| --- |
| license: mit |
| --- |
| |
| # magicBERT |
|
|
| A masked-language-model-style transformer for Commander deck completion. Given a partial deck (the "context") and a sequence of masked slots, magicBERT predicts the full 100-card deck in a permutation-invariant way using Hungarian matching. |
|
|
| ## Architecture |
|
|
| `magicBERT` uses a standard transformer encoder with one addition: after each encoder layer, a cross-attention layer attends to a set of **context card embeddings**. |
| The context cards serve as the conditioning signal β "given these cards, complete the rest of the deck." |
|
|
| ``` |
| input_ids (masked slots to fill) |
| | |
| [Token + Positional Embeddings] |
| | |
| Encoder Layer 1 |
| | |
| Cross-Attention β context_cards |
| | |
| Encoder Layer 2 |
| | |
| Cross-Attention β context_cards |
| ... |
| | |
| LayerNorm β LM Head β logits (B, seq_len, vocab_size) |
| ``` |
|
|
| ## Generation |
|
|
| Two generation modes are available: |
|
|
| - **`generate`** β single-pass: run one forward pass, apply a legality mask (Commander-legal cards only), then solve the global assignment problem with `linear_sum_assignment`.Basics are allowed to repeat; non-basics are constrained to appear at most once. |
|
|
| - **`iterative_generate`** β multi-pass refinement: after each step, the lowest-confidence slots are re-masked and the model is run again, allowing it to revise uncertain picks in light of its other choices. |
| |
| |
| ## Usage |
| |
| ``` |
| import torch |
| from transformers import AutoModel, AutoTokenizer |
| |
| model_name = "nishtahir/magicBERT" |
| model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # type: ignore[assignment] |
| model.eval() |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| |
| |
| cards = ["Yuriko, the Tiger's Shadow"] |
| |
| # Tokenize context cards |
| context_token_ids: list[int] = tokenizer.convert_tokens_to_ids(cards) # type: ignore[assignment] |
| unknown = [ |
| name |
| for name, tid in zip(cards, context_token_ids, strict=True) |
| if tid == tokenizer.unk_token_id |
| ] |
| if unknown: |
| print(f"Warning: the following cards were not found in the vocabulary: {unknown}") |
| |
| # Build (1, C) context tensor |
| context_ids = torch.tensor([context_token_ids], dtype=torch.long) |
| |
| # Built input vector of masked cards. |
| seq_len: int = model.config.seq_len |
| input_ids = torch.full((1, seq_len), model.config.mask_token_id, dtype=torch.long) |
| |
| # Make prediction |
| token_ids = model.generate(input_ids, context_ids=context_ids) # (1, seq_len) |
| |
| # Decode token Ids back into card names |
| slot_ids: list[int] = token_ids[0].tolist() |
| card_names: list[str] = tokenizer.convert_ids_to_tokens(slot_ids) # type: ignore[assignment] |
| |
| pad_token = tokenizer.pad_token |
| deck = [name for name in card_names if name != pad_token] |
| |
| print(f"\nGenerated deck ({len(deck)} cards):") |
| for i, name in enumerate(deck, 1): |
| print(f" {i:>3}. {name}") |
| |
| # Generated deck (100 cards): |
| # 1. Watery Grave |
| # 2. Yuriko, the Tiger's Shadow |
| # 3. Verdant Catacombs |
| # 4. Island |
| # 5. Prosperous Thief |
| # 6. Clearwater Pathway // Murkwater Pathway |
| # 7. Island |
| # 8. Island |
| # 9. Mist-Syndicate Naga |
| # 10. Marsh Flats |
| ``` |