nishtahir's picture
Initial commit
dc7b215 unverified
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from transformers import (
AutoConfig,
AutoModel,
PretrainedConfig,
PreTrainedModel,
)
class MtgGloVeConfig(PretrainedConfig):
model_type = "mtg-glove-commander"
def __init__(
self,
vocab_size: int = 10000,
embedding_dim: int = 128,
unk_token_id: int = 0,
cooccurrence_alpha: float = 0.75,
cooccurrence_xmax: float = 100.0,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.unk_token_id = unk_token_id
self.cooccurrence_alpha = cooccurrence_alpha
self.cooccurrence_xmax = cooccurrence_xmax
MtgGloVeConfig.register_for_auto_class(AutoConfig)
@dataclass
class MtgGloVeOutput:
loss: Tensor | None = None
class MtgGloVeModel(PreTrainedModel):
config_class = MtgGloVeConfig
def __init__(self, config: MtgGloVeConfig):
super().__init__(config)
dim = config.embedding_dim
vocab = config.vocab_size
self.center_embeddings = nn.Embedding(vocab, dim)
self.context_embeddings = nn.Embedding(vocab, dim)
self.center_bias = nn.Embedding(vocab, 1)
self.context_bias = nn.Embedding(vocab, 1)
self._reset_parameters()
self.post_init()
def _reset_parameters(self) -> None:
bound = 0.5 / self.config.embedding_dim
nn.init.uniform_(self.center_embeddings.weight, -bound, bound)
nn.init.uniform_(self.context_embeddings.weight, -bound, bound)
nn.init.zeros_(self.center_bias.weight)
nn.init.zeros_(self.context_bias.weight)
def forward(
self,
*,
center_ids: Tensor,
context_ids: Tensor,
cooccurrence_counts: Tensor,
) -> MtgGloVeOutput:
if center_ids.shape != context_ids.shape or center_ids.shape != cooccurrence_counts.shape:
raise ValueError("center_ids, context_ids, and cooccurrence_counts must align")
center = self.center_embeddings(center_ids)
context = self.context_embeddings(context_ids)
bias_center = self.center_bias(center_ids).squeeze(-1)
bias_context = self.context_bias(context_ids).squeeze(-1)
inner = torch.sum(center * context, dim=1) + bias_center + bias_context
log_counts = torch.log(cooccurrence_counts)
xmax = self.config.cooccurrence_xmax
alpha = self.config.cooccurrence_alpha
weights = torch.pow(torch.clamp(cooccurrence_counts / xmax, max=1.0), alpha)
loss = torch.mean(weights * (inner - log_counts) ** 2)
return MtgGloVeOutput(loss=loss)
@torch.no_grad()
def get_combined_embeddings(self) -> Tensor:
return (self.center_embeddings.weight + self.context_embeddings.weight) / 2.0
@torch.no_grad()
def get_card_bias(self, card_id: int) -> float:
center_bias = self.center_bias(torch.tensor(card_id))
context_bias = self.context_bias(torch.tensor(card_id))
return (center_bias + context_bias).item() / 2.0
MtgGloVeModel.register_for_auto_class(AutoModel)