nishtahir commited on
Commit
7cf5414
·
verified ·
1 Parent(s): 21beef1

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ generation_config.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MagicBERT"
4
+ ],
5
+ "attention_dropout": 0.15,
6
+ "auto_map": {
7
+ "AutoConfig": "config.MagicBERTConfig",
8
+ "AutoModel": "modeling.MagicBERT"
9
+ },
10
+ "d_model": 128,
11
+ "dim_feed_forward": 341,
12
+ "dtype": "float32",
13
+ "embedding_dropout": 0.15,
14
+ "mask_token_id": 0,
15
+ "model_type": "magicBERT",
16
+ "num_attention_heads": 8,
17
+ "num_encoder_layers": 4,
18
+ "pad_token_id": 1,
19
+ "seq_len": 100,
20
+ "tie_embeddings": true,
21
+ "transformers_version": "4.57.3",
22
+ "vocab_size": 36476
23
+ }
config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from transformers import AutoConfig, GenerationConfig, PretrainedConfig
4
+
5
+
6
+ class MagicBERTConfig(PretrainedConfig):
7
+ model_type = "magicBERT"
8
+
9
+ def __init__(
10
+ self,
11
+ *,
12
+ attention_dropout: float = 0.15,
13
+ d_model: int = 768,
14
+ dim_feed_forward: int = 3072,
15
+ embedding_dropout: float = 0.15,
16
+ mask_token_id: int = 0,
17
+ num_attention_heads: int = 8,
18
+ num_encoder_layers: int = 4,
19
+ pad_token_id: int = 1,
20
+ seq_len: int = 100,
21
+ tie_embeddings: bool = True,
22
+ vocab_size: int = 35000,
23
+ **kwargs,
24
+ ):
25
+ if "tie_word_embeddings" not in kwargs:
26
+ kwargs["tie_word_embeddings"] = tie_embeddings
27
+ super().__init__(**kwargs)
28
+ self.attention_dropout = attention_dropout
29
+ self.d_model = d_model
30
+ self.dim_feed_forward = dim_feed_forward or int(d_model * 8 / 3)
31
+ self.embedding_dropout = embedding_dropout
32
+ self.num_attention_heads = num_attention_heads
33
+ self.mask_token_id = mask_token_id
34
+ self.num_encoder_layers = num_encoder_layers
35
+ self.seq_len = seq_len
36
+ self.tie_embeddings = tie_embeddings
37
+ self.vocab_size = vocab_size
38
+ self.pad_token_id = pad_token_id
39
+
40
+
41
+ class MagicBERTGenerationConfig(GenerationConfig):
42
+ model_type = MagicBERTConfig.model_type
43
+
44
+ def __init__(self, *, cards: list[dict[str, Any]] | None = None, **kwargs):
45
+ super().__init__(**kwargs)
46
+ self.cards = cards or []
47
+
48
+
49
+ MagicBERTConfig.register_for_auto_class(AutoConfig)
generation_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:092cc25d053bd5adb42df7598ec6361876ae65c3a1b2121290cd8609cf5c2693
3
+ size 11883187
metrics.json ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:888c376e42ba1008aa38b70ac69b00f9fc233bb9e798a631c603a79753ff65f2
3
+ size 22271036
modeling.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from scipy.optimize import linear_sum_assignment
7
+ from torch import Tensor
8
+ from transformers import AutoModel, PreTrainedModel
9
+
10
+ from .config import MagicBERTConfig
11
+
12
+
13
+ class HungarianTokenLoss(nn.Module):
14
+ """
15
+ Permutation-invariant token classification loss using Hungarian matching.
16
+
17
+ logits: (B, N, V) - N slot queries, V vocab
18
+ targets: (B, M) - M target token ids (unordered multiset)
19
+ target_mask: (B, M) bool/0-1 mask; True for valid targets, False for padding (optional)
20
+ """
21
+
22
+ def __init__(self, reduction: str = "mean", label_smoothing: float = 0.0):
23
+ super().__init__()
24
+ if reduction not in {"mean", "sum", "none"}:
25
+ raise ValueError("reduction must be one of: mean, sum, none")
26
+ if not (0.0 <= label_smoothing < 1.0):
27
+ raise ValueError("label_smoothing must be in [0, 1)")
28
+ self.reduction = reduction
29
+ self.label_smoothing = float(label_smoothing)
30
+
31
+ def forward(
32
+ self,
33
+ logits: torch.Tensor,
34
+ targets: torch.Tensor,
35
+ *,
36
+ target_mask: torch.Tensor | None = None,
37
+ ) -> torch.Tensor:
38
+ if logits.dim() != 3:
39
+ raise ValueError("logits must be (B, N, V)")
40
+ if targets.dim() != 2:
41
+ raise ValueError("targets must be (B, M)")
42
+ if logits.size(0) != targets.size(0):
43
+ raise ValueError("batch size mismatch between logits and targets")
44
+
45
+ B, N, V = logits.shape
46
+ _, M = targets.shape
47
+
48
+ if target_mask is not None:
49
+ if target_mask.shape != targets.shape:
50
+ raise ValueError("target_mask must have same shape as targets (B, M)")
51
+ valid_mask = target_mask.bool()
52
+ else:
53
+ valid_mask = torch.ones_like(targets, dtype=torch.bool)
54
+
55
+ log_probs = F.log_softmax(logits, dim=-1) # (B, N, V)
56
+
57
+ batch_losses: list[torch.Tensor] = []
58
+ for b in range(B):
59
+ # Select valid targets for this sample: ids shape (m,)
60
+ ids = targets[b][valid_mask[b]]
61
+ m = int(ids.numel())
62
+ if m == 0 or N == 0:
63
+ # No targets or no predictions -> zero loss
64
+ batch_losses.append(log_probs[b].sum() * 0.0)
65
+ continue
66
+
67
+ # Cost matrix: (N, m) where cost[i, j] = -log p_i(ids[j])
68
+ # Gather: log_probs[b] is (N, V), ids is (m,) -> result (N, m)
69
+ lp = log_probs[b] # (N, V)
70
+ cost = -lp[:, ids] # (N, m)
71
+
72
+ # Hungarian assignment (CPU, non-differentiable)
73
+ row_ind, col_ind = linear_sum_assignment(cost.detach().cpu().numpy())
74
+
75
+ row = torch.tensor(row_ind, device=logits.device, dtype=torch.long)
76
+ col = torch.tensor(col_ind, device=logits.device, dtype=torch.long)
77
+
78
+ matched_cost = cost[row, col] # (k,) where k = min(N, m)
79
+
80
+ # Optional label smoothing, applied only on matched pairs
81
+ if self.label_smoothing > 0.0:
82
+ # nll for matched pairs is matched_cost
83
+ # smooth loss is -mean log_probs over vocab
84
+ matched_lp = lp[row] # (k, V)
85
+ smooth = -matched_lp.mean(dim=-1) # (k,)
86
+ eps = self.label_smoothing
87
+ matched_cost = (1.0 - eps) * matched_cost + eps * smooth
88
+
89
+ if self.reduction == "sum":
90
+ batch_losses.append(matched_cost.sum())
91
+ else:
92
+ batch_losses.append(matched_cost.mean())
93
+
94
+ out = torch.stack(batch_losses) if batch_losses else torch.tensor(0.0, device=logits.device)
95
+
96
+ if self.reduction == "none":
97
+ return out
98
+ if self.reduction == "sum":
99
+ return out.sum()
100
+ return out.mean()
101
+
102
+
103
+ class MagicBERTOutput(NamedTuple):
104
+ logits: Tensor # (B, seq_len, vocab_size)
105
+ loss: Tensor | None # scalar, present when target_ids were supplied
106
+
107
+
108
+ class MagicBERTModel(nn.Module):
109
+ def __init__(
110
+ self,
111
+ *,
112
+ attention_dropout: float,
113
+ d_model: int,
114
+ dim_feed_forward: int,
115
+ embedding_dropout: float,
116
+ mask_token_id: int,
117
+ num_attention_heads: int,
118
+ num_encoder_layers: int,
119
+ pad_token_id: int,
120
+ seq_len: int,
121
+ tie_embeddings: bool,
122
+ vocab_size: int,
123
+ ):
124
+ super().__init__()
125
+ self.seq_len = seq_len
126
+ self.tie_embeddings = tie_embeddings
127
+ self.pad_token_id = pad_token_id
128
+ self.mask_token_id = mask_token_id
129
+
130
+ self.semantic_E = nn.Embedding(vocab_size, d_model)
131
+ self.pos_E = nn.Embedding(seq_len, d_model)
132
+ self.embedding_dropout = nn.Dropout(embedding_dropout)
133
+ self.context_scale = nn.Parameter(torch.ones(1))
134
+
135
+ self.encoder_layers = nn.ModuleList(
136
+ [
137
+ nn.TransformerEncoderLayer(
138
+ batch_first=True,
139
+ d_model=d_model,
140
+ dim_feedforward=dim_feed_forward,
141
+ dropout=attention_dropout,
142
+ nhead=num_attention_heads,
143
+ )
144
+ for _ in range(num_encoder_layers)
145
+ ]
146
+ )
147
+
148
+ self.context_query_norms = nn.ModuleList(
149
+ [nn.LayerNorm(d_model) for _ in range(num_encoder_layers)]
150
+ )
151
+ self.context_kv_norms = nn.ModuleList(
152
+ [nn.LayerNorm(d_model) for _ in range(num_encoder_layers)]
153
+ )
154
+ self.context_attention_layers = nn.ModuleList(
155
+ [
156
+ nn.MultiheadAttention(
157
+ embed_dim=d_model,
158
+ num_heads=num_attention_heads,
159
+ dropout=attention_dropout,
160
+ batch_first=True,
161
+ )
162
+ for _ in range(num_encoder_layers)
163
+ ]
164
+ )
165
+ self.layer_norm = nn.LayerNorm(d_model)
166
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
167
+ self.loss_fn = HungarianTokenLoss()
168
+ if tie_embeddings:
169
+ self.tie_weights()
170
+
171
+ def _attention_mask(self, input_ids: Tensor, attention_mask: Tensor | None) -> Tensor:
172
+ if attention_mask is not None:
173
+ if attention_mask.shape != input_ids.shape:
174
+ raise ValueError("attention_mask must have the same shape as input_ids")
175
+ return attention_mask.bool()
176
+ return input_ids.ne(self.pad_token_id)
177
+
178
+ def forward(
179
+ self,
180
+ *,
181
+ input_ids: Tensor,
182
+ attention_mask: Tensor | None = None,
183
+ context_ids: Tensor,
184
+ context_attention_mask: Tensor | None = None,
185
+ target_ids: Tensor | None = None,
186
+ target_attention_mask: Tensor | None = None,
187
+ ) -> MagicBERTOutput:
188
+ if input_ids.dim() != 2:
189
+ raise ValueError("input_ids must be of shape (batch, seq_len)")
190
+ if input_ids.size(0) == 0:
191
+ raise ValueError("input_ids batch dimension must be > 0")
192
+ if context_ids.size(0) != input_ids.size(0):
193
+ raise ValueError("context_ids batch dimension must match input_ids")
194
+ if context_attention_mask is None:
195
+ context_attention_mask = context_ids.ne(self.pad_token_id)
196
+ if context_attention_mask.shape != context_ids.shape:
197
+ raise ValueError("context_attention_mask must have the same shape as context_ids")
198
+
199
+ padding_mask = ~self._attention_mask(input_ids, attention_mask)
200
+ positions = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0)
201
+ src_embeddings = self.embedding_dropout(self.semantic_E(input_ids) + self.pos_E(positions))
202
+
203
+ context_embeddings = self.semantic_E(context_ids)
204
+ context_embeddings = self.embedding_dropout(context_embeddings)
205
+
206
+ context_padding_mask = ~context_attention_mask.bool()
207
+
208
+ encoded = src_embeddings
209
+ for idx, layer in enumerate(self.encoder_layers):
210
+ encoded = layer(encoded, src_key_padding_mask=padding_mask)
211
+ norm_encoded = self.context_query_norms[idx](encoded)
212
+ norm_context = self.context_kv_norms[idx](context_embeddings)
213
+ attn_output, _ = self.context_attention_layers[idx](
214
+ norm_encoded,
215
+ norm_context,
216
+ norm_context,
217
+ key_padding_mask=context_padding_mask,
218
+ need_weights=False,
219
+ )
220
+ encoded = encoded + self.context_scale * attn_output
221
+
222
+ encoded = self.layer_norm(encoded)
223
+ logits = self.lm_head(encoded)
224
+
225
+ loss = None
226
+ if target_ids is not None:
227
+ loss = self.loss_fn(logits, target_ids, target_mask=target_attention_mask)
228
+
229
+ return MagicBERTOutput(logits=logits, loss=loss)
230
+
231
+ def tie_weights(self, **kwargs) -> None:
232
+ if self.tie_embeddings:
233
+ self.lm_head.weight = self.semantic_E.weight
234
+
235
+
236
+ class MagicBERT(PreTrainedModel):
237
+ config_class = MagicBERTConfig
238
+ _tied_weights_keys = {"model.lm_head.weight": "model.semantic_E.weight"}
239
+
240
+ def __init__(self, config: MagicBERTConfig):
241
+ super().__init__(config)
242
+ self.model = MagicBERTModel(
243
+ attention_dropout=config.attention_dropout,
244
+ d_model=config.d_model,
245
+ dim_feed_forward=config.dim_feed_forward,
246
+ embedding_dropout=config.embedding_dropout,
247
+ mask_token_id=config.mask_token_id,
248
+ num_attention_heads=config.num_attention_heads,
249
+ num_encoder_layers=config.num_encoder_layers,
250
+ pad_token_id=config.pad_token_id, # type: ignore
251
+ seq_len=config.seq_len,
252
+ tie_embeddings=config.tie_embeddings,
253
+ vocab_size=config.vocab_size,
254
+ )
255
+ self.post_init()
256
+
257
+ def tie_weights(self, **kwargs) -> None: # type: ignore
258
+ if self.config.tie_embeddings:
259
+ self.model.tie_weights()
260
+
261
+ def get_input_embeddings(self) -> nn.Module:
262
+ return self.model.semantic_E
263
+
264
+ def set_input_embeddings(self, value: nn.Module):
265
+ self.model.semantic_E = value
266
+ if self.config.tie_embeddings:
267
+ self.tie_weights()
268
+
269
+ def get_output_embeddings(self) -> nn.Module:
270
+ return self.model.lm_head
271
+
272
+ def set_output_embeddings(self, new_embeddings: nn.Module):
273
+ self.model.lm_head = new_embeddings
274
+ if self.config.tie_embeddings:
275
+ self.tie_weights()
276
+
277
+ def forward(
278
+ self,
279
+ *,
280
+ input_ids: Tensor,
281
+ attention_mask: Tensor | None = None,
282
+ context_ids: Tensor,
283
+ context_attention_mask: Tensor | None = None,
284
+ target_ids: Tensor | None = None,
285
+ target_attention_mask: Tensor | None = None,
286
+ ) -> MagicBERTOutput:
287
+ return self.model(
288
+ input_ids=input_ids,
289
+ attention_mask=attention_mask,
290
+ context_ids=context_ids,
291
+ context_attention_mask=context_attention_mask,
292
+ target_ids=target_ids,
293
+ target_attention_mask=target_attention_mask,
294
+ )
295
+
296
+ def _build_legal_token_mask(
297
+ self,
298
+ *,
299
+ device: torch.device,
300
+ cards: list[dict[str, object]],
301
+ ) -> Tensor:
302
+ legal_token_mask = torch.zeros(self.config.vocab_size, device=device, dtype=torch.bool)
303
+ legal_token_mask[self.config.pad_token_id] = True
304
+ legal_token_mask[self.config.mask_token_id] = True
305
+ for card in cards:
306
+ if card.get("commander_legal"):
307
+ token_id = card.get("token_id")
308
+ if isinstance(token_id, int) and 0 <= token_id < self.config.vocab_size:
309
+ legal_token_mask[token_id] = True
310
+ return legal_token_mask
311
+
312
+ def _build_basic_token_mask(
313
+ self,
314
+ *,
315
+ device: torch.device,
316
+ cards: list[dict[str, object]],
317
+ ) -> Tensor:
318
+ basic_token_mask = torch.zeros(self.config.vocab_size, device=device, dtype=torch.bool)
319
+ for card in cards:
320
+ token_id = card.get("token_id")
321
+ type_line = card.get("type_line", "")
322
+ if isinstance(token_id, int) and 0 <= token_id < self.config.vocab_size:
323
+ if isinstance(type_line, str) and "Basic" in type_line:
324
+ basic_token_mask[token_id] = True
325
+ return basic_token_mask
326
+
327
+ @torch.no_grad()
328
+ def generate(
329
+ self,
330
+ input_ids: Tensor,
331
+ *,
332
+ context_ids: Tensor | None = None,
333
+ context_attention_mask: Tensor | None = None,
334
+ ) -> Tensor:
335
+ cards = getattr(self.generation_config, "cards", None)
336
+ if not cards:
337
+ raise ValueError("generation_config.cards is required for legality masking")
338
+
339
+ pad_token_id: int = self.config.pad_token_id # type: ignore
340
+ mask_token_id: int = self.config.mask_token_id
341
+
342
+ if context_ids is None:
343
+ context_ids = input_ids.masked_fill(input_ids.eq(pad_token_id), mask_token_id)
344
+
345
+ legal_token_mask = self._build_legal_token_mask(device=input_ids.device, cards=cards)
346
+ basic_token_mask = self._build_basic_token_mask(device=input_ids.device, cards=cards)
347
+
348
+ output = self(
349
+ input_ids=input_ids,
350
+ context_ids=context_ids,
351
+ context_attention_mask=context_attention_mask,
352
+ )
353
+ logits = output.logits # (B, seq_len, V)
354
+ logits = logits.masked_fill(~legal_token_mask, -1e9)
355
+
356
+ B, num_slots, V = logits.shape
357
+ log_probs = F.log_softmax(logits, dim=-1)
358
+
359
+ # Column pool: non-basics appear once (singleton), basics appear num_slots times
360
+ legal_non_basic = legal_token_mask & ~basic_token_mask
361
+ legal_non_basic[pad_token_id] = False
362
+ legal_non_basic[mask_token_id] = False
363
+ non_basic_ids = legal_non_basic.nonzero(as_tuple=False).flatten().tolist()
364
+ basic_ids = basic_token_mask.nonzero(as_tuple=False).flatten().tolist()
365
+ col_ids: list[int] = non_basic_ids + basic_ids * num_slots
366
+ col_ids_t = torch.tensor(col_ids, device=logits.device, dtype=torch.long)
367
+
368
+ result = torch.full((B, num_slots), pad_token_id, device=logits.device, dtype=torch.long)
369
+ for b in range(B):
370
+ cost = -log_probs[b][:, col_ids_t] # (num_slots, num_cols)
371
+ row_ind, col_ind = linear_sum_assignment(cost.cpu().numpy())
372
+ rows = torch.tensor(row_ind, device=logits.device, dtype=torch.long)
373
+ result[b, rows] = col_ids_t[torch.tensor(col_ind, device=logits.device)]
374
+
375
+ return result
376
+
377
+ @torch.no_grad()
378
+ def iterative_generate(
379
+ self,
380
+ input_ids: Tensor,
381
+ *,
382
+ context_ids: Tensor | None = None,
383
+ context_attention_mask: Tensor | None = None,
384
+ steps: int = 5,
385
+ remask_ratio: float = 0.3,
386
+ ) -> list[Tensor]:
387
+ """Iteratively generate a deck, remasking low-confidence slots between steps.
388
+
389
+ Returns a list of token_id tensors, one per step (each shape (B, num_slots)).
390
+ """
391
+ cards = getattr(self.generation_config, "cards", None)
392
+ if not cards:
393
+ raise ValueError("generation_config.cards is required for legality masking")
394
+
395
+ pad_token_id: int = self.config.pad_token_id # type: ignore
396
+ mask_token_id: int = self.config.mask_token_id
397
+
398
+ if context_ids is None:
399
+ context_ids = input_ids.masked_fill(input_ids.eq(pad_token_id), mask_token_id)
400
+
401
+ legal_token_mask = self._build_legal_token_mask(device=input_ids.device, cards=cards)
402
+ basic_token_mask = self._build_basic_token_mask(device=input_ids.device, cards=cards)
403
+
404
+ legal_non_basic = legal_token_mask & ~basic_token_mask
405
+ legal_non_basic[pad_token_id] = False
406
+ legal_non_basic[mask_token_id] = False
407
+ non_basic_ids = legal_non_basic.nonzero(as_tuple=False).flatten().tolist()
408
+ basic_ids = basic_token_mask.nonzero(as_tuple=False).flatten().tolist()
409
+
410
+ x = input_ids.clone()
411
+ B, num_slots = x.shape
412
+ col_ids: list[int] = non_basic_ids + basic_ids * num_slots
413
+ col_ids_t = torch.tensor(col_ids, device=x.device, dtype=torch.long)
414
+
415
+ all_steps: list[Tensor] = []
416
+
417
+ for step in range(steps):
418
+ is_last = step == steps - 1
419
+
420
+ output = self(
421
+ input_ids=x,
422
+ context_ids=context_ids,
423
+ context_attention_mask=context_attention_mask,
424
+ )
425
+ logits = output.logits.masked_fill(~legal_token_mask, -1e9)
426
+ log_probs = F.log_softmax(logits, dim=-1)
427
+
428
+ result = torch.full((B, num_slots), pad_token_id, device=x.device, dtype=torch.long)
429
+ confidence = torch.full((B, num_slots), float("-inf"), device=x.device)
430
+
431
+ for b in range(B):
432
+ cost = -log_probs[b][:, col_ids_t]
433
+ row_ind, col_ind = linear_sum_assignment(cost.cpu().numpy())
434
+ rows = torch.tensor(row_ind, device=x.device, dtype=torch.long)
435
+ cols = torch.tensor(col_ind, device=x.device, dtype=torch.long)
436
+ result[b, rows] = col_ids_t[cols]
437
+ confidence[b, rows] = -cost[rows, cols]
438
+
439
+ all_steps.append(result.clone())
440
+
441
+ if is_last or remask_ratio <= 0.0:
442
+ x = result
443
+ continue
444
+
445
+ # Remask the lowest-confidence slots so the next step can revise them.
446
+ x = result.clone()
447
+ for b in range(B):
448
+ filled = result[b].ne(pad_token_id).nonzero(as_tuple=False).flatten()
449
+ n_remask = max(0, int(filled.numel() * remask_ratio))
450
+ if n_remask == 0:
451
+ continue
452
+ _, worst = torch.topk(confidence[b, filled], k=n_remask, largest=False)
453
+ x[b, filled[worst]] = mask_token_id
454
+
455
+ return all_steps
456
+
457
+
458
+ MagicBERT.register_for_auto_class(AutoModel)
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "mask_token": "<|mask|>",
3
+ "pad_token": "<|pad|>",
4
+ "unk_token": "<|unk|>"
5
+ }
summary.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eval_loss": {
3
+ "average": 2.8317361718596192
4
+ }
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, PreTrainedTokenizerFast
5
+
6
+
7
+ class MagicBERTTokenizer(PreTrainedTokenizerFast):
8
+ def card_token_ids(self) -> list[int]:
9
+ specials = set(self.all_special_tokens)
10
+ vocab = self.get_vocab()
11
+ return sorted(token_id for token, token_id in vocab.items() if token not in specials)
12
+
13
+ def is_card_token(self, token: str) -> bool:
14
+ return token not in set(self.all_special_tokens)
15
+
16
+ def is_card_id(self, token_id: int) -> bool:
17
+ token: str = self.convert_ids_to_tokens(token_id) # type: ignore
18
+ return bool(token) and self.is_card_token(token)
19
+
20
+ def convert_card_names_to_ids(self, names: Iterable[str]) -> torch.Tensor:
21
+ ids = [self.convert_tokens_to_ids(name) for name in names]
22
+ return torch.tensor(ids, dtype=torch.long)
23
+
24
+ def encode(
25
+ self,
26
+ text,
27
+ text_pair=None,
28
+ add_special_tokens=True,
29
+ padding="max_length",
30
+ truncation=False,
31
+ max_length=None,
32
+ stride=0,
33
+ padding_side=None,
34
+ return_tensors=None,
35
+ **kwargs,
36
+ ):
37
+ if isinstance(text, list) and text_pair is None:
38
+ if not text or isinstance(text[0], str):
39
+ return super().encode(
40
+ text,
41
+ text_pair=text_pair,
42
+ add_special_tokens=add_special_tokens,
43
+ padding=padding,
44
+ truncation=truncation,
45
+ max_length=max_length,
46
+ stride=stride,
47
+ padding_side=padding_side,
48
+ return_tensors=return_tensors,
49
+ is_split_into_words=True,
50
+ **kwargs,
51
+ )
52
+ if isinstance(text[0], list):
53
+ batch_encoding = super().__call__(
54
+ text=text,
55
+ add_special_tokens=add_special_tokens,
56
+ padding=padding,
57
+ truncation=truncation,
58
+ max_length=max_length,
59
+ stride=stride,
60
+ padding_side=padding_side,
61
+ return_tensors=return_tensors,
62
+ is_split_into_words=True,
63
+ **kwargs,
64
+ )
65
+ return batch_encoding["input_ids"]
66
+ return super().encode(
67
+ text,
68
+ text_pair=text_pair,
69
+ add_special_tokens=add_special_tokens,
70
+ padding=padding,
71
+ truncation=truncation,
72
+ max_length=max_length,
73
+ stride=stride,
74
+ padding_side=padding_side,
75
+ return_tensors=return_tensors,
76
+ **kwargs,
77
+ )
78
+
79
+
80
+ MagicBERTTokenizer.register_for_auto_class(AutoTokenizer)
tokenizer_config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<|mask|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<|pad|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<|unk|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ }
27
+ },
28
+ "auto_map": {
29
+ "AutoTokenizer": [
30
+ "tokenizer.MagicBERTTokenizer",
31
+ null
32
+ ]
33
+ },
34
+ "clean_up_tokenization_spaces": false,
35
+ "extra_special_tokens": {},
36
+ "mask_token": "<|mask|>",
37
+ "model_max_length": 100,
38
+ "pad_token": "<|pad|>",
39
+ "tokenizer_class": "MagicBERTTokenizer",
40
+ "unk_token": "<|unk|>"
41
+ }