from typing import NamedTuple import torch import torch.nn as nn import torch.nn.functional as F from scipy.optimize import linear_sum_assignment from torch import Tensor from transformers import AutoModel, PreTrainedModel from .config import MagicBERTConfig class HungarianTokenLoss(nn.Module): """ Permutation-invariant token classification loss using Hungarian matching. logits: (B, N, V) - N slot queries, V vocab targets: (B, M) - M target token ids (unordered multiset) target_mask: (B, M) bool/0-1 mask; True for valid targets, False for padding (optional) """ def __init__(self, reduction: str = "mean", label_smoothing: float = 0.0): super().__init__() if reduction not in {"mean", "sum", "none"}: raise ValueError("reduction must be one of: mean, sum, none") if not (0.0 <= label_smoothing < 1.0): raise ValueError("label_smoothing must be in [0, 1)") self.reduction = reduction self.label_smoothing = float(label_smoothing) def forward( self, logits: torch.Tensor, targets: torch.Tensor, *, target_mask: torch.Tensor | None = None, ) -> torch.Tensor: if logits.dim() != 3: raise ValueError("logits must be (B, N, V)") if targets.dim() != 2: raise ValueError("targets must be (B, M)") if logits.size(0) != targets.size(0): raise ValueError("batch size mismatch between logits and targets") B, N, V = logits.shape _, M = targets.shape if target_mask is not None: if target_mask.shape != targets.shape: raise ValueError("target_mask must have same shape as targets (B, M)") valid_mask = target_mask.bool() else: valid_mask = torch.ones_like(targets, dtype=torch.bool) log_probs = F.log_softmax(logits, dim=-1) # (B, N, V) batch_losses: list[torch.Tensor] = [] for b in range(B): # Select valid targets for this sample: ids shape (m,) ids = targets[b][valid_mask[b]] m = int(ids.numel()) if m == 0 or N == 0: # No targets or no predictions -> zero loss batch_losses.append(log_probs[b].sum() * 0.0) continue # Cost matrix: (N, m) where cost[i, j] = -log p_i(ids[j]) # Gather: log_probs[b] is (N, V), ids is (m,) -> result (N, m) lp = log_probs[b] # (N, V) cost = -lp[:, ids] # (N, m) # Hungarian assignment (CPU, non-differentiable) row_ind, col_ind = linear_sum_assignment(cost.detach().cpu().numpy()) row = torch.tensor(row_ind, device=logits.device, dtype=torch.long) col = torch.tensor(col_ind, device=logits.device, dtype=torch.long) matched_cost = cost[row, col] # (k,) where k = min(N, m) # Optional label smoothing, applied only on matched pairs if self.label_smoothing > 0.0: # nll for matched pairs is matched_cost # smooth loss is -mean log_probs over vocab matched_lp = lp[row] # (k, V) smooth = -matched_lp.mean(dim=-1) # (k,) eps = self.label_smoothing matched_cost = (1.0 - eps) * matched_cost + eps * smooth if self.reduction == "sum": batch_losses.append(matched_cost.sum()) else: batch_losses.append(matched_cost.mean()) out = torch.stack(batch_losses) if batch_losses else torch.tensor(0.0, device=logits.device) if self.reduction == "none": return out if self.reduction == "sum": return out.sum() return out.mean() class MagicBERTOutput(NamedTuple): logits: Tensor # (B, seq_len, vocab_size) loss: Tensor | None # scalar, present when target_ids were supplied class MagicBERTModel(nn.Module): def __init__( self, *, attention_dropout: float, d_model: int, dim_feed_forward: int, embedding_dropout: float, mask_token_id: int, num_attention_heads: int, num_encoder_layers: int, pad_token_id: int, seq_len: int, tie_embeddings: bool, vocab_size: int, ): super().__init__() self.seq_len = seq_len self.tie_embeddings = tie_embeddings self.pad_token_id = pad_token_id self.mask_token_id = mask_token_id self.semantic_E = nn.Embedding(vocab_size, d_model) self.pos_E = nn.Embedding(seq_len, d_model) self.embedding_dropout = nn.Dropout(embedding_dropout) self.context_scale = nn.Parameter(torch.ones(1)) self.encoder_layers = nn.ModuleList( [ nn.TransformerEncoderLayer( batch_first=True, d_model=d_model, dim_feedforward=dim_feed_forward, dropout=attention_dropout, nhead=num_attention_heads, ) for _ in range(num_encoder_layers) ] ) self.context_query_norms = nn.ModuleList( [nn.LayerNorm(d_model) for _ in range(num_encoder_layers)] ) self.context_kv_norms = nn.ModuleList( [nn.LayerNorm(d_model) for _ in range(num_encoder_layers)] ) self.context_attention_layers = nn.ModuleList( [ nn.MultiheadAttention( embed_dim=d_model, num_heads=num_attention_heads, dropout=attention_dropout, batch_first=True, ) for _ in range(num_encoder_layers) ] ) self.layer_norm = nn.LayerNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size, bias=False) self.loss_fn = HungarianTokenLoss() if tie_embeddings: self.tie_weights() def _attention_mask(self, input_ids: Tensor, attention_mask: Tensor | None) -> Tensor: if attention_mask is not None: if attention_mask.shape != input_ids.shape: raise ValueError("attention_mask must have the same shape as input_ids") return attention_mask.bool() return input_ids.ne(self.pad_token_id) def forward( self, *, input_ids: Tensor, attention_mask: Tensor | None = None, context_ids: Tensor, context_attention_mask: Tensor | None = None, target_ids: Tensor | None = None, target_attention_mask: Tensor | None = None, ) -> MagicBERTOutput: if input_ids.dim() != 2: raise ValueError("input_ids must be of shape (batch, seq_len)") if input_ids.size(0) == 0: raise ValueError("input_ids batch dimension must be > 0") if context_ids.size(0) != input_ids.size(0): raise ValueError("context_ids batch dimension must match input_ids") if context_attention_mask is None: context_attention_mask = context_ids.ne(self.pad_token_id) if context_attention_mask.shape != context_ids.shape: raise ValueError("context_attention_mask must have the same shape as context_ids") padding_mask = ~self._attention_mask(input_ids, attention_mask) positions = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0) src_embeddings = self.embedding_dropout(self.semantic_E(input_ids) + self.pos_E(positions)) context_embeddings = self.semantic_E(context_ids) context_embeddings = self.embedding_dropout(context_embeddings) context_padding_mask = ~context_attention_mask.bool() encoded = src_embeddings for idx, layer in enumerate(self.encoder_layers): encoded = layer(encoded, src_key_padding_mask=padding_mask) norm_encoded = self.context_query_norms[idx](encoded) norm_context = self.context_kv_norms[idx](context_embeddings) attn_output, _ = self.context_attention_layers[idx]( norm_encoded, norm_context, norm_context, key_padding_mask=context_padding_mask, need_weights=False, ) encoded = encoded + self.context_scale * attn_output encoded = self.layer_norm(encoded) logits = self.lm_head(encoded) loss = None if target_ids is not None: loss = self.loss_fn(logits, target_ids, target_mask=target_attention_mask) return MagicBERTOutput(logits=logits, loss=loss) def tie_weights(self, **kwargs) -> None: if self.tie_embeddings: self.lm_head.weight = self.semantic_E.weight class MagicBERT(PreTrainedModel): config_class = MagicBERTConfig _tied_weights_keys = {"model.lm_head.weight": "model.semantic_E.weight"} def __init__(self, config: MagicBERTConfig): super().__init__(config) self.model = MagicBERTModel( attention_dropout=config.attention_dropout, d_model=config.d_model, dim_feed_forward=config.dim_feed_forward, embedding_dropout=config.embedding_dropout, mask_token_id=config.mask_token_id, num_attention_heads=config.num_attention_heads, num_encoder_layers=config.num_encoder_layers, pad_token_id=config.pad_token_id, # type: ignore seq_len=config.seq_len, tie_embeddings=config.tie_embeddings, vocab_size=config.vocab_size, ) self.post_init() def tie_weights(self, **kwargs) -> None: # type: ignore if self.config.tie_embeddings: self.model.tie_weights() def get_input_embeddings(self) -> nn.Module: return self.model.semantic_E def set_input_embeddings(self, value: nn.Module): self.model.semantic_E = value if self.config.tie_embeddings: self.tie_weights() def get_output_embeddings(self) -> nn.Module: return self.model.lm_head def set_output_embeddings(self, new_embeddings: nn.Module): self.model.lm_head = new_embeddings if self.config.tie_embeddings: self.tie_weights() def forward( self, *, input_ids: Tensor, attention_mask: Tensor | None = None, context_ids: Tensor, context_attention_mask: Tensor | None = None, target_ids: Tensor | None = None, target_attention_mask: Tensor | None = None, ) -> MagicBERTOutput: return self.model( input_ids=input_ids, attention_mask=attention_mask, context_ids=context_ids, context_attention_mask=context_attention_mask, target_ids=target_ids, target_attention_mask=target_attention_mask, ) def _build_legal_token_mask( self, *, device: torch.device, cards: list[dict[str, object]], ) -> Tensor: legal_token_mask = torch.zeros(self.config.vocab_size, device=device, dtype=torch.bool) legal_token_mask[self.config.pad_token_id] = True legal_token_mask[self.config.mask_token_id] = True for card in cards: if card.get("commander_legal"): token_id = card.get("token_id") if isinstance(token_id, int) and 0 <= token_id < self.config.vocab_size: legal_token_mask[token_id] = True return legal_token_mask def _build_basic_token_mask( self, *, device: torch.device, cards: list[dict[str, object]], ) -> Tensor: basic_token_mask = torch.zeros(self.config.vocab_size, device=device, dtype=torch.bool) for card in cards: token_id = card.get("token_id") type_line = card.get("type_line", "") if isinstance(token_id, int) and 0 <= token_id < self.config.vocab_size: if isinstance(type_line, str) and "Basic" in type_line: basic_token_mask[token_id] = True return basic_token_mask @torch.no_grad() def generate( self, input_ids: Tensor, *, context_ids: Tensor | None = None, context_attention_mask: Tensor | None = None, ) -> Tensor: cards = getattr(self.generation_config, "cards", None) if not cards: raise ValueError("generation_config.cards is required for legality masking") pad_token_id: int = self.config.pad_token_id # type: ignore mask_token_id: int = self.config.mask_token_id if context_ids is None: context_ids = input_ids.masked_fill(input_ids.eq(pad_token_id), mask_token_id) legal_token_mask = self._build_legal_token_mask(device=input_ids.device, cards=cards) basic_token_mask = self._build_basic_token_mask(device=input_ids.device, cards=cards) output = self( input_ids=input_ids, context_ids=context_ids, context_attention_mask=context_attention_mask, ) logits = output.logits # (B, seq_len, V) logits = logits.masked_fill(~legal_token_mask, -1e9) B, num_slots, V = logits.shape log_probs = F.log_softmax(logits, dim=-1) # Column pool: non-basics appear once (singleton), basics appear num_slots times legal_non_basic = legal_token_mask & ~basic_token_mask legal_non_basic[pad_token_id] = False legal_non_basic[mask_token_id] = False non_basic_ids = legal_non_basic.nonzero(as_tuple=False).flatten().tolist() basic_ids = basic_token_mask.nonzero(as_tuple=False).flatten().tolist() col_ids: list[int] = non_basic_ids + basic_ids * num_slots col_ids_t = torch.tensor(col_ids, device=logits.device, dtype=torch.long) result = torch.full((B, num_slots), pad_token_id, device=logits.device, dtype=torch.long) for b in range(B): cost = -log_probs[b][:, col_ids_t] # (num_slots, num_cols) row_ind, col_ind = linear_sum_assignment(cost.cpu().numpy()) rows = torch.tensor(row_ind, device=logits.device, dtype=torch.long) result[b, rows] = col_ids_t[torch.tensor(col_ind, device=logits.device)] return result @torch.no_grad() def iterative_generate( self, input_ids: Tensor, *, context_ids: Tensor | None = None, context_attention_mask: Tensor | None = None, steps: int = 5, remask_ratio: float = 0.3, ) -> list[Tensor]: """Iteratively generate a deck, remasking low-confidence slots between steps. Returns a list of token_id tensors, one per step (each shape (B, num_slots)). """ cards = getattr(self.generation_config, "cards", None) if not cards: raise ValueError("generation_config.cards is required for legality masking") pad_token_id: int = self.config.pad_token_id # type: ignore mask_token_id: int = self.config.mask_token_id if context_ids is None: context_ids = input_ids.masked_fill(input_ids.eq(pad_token_id), mask_token_id) legal_token_mask = self._build_legal_token_mask(device=input_ids.device, cards=cards) basic_token_mask = self._build_basic_token_mask(device=input_ids.device, cards=cards) legal_non_basic = legal_token_mask & ~basic_token_mask legal_non_basic[pad_token_id] = False legal_non_basic[mask_token_id] = False non_basic_ids = legal_non_basic.nonzero(as_tuple=False).flatten().tolist() basic_ids = basic_token_mask.nonzero(as_tuple=False).flatten().tolist() x = input_ids.clone() B, num_slots = x.shape col_ids: list[int] = non_basic_ids + basic_ids * num_slots col_ids_t = torch.tensor(col_ids, device=x.device, dtype=torch.long) all_steps: list[Tensor] = [] for step in range(steps): is_last = step == steps - 1 output = self( input_ids=x, context_ids=context_ids, context_attention_mask=context_attention_mask, ) logits = output.logits.masked_fill(~legal_token_mask, -1e9) log_probs = F.log_softmax(logits, dim=-1) result = torch.full((B, num_slots), pad_token_id, device=x.device, dtype=torch.long) confidence = torch.full((B, num_slots), float("-inf"), device=x.device) for b in range(B): cost = -log_probs[b][:, col_ids_t] row_ind, col_ind = linear_sum_assignment(cost.cpu().numpy()) rows = torch.tensor(row_ind, device=x.device, dtype=torch.long) cols = torch.tensor(col_ind, device=x.device, dtype=torch.long) result[b, rows] = col_ids_t[cols] confidence[b, rows] = -cost[rows, cols] all_steps.append(result.clone()) if is_last or remask_ratio <= 0.0: x = result continue # Remask the lowest-confidence slots so the next step can revise them. x = result.clone() for b in range(B): filled = result[b].ne(pad_token_id).nonzero(as_tuple=False).flatten() n_remask = max(0, int(filled.numel() * remask_ratio)) if n_remask == 0: continue _, worst = torch.topk(confidence[b, filled], k=n_remask, largest=False) x[b, filled[worst]] = mask_token_id return all_steps MagicBERT.register_for_auto_class(AutoModel)