magicBERT / modeling.py
nishtahir's picture
Upload folder using huggingface_hub
7cf5414 verified
from typing import NamedTuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from torch import Tensor
from transformers import AutoModel, PreTrainedModel
from .config import MagicBERTConfig
class HungarianTokenLoss(nn.Module):
"""
Permutation-invariant token classification loss using Hungarian matching.
logits: (B, N, V) - N slot queries, V vocab
targets: (B, M) - M target token ids (unordered multiset)
target_mask: (B, M) bool/0-1 mask; True for valid targets, False for padding (optional)
"""
def __init__(self, reduction: str = "mean", label_smoothing: float = 0.0):
super().__init__()
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of: mean, sum, none")
if not (0.0 <= label_smoothing < 1.0):
raise ValueError("label_smoothing must be in [0, 1)")
self.reduction = reduction
self.label_smoothing = float(label_smoothing)
def forward(
self,
logits: torch.Tensor,
targets: torch.Tensor,
*,
target_mask: torch.Tensor | None = None,
) -> torch.Tensor:
if logits.dim() != 3:
raise ValueError("logits must be (B, N, V)")
if targets.dim() != 2:
raise ValueError("targets must be (B, M)")
if logits.size(0) != targets.size(0):
raise ValueError("batch size mismatch between logits and targets")
B, N, V = logits.shape
_, M = targets.shape
if target_mask is not None:
if target_mask.shape != targets.shape:
raise ValueError("target_mask must have same shape as targets (B, M)")
valid_mask = target_mask.bool()
else:
valid_mask = torch.ones_like(targets, dtype=torch.bool)
log_probs = F.log_softmax(logits, dim=-1) # (B, N, V)
batch_losses: list[torch.Tensor] = []
for b in range(B):
# Select valid targets for this sample: ids shape (m,)
ids = targets[b][valid_mask[b]]
m = int(ids.numel())
if m == 0 or N == 0:
# No targets or no predictions -> zero loss
batch_losses.append(log_probs[b].sum() * 0.0)
continue
# Cost matrix: (N, m) where cost[i, j] = -log p_i(ids[j])
# Gather: log_probs[b] is (N, V), ids is (m,) -> result (N, m)
lp = log_probs[b] # (N, V)
cost = -lp[:, ids] # (N, m)
# Hungarian assignment (CPU, non-differentiable)
row_ind, col_ind = linear_sum_assignment(cost.detach().cpu().numpy())
row = torch.tensor(row_ind, device=logits.device, dtype=torch.long)
col = torch.tensor(col_ind, device=logits.device, dtype=torch.long)
matched_cost = cost[row, col] # (k,) where k = min(N, m)
# Optional label smoothing, applied only on matched pairs
if self.label_smoothing > 0.0:
# nll for matched pairs is matched_cost
# smooth loss is -mean log_probs over vocab
matched_lp = lp[row] # (k, V)
smooth = -matched_lp.mean(dim=-1) # (k,)
eps = self.label_smoothing
matched_cost = (1.0 - eps) * matched_cost + eps * smooth
if self.reduction == "sum":
batch_losses.append(matched_cost.sum())
else:
batch_losses.append(matched_cost.mean())
out = torch.stack(batch_losses) if batch_losses else torch.tensor(0.0, device=logits.device)
if self.reduction == "none":
return out
if self.reduction == "sum":
return out.sum()
return out.mean()
class MagicBERTOutput(NamedTuple):
logits: Tensor # (B, seq_len, vocab_size)
loss: Tensor | None # scalar, present when target_ids were supplied
class MagicBERTModel(nn.Module):
def __init__(
self,
*,
attention_dropout: float,
d_model: int,
dim_feed_forward: int,
embedding_dropout: float,
mask_token_id: int,
num_attention_heads: int,
num_encoder_layers: int,
pad_token_id: int,
seq_len: int,
tie_embeddings: bool,
vocab_size: int,
):
super().__init__()
self.seq_len = seq_len
self.tie_embeddings = tie_embeddings
self.pad_token_id = pad_token_id
self.mask_token_id = mask_token_id
self.semantic_E = nn.Embedding(vocab_size, d_model)
self.pos_E = nn.Embedding(seq_len, d_model)
self.embedding_dropout = nn.Dropout(embedding_dropout)
self.context_scale = nn.Parameter(torch.ones(1))
self.encoder_layers = nn.ModuleList(
[
nn.TransformerEncoderLayer(
batch_first=True,
d_model=d_model,
dim_feedforward=dim_feed_forward,
dropout=attention_dropout,
nhead=num_attention_heads,
)
for _ in range(num_encoder_layers)
]
)
self.context_query_norms = nn.ModuleList(
[nn.LayerNorm(d_model) for _ in range(num_encoder_layers)]
)
self.context_kv_norms = nn.ModuleList(
[nn.LayerNorm(d_model) for _ in range(num_encoder_layers)]
)
self.context_attention_layers = nn.ModuleList(
[
nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_attention_heads,
dropout=attention_dropout,
batch_first=True,
)
for _ in range(num_encoder_layers)
]
)
self.layer_norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.loss_fn = HungarianTokenLoss()
if tie_embeddings:
self.tie_weights()
def _attention_mask(self, input_ids: Tensor, attention_mask: Tensor | None) -> Tensor:
if attention_mask is not None:
if attention_mask.shape != input_ids.shape:
raise ValueError("attention_mask must have the same shape as input_ids")
return attention_mask.bool()
return input_ids.ne(self.pad_token_id)
def forward(
self,
*,
input_ids: Tensor,
attention_mask: Tensor | None = None,
context_ids: Tensor,
context_attention_mask: Tensor | None = None,
target_ids: Tensor | None = None,
target_attention_mask: Tensor | None = None,
) -> MagicBERTOutput:
if input_ids.dim() != 2:
raise ValueError("input_ids must be of shape (batch, seq_len)")
if input_ids.size(0) == 0:
raise ValueError("input_ids batch dimension must be > 0")
if context_ids.size(0) != input_ids.size(0):
raise ValueError("context_ids batch dimension must match input_ids")
if context_attention_mask is None:
context_attention_mask = context_ids.ne(self.pad_token_id)
if context_attention_mask.shape != context_ids.shape:
raise ValueError("context_attention_mask must have the same shape as context_ids")
padding_mask = ~self._attention_mask(input_ids, attention_mask)
positions = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0)
src_embeddings = self.embedding_dropout(self.semantic_E(input_ids) + self.pos_E(positions))
context_embeddings = self.semantic_E(context_ids)
context_embeddings = self.embedding_dropout(context_embeddings)
context_padding_mask = ~context_attention_mask.bool()
encoded = src_embeddings
for idx, layer in enumerate(self.encoder_layers):
encoded = layer(encoded, src_key_padding_mask=padding_mask)
norm_encoded = self.context_query_norms[idx](encoded)
norm_context = self.context_kv_norms[idx](context_embeddings)
attn_output, _ = self.context_attention_layers[idx](
norm_encoded,
norm_context,
norm_context,
key_padding_mask=context_padding_mask,
need_weights=False,
)
encoded = encoded + self.context_scale * attn_output
encoded = self.layer_norm(encoded)
logits = self.lm_head(encoded)
loss = None
if target_ids is not None:
loss = self.loss_fn(logits, target_ids, target_mask=target_attention_mask)
return MagicBERTOutput(logits=logits, loss=loss)
def tie_weights(self, **kwargs) -> None:
if self.tie_embeddings:
self.lm_head.weight = self.semantic_E.weight
class MagicBERT(PreTrainedModel):
config_class = MagicBERTConfig
_tied_weights_keys = {"model.lm_head.weight": "model.semantic_E.weight"}
def __init__(self, config: MagicBERTConfig):
super().__init__(config)
self.model = MagicBERTModel(
attention_dropout=config.attention_dropout,
d_model=config.d_model,
dim_feed_forward=config.dim_feed_forward,
embedding_dropout=config.embedding_dropout,
mask_token_id=config.mask_token_id,
num_attention_heads=config.num_attention_heads,
num_encoder_layers=config.num_encoder_layers,
pad_token_id=config.pad_token_id, # type: ignore
seq_len=config.seq_len,
tie_embeddings=config.tie_embeddings,
vocab_size=config.vocab_size,
)
self.post_init()
def tie_weights(self, **kwargs) -> None: # type: ignore
if self.config.tie_embeddings:
self.model.tie_weights()
def get_input_embeddings(self) -> nn.Module:
return self.model.semantic_E
def set_input_embeddings(self, value: nn.Module):
self.model.semantic_E = value
if self.config.tie_embeddings:
self.tie_weights()
def get_output_embeddings(self) -> nn.Module:
return self.model.lm_head
def set_output_embeddings(self, new_embeddings: nn.Module):
self.model.lm_head = new_embeddings
if self.config.tie_embeddings:
self.tie_weights()
def forward(
self,
*,
input_ids: Tensor,
attention_mask: Tensor | None = None,
context_ids: Tensor,
context_attention_mask: Tensor | None = None,
target_ids: Tensor | None = None,
target_attention_mask: Tensor | None = None,
) -> MagicBERTOutput:
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
context_ids=context_ids,
context_attention_mask=context_attention_mask,
target_ids=target_ids,
target_attention_mask=target_attention_mask,
)
def _build_legal_token_mask(
self,
*,
device: torch.device,
cards: list[dict[str, object]],
) -> Tensor:
legal_token_mask = torch.zeros(self.config.vocab_size, device=device, dtype=torch.bool)
legal_token_mask[self.config.pad_token_id] = True
legal_token_mask[self.config.mask_token_id] = True
for card in cards:
if card.get("commander_legal"):
token_id = card.get("token_id")
if isinstance(token_id, int) and 0 <= token_id < self.config.vocab_size:
legal_token_mask[token_id] = True
return legal_token_mask
def _build_basic_token_mask(
self,
*,
device: torch.device,
cards: list[dict[str, object]],
) -> Tensor:
basic_token_mask = torch.zeros(self.config.vocab_size, device=device, dtype=torch.bool)
for card in cards:
token_id = card.get("token_id")
type_line = card.get("type_line", "")
if isinstance(token_id, int) and 0 <= token_id < self.config.vocab_size:
if isinstance(type_line, str) and "Basic" in type_line:
basic_token_mask[token_id] = True
return basic_token_mask
@torch.no_grad()
def generate(
self,
input_ids: Tensor,
*,
context_ids: Tensor | None = None,
context_attention_mask: Tensor | None = None,
) -> Tensor:
cards = getattr(self.generation_config, "cards", None)
if not cards:
raise ValueError("generation_config.cards is required for legality masking")
pad_token_id: int = self.config.pad_token_id # type: ignore
mask_token_id: int = self.config.mask_token_id
if context_ids is None:
context_ids = input_ids.masked_fill(input_ids.eq(pad_token_id), mask_token_id)
legal_token_mask = self._build_legal_token_mask(device=input_ids.device, cards=cards)
basic_token_mask = self._build_basic_token_mask(device=input_ids.device, cards=cards)
output = self(
input_ids=input_ids,
context_ids=context_ids,
context_attention_mask=context_attention_mask,
)
logits = output.logits # (B, seq_len, V)
logits = logits.masked_fill(~legal_token_mask, -1e9)
B, num_slots, V = logits.shape
log_probs = F.log_softmax(logits, dim=-1)
# Column pool: non-basics appear once (singleton), basics appear num_slots times
legal_non_basic = legal_token_mask & ~basic_token_mask
legal_non_basic[pad_token_id] = False
legal_non_basic[mask_token_id] = False
non_basic_ids = legal_non_basic.nonzero(as_tuple=False).flatten().tolist()
basic_ids = basic_token_mask.nonzero(as_tuple=False).flatten().tolist()
col_ids: list[int] = non_basic_ids + basic_ids * num_slots
col_ids_t = torch.tensor(col_ids, device=logits.device, dtype=torch.long)
result = torch.full((B, num_slots), pad_token_id, device=logits.device, dtype=torch.long)
for b in range(B):
cost = -log_probs[b][:, col_ids_t] # (num_slots, num_cols)
row_ind, col_ind = linear_sum_assignment(cost.cpu().numpy())
rows = torch.tensor(row_ind, device=logits.device, dtype=torch.long)
result[b, rows] = col_ids_t[torch.tensor(col_ind, device=logits.device)]
return result
@torch.no_grad()
def iterative_generate(
self,
input_ids: Tensor,
*,
context_ids: Tensor | None = None,
context_attention_mask: Tensor | None = None,
steps: int = 5,
remask_ratio: float = 0.3,
) -> list[Tensor]:
"""Iteratively generate a deck, remasking low-confidence slots between steps.
Returns a list of token_id tensors, one per step (each shape (B, num_slots)).
"""
cards = getattr(self.generation_config, "cards", None)
if not cards:
raise ValueError("generation_config.cards is required for legality masking")
pad_token_id: int = self.config.pad_token_id # type: ignore
mask_token_id: int = self.config.mask_token_id
if context_ids is None:
context_ids = input_ids.masked_fill(input_ids.eq(pad_token_id), mask_token_id)
legal_token_mask = self._build_legal_token_mask(device=input_ids.device, cards=cards)
basic_token_mask = self._build_basic_token_mask(device=input_ids.device, cards=cards)
legal_non_basic = legal_token_mask & ~basic_token_mask
legal_non_basic[pad_token_id] = False
legal_non_basic[mask_token_id] = False
non_basic_ids = legal_non_basic.nonzero(as_tuple=False).flatten().tolist()
basic_ids = basic_token_mask.nonzero(as_tuple=False).flatten().tolist()
x = input_ids.clone()
B, num_slots = x.shape
col_ids: list[int] = non_basic_ids + basic_ids * num_slots
col_ids_t = torch.tensor(col_ids, device=x.device, dtype=torch.long)
all_steps: list[Tensor] = []
for step in range(steps):
is_last = step == steps - 1
output = self(
input_ids=x,
context_ids=context_ids,
context_attention_mask=context_attention_mask,
)
logits = output.logits.masked_fill(~legal_token_mask, -1e9)
log_probs = F.log_softmax(logits, dim=-1)
result = torch.full((B, num_slots), pad_token_id, device=x.device, dtype=torch.long)
confidence = torch.full((B, num_slots), float("-inf"), device=x.device)
for b in range(B):
cost = -log_probs[b][:, col_ids_t]
row_ind, col_ind = linear_sum_assignment(cost.cpu().numpy())
rows = torch.tensor(row_ind, device=x.device, dtype=torch.long)
cols = torch.tensor(col_ind, device=x.device, dtype=torch.long)
result[b, rows] = col_ids_t[cols]
confidence[b, rows] = -cost[rows, cols]
all_steps.append(result.clone())
if is_last or remask_ratio <= 0.0:
x = result
continue
# Remask the lowest-confidence slots so the next step can revise them.
x = result.clone()
for b in range(B):
filled = result[b].ne(pad_token_id).nonzero(as_tuple=False).flatten()
n_remask = max(0, int(filled.numel() * remask_ratio))
if n_remask == 0:
continue
_, worst = torch.topk(confidence[b, filled], k=n_remask, largest=False)
x[b, filled[worst]] = mask_token_id
return all_steps
MagicBERT.register_for_auto_class(AutoModel)