from dataclasses import dataclass import torch from torch import Tensor, nn from transformers import ( AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel, ) class MtgGloVeConfig(PretrainedConfig): model_type = "mtg-glove-commander" def __init__( self, vocab_size: int = 10000, embedding_dim: int = 128, unk_token_id: int = 0, cooccurrence_alpha: float = 0.75, cooccurrence_xmax: float = 100.0, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.unk_token_id = unk_token_id self.cooccurrence_alpha = cooccurrence_alpha self.cooccurrence_xmax = cooccurrence_xmax MtgGloVeConfig.register_for_auto_class(AutoConfig) @dataclass class MtgGloVeOutput: loss: Tensor | None = None class MtgGloVeModel(PreTrainedModel): config_class = MtgGloVeConfig def __init__(self, config: MtgGloVeConfig): super().__init__(config) dim = config.embedding_dim vocab = config.vocab_size self.center_embeddings = nn.Embedding(vocab, dim) self.context_embeddings = nn.Embedding(vocab, dim) self.center_bias = nn.Embedding(vocab, 1) self.context_bias = nn.Embedding(vocab, 1) self._reset_parameters() self.post_init() def _reset_parameters(self) -> None: bound = 0.5 / self.config.embedding_dim nn.init.uniform_(self.center_embeddings.weight, -bound, bound) nn.init.uniform_(self.context_embeddings.weight, -bound, bound) nn.init.zeros_(self.center_bias.weight) nn.init.zeros_(self.context_bias.weight) def forward( self, *, center_ids: Tensor, context_ids: Tensor, cooccurrence_counts: Tensor, ) -> MtgGloVeOutput: if center_ids.shape != context_ids.shape or center_ids.shape != cooccurrence_counts.shape: raise ValueError("center_ids, context_ids, and cooccurrence_counts must align") center = self.center_embeddings(center_ids) context = self.context_embeddings(context_ids) bias_center = self.center_bias(center_ids).squeeze(-1) bias_context = self.context_bias(context_ids).squeeze(-1) inner = torch.sum(center * context, dim=1) + bias_center + bias_context log_counts = torch.log(cooccurrence_counts) xmax = self.config.cooccurrence_xmax alpha = self.config.cooccurrence_alpha weights = torch.pow(torch.clamp(cooccurrence_counts / xmax, max=1.0), alpha) loss = torch.mean(weights * (inner - log_counts) ** 2) return MtgGloVeOutput(loss=loss) @torch.no_grad() def get_combined_embeddings(self) -> Tensor: return (self.center_embeddings.weight + self.context_embeddings.weight) / 2.0 @torch.no_grad() def get_card_bias(self, card_id: int) -> float: center_bias = self.center_bias(torch.tensor(card_id)) context_bias = self.context_bias(torch.tensor(card_id)) return (center_bias + context_bias).item() / 2.0 MtgGloVeModel.register_for_auto_class(AutoModel)