Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,97 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# magicBERT
|
| 6 |
+
|
| 7 |
+
A masked-language-model-style transformer for Commander deck completion. Given a partial deck (the "context") and a sequence of masked slots, magicBERT predicts the full 100-card deck in a permutation-invariant way using Hungarian matching.
|
| 8 |
+
|
| 9 |
+
## Architecture
|
| 10 |
+
|
| 11 |
+
`magicBERT` uses a standard transformer encoder with one addition: after each encoder layer, a cross-attention layer attends to a set of **context card embeddings**.
|
| 12 |
+
The context cards serve as the conditioning signal — "given these cards, complete the rest of the deck."
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
input_ids (masked slots to fill)
|
| 16 |
+
|
|
| 17 |
+
[Token + Positional Embeddings]
|
| 18 |
+
|
|
| 19 |
+
Encoder Layer 1
|
| 20 |
+
|
|
| 21 |
+
Cross-Attention → context_cards
|
| 22 |
+
|
|
| 23 |
+
Encoder Layer 2
|
| 24 |
+
|
|
| 25 |
+
Cross-Attention → context_cards
|
| 26 |
+
...
|
| 27 |
+
|
|
| 28 |
+
LayerNorm → LM Head → logits (B, seq_len, vocab_size)
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Generation
|
| 32 |
+
|
| 33 |
+
Two generation modes are available:
|
| 34 |
+
|
| 35 |
+
- **`generate`** — single-pass: run one forward pass, apply a legality mask (Commander-legal cards only), then solve the global assignment problem with `linear_sum_assignment`.Basics are allowed to repeat; non-basics are constrained to appear at most once.
|
| 36 |
+
|
| 37 |
+
- **`iterative_generate`** — multi-pass refinement: after each step, the lowest-confidence slots are re-masked and the model is run again, allowing it to revise uncertain picks in light of its other choices.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
## Usage
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
import torch
|
| 44 |
+
from transformers import AutoModel, AutoTokenizer
|
| 45 |
+
|
| 46 |
+
model_name = "nishtahir/magicBERT"
|
| 47 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # type: ignore[assignment]
|
| 48 |
+
model.eval()
|
| 49 |
+
|
| 50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
cards = ["Yuriko, the Tiger's Shadow"]
|
| 54 |
+
|
| 55 |
+
# Tokenize context cards
|
| 56 |
+
context_token_ids: list[int] = tokenizer.convert_tokens_to_ids(cards) # type: ignore[assignment]
|
| 57 |
+
unknown = [
|
| 58 |
+
name
|
| 59 |
+
for name, tid in zip(cards, context_token_ids, strict=True)
|
| 60 |
+
if tid == tokenizer.unk_token_id
|
| 61 |
+
]
|
| 62 |
+
if unknown:
|
| 63 |
+
print(f"Warning: the following cards were not found in the vocabulary: {unknown}")
|
| 64 |
+
|
| 65 |
+
# Build (1, C) context tensor
|
| 66 |
+
context_ids = torch.tensor([context_token_ids], dtype=torch.long)
|
| 67 |
+
|
| 68 |
+
# Built input vector of masked cards.
|
| 69 |
+
seq_len: int = model.config.seq_len
|
| 70 |
+
input_ids = torch.full((1, seq_len), model.config.mask_token_id, dtype=torch.long)
|
| 71 |
+
|
| 72 |
+
# Make prediction
|
| 73 |
+
token_ids = model.generate(input_ids, context_ids=context_ids) # (1, seq_len)
|
| 74 |
+
|
| 75 |
+
# Decode token Ids back into card names
|
| 76 |
+
slot_ids: list[int] = token_ids[0].tolist()
|
| 77 |
+
card_names: list[str] = tokenizer.convert_ids_to_tokens(slot_ids) # type: ignore[assignment]
|
| 78 |
+
|
| 79 |
+
pad_token = tokenizer.pad_token
|
| 80 |
+
deck = [name for name in card_names if name != pad_token]
|
| 81 |
+
|
| 82 |
+
print(f"\nGenerated deck ({len(deck)} cards):")
|
| 83 |
+
for i, name in enumerate(deck, 1):
|
| 84 |
+
print(f" {i:>3}. {name}")
|
| 85 |
+
|
| 86 |
+
# Generated deck (100 cards):
|
| 87 |
+
# 1. Watery Grave
|
| 88 |
+
# 2. Yuriko, the Tiger's Shadow
|
| 89 |
+
# 3. Verdant Catacombs
|
| 90 |
+
# 4. Island
|
| 91 |
+
# 5. Prosperous Thief
|
| 92 |
+
# 6. Clearwater Pathway // Murkwater Pathway
|
| 93 |
+
# 7. Island
|
| 94 |
+
# 8. Island
|
| 95 |
+
# 9. Mist-Syndicate Naga
|
| 96 |
+
# 10. Marsh Flats
|
| 97 |
+
```
|