| from dataclasses import dataclass |
|
|
| import torch |
| from torch import Tensor, nn |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| PretrainedConfig, |
| PreTrainedModel, |
| ) |
|
|
|
|
| class MtgGloVeConfig(PretrainedConfig): |
| model_type = "mtg-glove-commander" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 10000, |
| embedding_dim: int = 128, |
| unk_token_id: int = 0, |
| cooccurrence_alpha: float = 0.75, |
| cooccurrence_xmax: float = 100.0, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.embedding_dim = embedding_dim |
| self.unk_token_id = unk_token_id |
| self.cooccurrence_alpha = cooccurrence_alpha |
| self.cooccurrence_xmax = cooccurrence_xmax |
|
|
|
|
| MtgGloVeConfig.register_for_auto_class(AutoConfig) |
|
|
|
|
| @dataclass |
| class MtgGloVeOutput: |
| loss: Tensor | None = None |
|
|
|
|
| class MtgGloVeModel(PreTrainedModel): |
| config_class = MtgGloVeConfig |
|
|
| def __init__(self, config: MtgGloVeConfig): |
| super().__init__(config) |
| dim = config.embedding_dim |
| vocab = config.vocab_size |
| self.center_embeddings = nn.Embedding(vocab, dim) |
| self.context_embeddings = nn.Embedding(vocab, dim) |
| self.center_bias = nn.Embedding(vocab, 1) |
| self.context_bias = nn.Embedding(vocab, 1) |
| self._reset_parameters() |
| self.post_init() |
|
|
| def _reset_parameters(self) -> None: |
| bound = 0.5 / self.config.embedding_dim |
| nn.init.uniform_(self.center_embeddings.weight, -bound, bound) |
| nn.init.uniform_(self.context_embeddings.weight, -bound, bound) |
| nn.init.zeros_(self.center_bias.weight) |
| nn.init.zeros_(self.context_bias.weight) |
|
|
| def forward( |
| self, |
| *, |
| center_ids: Tensor, |
| context_ids: Tensor, |
| cooccurrence_counts: Tensor, |
| ) -> MtgGloVeOutput: |
| if center_ids.shape != context_ids.shape or center_ids.shape != cooccurrence_counts.shape: |
| raise ValueError("center_ids, context_ids, and cooccurrence_counts must align") |
|
|
| center = self.center_embeddings(center_ids) |
| context = self.context_embeddings(context_ids) |
| bias_center = self.center_bias(center_ids).squeeze(-1) |
| bias_context = self.context_bias(context_ids).squeeze(-1) |
|
|
| inner = torch.sum(center * context, dim=1) + bias_center + bias_context |
| log_counts = torch.log(cooccurrence_counts) |
|
|
| xmax = self.config.cooccurrence_xmax |
| alpha = self.config.cooccurrence_alpha |
| weights = torch.pow(torch.clamp(cooccurrence_counts / xmax, max=1.0), alpha) |
|
|
| loss = torch.mean(weights * (inner - log_counts) ** 2) |
| return MtgGloVeOutput(loss=loss) |
|
|
| @torch.no_grad() |
| def get_combined_embeddings(self) -> Tensor: |
| return (self.center_embeddings.weight + self.context_embeddings.weight) / 2.0 |
|
|
| @torch.no_grad() |
| def get_card_bias(self, card_id: int) -> float: |
| center_bias = self.center_bias(torch.tensor(card_id)) |
| context_bias = self.context_bias(torch.tensor(card_id)) |
| return (center_bias + context_bias).item() / 2.0 |
|
|
|
|
| MtgGloVeModel.register_for_auto_class(AutoModel) |
|
|