| from typing import NamedTuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from scipy.optimize import linear_sum_assignment |
| from torch import Tensor |
| from transformers import AutoModel, PreTrainedModel |
|
|
| from .config import MagicBERTConfig |
|
|
|
|
| class HungarianTokenLoss(nn.Module): |
| """ |
| Permutation-invariant token classification loss using Hungarian matching. |
| |
| logits: (B, N, V) - N slot queries, V vocab |
| targets: (B, M) - M target token ids (unordered multiset) |
| target_mask: (B, M) bool/0-1 mask; True for valid targets, False for padding (optional) |
| """ |
|
|
| def __init__(self, reduction: str = "mean", label_smoothing: float = 0.0): |
| super().__init__() |
| if reduction not in {"mean", "sum", "none"}: |
| raise ValueError("reduction must be one of: mean, sum, none") |
| if not (0.0 <= label_smoothing < 1.0): |
| raise ValueError("label_smoothing must be in [0, 1)") |
| self.reduction = reduction |
| self.label_smoothing = float(label_smoothing) |
|
|
| def forward( |
| self, |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| *, |
| target_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| if logits.dim() != 3: |
| raise ValueError("logits must be (B, N, V)") |
| if targets.dim() != 2: |
| raise ValueError("targets must be (B, M)") |
| if logits.size(0) != targets.size(0): |
| raise ValueError("batch size mismatch between logits and targets") |
|
|
| B, N, V = logits.shape |
| _, M = targets.shape |
|
|
| if target_mask is not None: |
| if target_mask.shape != targets.shape: |
| raise ValueError("target_mask must have same shape as targets (B, M)") |
| valid_mask = target_mask.bool() |
| else: |
| valid_mask = torch.ones_like(targets, dtype=torch.bool) |
|
|
| log_probs = F.log_softmax(logits, dim=-1) |
|
|
| batch_losses: list[torch.Tensor] = [] |
| for b in range(B): |
| |
| ids = targets[b][valid_mask[b]] |
| m = int(ids.numel()) |
| if m == 0 or N == 0: |
| |
| batch_losses.append(log_probs[b].sum() * 0.0) |
| continue |
|
|
| |
| |
| lp = log_probs[b] |
| cost = -lp[:, ids] |
|
|
| |
| row_ind, col_ind = linear_sum_assignment(cost.detach().cpu().numpy()) |
|
|
| row = torch.tensor(row_ind, device=logits.device, dtype=torch.long) |
| col = torch.tensor(col_ind, device=logits.device, dtype=torch.long) |
|
|
| matched_cost = cost[row, col] |
|
|
| |
| if self.label_smoothing > 0.0: |
| |
| |
| matched_lp = lp[row] |
| smooth = -matched_lp.mean(dim=-1) |
| eps = self.label_smoothing |
| matched_cost = (1.0 - eps) * matched_cost + eps * smooth |
|
|
| if self.reduction == "sum": |
| batch_losses.append(matched_cost.sum()) |
| else: |
| batch_losses.append(matched_cost.mean()) |
|
|
| out = torch.stack(batch_losses) if batch_losses else torch.tensor(0.0, device=logits.device) |
|
|
| if self.reduction == "none": |
| return out |
| if self.reduction == "sum": |
| return out.sum() |
| return out.mean() |
|
|
|
|
| class MagicBERTOutput(NamedTuple): |
| logits: Tensor |
| loss: Tensor | None |
|
|
|
|
| class MagicBERTModel(nn.Module): |
| def __init__( |
| self, |
| *, |
| attention_dropout: float, |
| d_model: int, |
| dim_feed_forward: int, |
| embedding_dropout: float, |
| mask_token_id: int, |
| num_attention_heads: int, |
| num_encoder_layers: int, |
| pad_token_id: int, |
| seq_len: int, |
| tie_embeddings: bool, |
| vocab_size: int, |
| ): |
| super().__init__() |
| self.seq_len = seq_len |
| self.tie_embeddings = tie_embeddings |
| self.pad_token_id = pad_token_id |
| self.mask_token_id = mask_token_id |
|
|
| self.semantic_E = nn.Embedding(vocab_size, d_model) |
| self.pos_E = nn.Embedding(seq_len, d_model) |
| self.embedding_dropout = nn.Dropout(embedding_dropout) |
| self.context_scale = nn.Parameter(torch.ones(1)) |
|
|
| self.encoder_layers = nn.ModuleList( |
| [ |
| nn.TransformerEncoderLayer( |
| batch_first=True, |
| d_model=d_model, |
| dim_feedforward=dim_feed_forward, |
| dropout=attention_dropout, |
| nhead=num_attention_heads, |
| ) |
| for _ in range(num_encoder_layers) |
| ] |
| ) |
|
|
| self.context_query_norms = nn.ModuleList( |
| [nn.LayerNorm(d_model) for _ in range(num_encoder_layers)] |
| ) |
| self.context_kv_norms = nn.ModuleList( |
| [nn.LayerNorm(d_model) for _ in range(num_encoder_layers)] |
| ) |
| self.context_attention_layers = nn.ModuleList( |
| [ |
| nn.MultiheadAttention( |
| embed_dim=d_model, |
| num_heads=num_attention_heads, |
| dropout=attention_dropout, |
| batch_first=True, |
| ) |
| for _ in range(num_encoder_layers) |
| ] |
| ) |
| self.layer_norm = nn.LayerNorm(d_model) |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
| self.loss_fn = HungarianTokenLoss() |
| if tie_embeddings: |
| self.tie_weights() |
|
|
| def _attention_mask(self, input_ids: Tensor, attention_mask: Tensor | None) -> Tensor: |
| if attention_mask is not None: |
| if attention_mask.shape != input_ids.shape: |
| raise ValueError("attention_mask must have the same shape as input_ids") |
| return attention_mask.bool() |
| return input_ids.ne(self.pad_token_id) |
|
|
| def forward( |
| self, |
| *, |
| input_ids: Tensor, |
| attention_mask: Tensor | None = None, |
| context_ids: Tensor, |
| context_attention_mask: Tensor | None = None, |
| target_ids: Tensor | None = None, |
| target_attention_mask: Tensor | None = None, |
| ) -> MagicBERTOutput: |
| if input_ids.dim() != 2: |
| raise ValueError("input_ids must be of shape (batch, seq_len)") |
| if input_ids.size(0) == 0: |
| raise ValueError("input_ids batch dimension must be > 0") |
| if context_ids.size(0) != input_ids.size(0): |
| raise ValueError("context_ids batch dimension must match input_ids") |
| if context_attention_mask is None: |
| context_attention_mask = context_ids.ne(self.pad_token_id) |
| if context_attention_mask.shape != context_ids.shape: |
| raise ValueError("context_attention_mask must have the same shape as context_ids") |
|
|
| padding_mask = ~self._attention_mask(input_ids, attention_mask) |
| positions = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0) |
| src_embeddings = self.embedding_dropout(self.semantic_E(input_ids) + self.pos_E(positions)) |
|
|
| context_embeddings = self.semantic_E(context_ids) |
| context_embeddings = self.embedding_dropout(context_embeddings) |
|
|
| context_padding_mask = ~context_attention_mask.bool() |
|
|
| encoded = src_embeddings |
| for idx, layer in enumerate(self.encoder_layers): |
| encoded = layer(encoded, src_key_padding_mask=padding_mask) |
| norm_encoded = self.context_query_norms[idx](encoded) |
| norm_context = self.context_kv_norms[idx](context_embeddings) |
| attn_output, _ = self.context_attention_layers[idx]( |
| norm_encoded, |
| norm_context, |
| norm_context, |
| key_padding_mask=context_padding_mask, |
| need_weights=False, |
| ) |
| encoded = encoded + self.context_scale * attn_output |
|
|
| encoded = self.layer_norm(encoded) |
| logits = self.lm_head(encoded) |
|
|
| loss = None |
| if target_ids is not None: |
| loss = self.loss_fn(logits, target_ids, target_mask=target_attention_mask) |
|
|
| return MagicBERTOutput(logits=logits, loss=loss) |
|
|
| def tie_weights(self, **kwargs) -> None: |
| if self.tie_embeddings: |
| self.lm_head.weight = self.semantic_E.weight |
|
|
|
|
| class MagicBERT(PreTrainedModel): |
| config_class = MagicBERTConfig |
| _tied_weights_keys = {"model.lm_head.weight": "model.semantic_E.weight"} |
|
|
| def __init__(self, config: MagicBERTConfig): |
| super().__init__(config) |
| self.model = MagicBERTModel( |
| attention_dropout=config.attention_dropout, |
| d_model=config.d_model, |
| dim_feed_forward=config.dim_feed_forward, |
| embedding_dropout=config.embedding_dropout, |
| mask_token_id=config.mask_token_id, |
| num_attention_heads=config.num_attention_heads, |
| num_encoder_layers=config.num_encoder_layers, |
| pad_token_id=config.pad_token_id, |
| seq_len=config.seq_len, |
| tie_embeddings=config.tie_embeddings, |
| vocab_size=config.vocab_size, |
| ) |
| self.post_init() |
|
|
| def tie_weights(self, **kwargs) -> None: |
| if self.config.tie_embeddings: |
| self.model.tie_weights() |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| return self.model.semantic_E |
|
|
| def set_input_embeddings(self, value: nn.Module): |
| self.model.semantic_E = value |
| if self.config.tie_embeddings: |
| self.tie_weights() |
|
|
| def get_output_embeddings(self) -> nn.Module: |
| return self.model.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Module): |
| self.model.lm_head = new_embeddings |
| if self.config.tie_embeddings: |
| self.tie_weights() |
|
|
| def forward( |
| self, |
| *, |
| input_ids: Tensor, |
| attention_mask: Tensor | None = None, |
| context_ids: Tensor, |
| context_attention_mask: Tensor | None = None, |
| target_ids: Tensor | None = None, |
| target_attention_mask: Tensor | None = None, |
| ) -> MagicBERTOutput: |
| return self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| context_ids=context_ids, |
| context_attention_mask=context_attention_mask, |
| target_ids=target_ids, |
| target_attention_mask=target_attention_mask, |
| ) |
|
|
| def _build_legal_token_mask( |
| self, |
| *, |
| device: torch.device, |
| cards: list[dict[str, object]], |
| ) -> Tensor: |
| legal_token_mask = torch.zeros(self.config.vocab_size, device=device, dtype=torch.bool) |
| legal_token_mask[self.config.pad_token_id] = True |
| legal_token_mask[self.config.mask_token_id] = True |
| for card in cards: |
| if card.get("commander_legal"): |
| token_id = card.get("token_id") |
| if isinstance(token_id, int) and 0 <= token_id < self.config.vocab_size: |
| legal_token_mask[token_id] = True |
| return legal_token_mask |
|
|
| def _build_basic_token_mask( |
| self, |
| *, |
| device: torch.device, |
| cards: list[dict[str, object]], |
| ) -> Tensor: |
| basic_token_mask = torch.zeros(self.config.vocab_size, device=device, dtype=torch.bool) |
| for card in cards: |
| token_id = card.get("token_id") |
| type_line = card.get("type_line", "") |
| if isinstance(token_id, int) and 0 <= token_id < self.config.vocab_size: |
| if isinstance(type_line, str) and "Basic" in type_line: |
| basic_token_mask[token_id] = True |
| return basic_token_mask |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: Tensor, |
| *, |
| context_ids: Tensor | None = None, |
| context_attention_mask: Tensor | None = None, |
| ) -> Tensor: |
| cards = getattr(self.generation_config, "cards", None) |
| if not cards: |
| raise ValueError("generation_config.cards is required for legality masking") |
|
|
| pad_token_id: int = self.config.pad_token_id |
| mask_token_id: int = self.config.mask_token_id |
|
|
| if context_ids is None: |
| context_ids = input_ids.masked_fill(input_ids.eq(pad_token_id), mask_token_id) |
|
|
| legal_token_mask = self._build_legal_token_mask(device=input_ids.device, cards=cards) |
| basic_token_mask = self._build_basic_token_mask(device=input_ids.device, cards=cards) |
|
|
| output = self( |
| input_ids=input_ids, |
| context_ids=context_ids, |
| context_attention_mask=context_attention_mask, |
| ) |
| logits = output.logits |
| logits = logits.masked_fill(~legal_token_mask, -1e9) |
|
|
| B, num_slots, V = logits.shape |
| log_probs = F.log_softmax(logits, dim=-1) |
|
|
| |
| legal_non_basic = legal_token_mask & ~basic_token_mask |
| legal_non_basic[pad_token_id] = False |
| legal_non_basic[mask_token_id] = False |
| non_basic_ids = legal_non_basic.nonzero(as_tuple=False).flatten().tolist() |
| basic_ids = basic_token_mask.nonzero(as_tuple=False).flatten().tolist() |
| col_ids: list[int] = non_basic_ids + basic_ids * num_slots |
| col_ids_t = torch.tensor(col_ids, device=logits.device, dtype=torch.long) |
|
|
| result = torch.full((B, num_slots), pad_token_id, device=logits.device, dtype=torch.long) |
| for b in range(B): |
| cost = -log_probs[b][:, col_ids_t] |
| row_ind, col_ind = linear_sum_assignment(cost.cpu().numpy()) |
| rows = torch.tensor(row_ind, device=logits.device, dtype=torch.long) |
| result[b, rows] = col_ids_t[torch.tensor(col_ind, device=logits.device)] |
|
|
| return result |
|
|
| @torch.no_grad() |
| def iterative_generate( |
| self, |
| input_ids: Tensor, |
| *, |
| context_ids: Tensor | None = None, |
| context_attention_mask: Tensor | None = None, |
| steps: int = 5, |
| remask_ratio: float = 0.3, |
| ) -> list[Tensor]: |
| """Iteratively generate a deck, remasking low-confidence slots between steps. |
| |
| Returns a list of token_id tensors, one per step (each shape (B, num_slots)). |
| """ |
| cards = getattr(self.generation_config, "cards", None) |
| if not cards: |
| raise ValueError("generation_config.cards is required for legality masking") |
|
|
| pad_token_id: int = self.config.pad_token_id |
| mask_token_id: int = self.config.mask_token_id |
|
|
| if context_ids is None: |
| context_ids = input_ids.masked_fill(input_ids.eq(pad_token_id), mask_token_id) |
|
|
| legal_token_mask = self._build_legal_token_mask(device=input_ids.device, cards=cards) |
| basic_token_mask = self._build_basic_token_mask(device=input_ids.device, cards=cards) |
|
|
| legal_non_basic = legal_token_mask & ~basic_token_mask |
| legal_non_basic[pad_token_id] = False |
| legal_non_basic[mask_token_id] = False |
| non_basic_ids = legal_non_basic.nonzero(as_tuple=False).flatten().tolist() |
| basic_ids = basic_token_mask.nonzero(as_tuple=False).flatten().tolist() |
|
|
| x = input_ids.clone() |
| B, num_slots = x.shape |
| col_ids: list[int] = non_basic_ids + basic_ids * num_slots |
| col_ids_t = torch.tensor(col_ids, device=x.device, dtype=torch.long) |
|
|
| all_steps: list[Tensor] = [] |
|
|
| for step in range(steps): |
| is_last = step == steps - 1 |
|
|
| output = self( |
| input_ids=x, |
| context_ids=context_ids, |
| context_attention_mask=context_attention_mask, |
| ) |
| logits = output.logits.masked_fill(~legal_token_mask, -1e9) |
| log_probs = F.log_softmax(logits, dim=-1) |
|
|
| result = torch.full((B, num_slots), pad_token_id, device=x.device, dtype=torch.long) |
| confidence = torch.full((B, num_slots), float("-inf"), device=x.device) |
|
|
| for b in range(B): |
| cost = -log_probs[b][:, col_ids_t] |
| row_ind, col_ind = linear_sum_assignment(cost.cpu().numpy()) |
| rows = torch.tensor(row_ind, device=x.device, dtype=torch.long) |
| cols = torch.tensor(col_ind, device=x.device, dtype=torch.long) |
| result[b, rows] = col_ids_t[cols] |
| confidence[b, rows] = -cost[rows, cols] |
|
|
| all_steps.append(result.clone()) |
|
|
| if is_last or remask_ratio <= 0.0: |
| x = result |
| continue |
|
|
| |
| x = result.clone() |
| for b in range(B): |
| filled = result[b].ne(pad_token_id).nonzero(as_tuple=False).flatten() |
| n_remask = max(0, int(filled.numel() * remask_ratio)) |
| if n_remask == 0: |
| continue |
| _, worst = torch.topk(confidence[b, filled], k=n_remask, largest=False) |
| x[b, filled[worst]] = mask_token_id |
|
|
| return all_steps |
|
|
|
|
| MagicBERT.register_for_auto_class(AutoModel) |
|
|