File size: 3,182 Bytes
dc7b215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from dataclasses import dataclass

import torch
from torch import Tensor, nn
from transformers import (
    AutoConfig,
    AutoModel,
    PretrainedConfig,
    PreTrainedModel,
)


class MtgGloVeConfig(PretrainedConfig):
    model_type = "mtg-glove-commander"

    def __init__(
        self,
        vocab_size: int = 10000,
        embedding_dim: int = 128,
        unk_token_id: int = 0,
        cooccurrence_alpha: float = 0.75,
        cooccurrence_xmax: float = 100.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.unk_token_id = unk_token_id
        self.cooccurrence_alpha = cooccurrence_alpha
        self.cooccurrence_xmax = cooccurrence_xmax


MtgGloVeConfig.register_for_auto_class(AutoConfig)


@dataclass
class MtgGloVeOutput:
    loss: Tensor | None = None


class MtgGloVeModel(PreTrainedModel):
    config_class = MtgGloVeConfig

    def __init__(self, config: MtgGloVeConfig):
        super().__init__(config)
        dim = config.embedding_dim
        vocab = config.vocab_size
        self.center_embeddings = nn.Embedding(vocab, dim)
        self.context_embeddings = nn.Embedding(vocab, dim)
        self.center_bias = nn.Embedding(vocab, 1)
        self.context_bias = nn.Embedding(vocab, 1)
        self._reset_parameters()
        self.post_init()

    def _reset_parameters(self) -> None:
        bound = 0.5 / self.config.embedding_dim
        nn.init.uniform_(self.center_embeddings.weight, -bound, bound)
        nn.init.uniform_(self.context_embeddings.weight, -bound, bound)
        nn.init.zeros_(self.center_bias.weight)
        nn.init.zeros_(self.context_bias.weight)

    def forward(
        self,
        *,
        center_ids: Tensor,
        context_ids: Tensor,
        cooccurrence_counts: Tensor,
    ) -> MtgGloVeOutput:
        if center_ids.shape != context_ids.shape or center_ids.shape != cooccurrence_counts.shape:
            raise ValueError("center_ids, context_ids, and cooccurrence_counts must align")

        center = self.center_embeddings(center_ids)
        context = self.context_embeddings(context_ids)
        bias_center = self.center_bias(center_ids).squeeze(-1)
        bias_context = self.context_bias(context_ids).squeeze(-1)

        inner = torch.sum(center * context, dim=1) + bias_center + bias_context
        log_counts = torch.log(cooccurrence_counts)

        xmax = self.config.cooccurrence_xmax
        alpha = self.config.cooccurrence_alpha
        weights = torch.pow(torch.clamp(cooccurrence_counts / xmax, max=1.0), alpha)

        loss = torch.mean(weights * (inner - log_counts) ** 2)
        return MtgGloVeOutput(loss=loss)

    @torch.no_grad()
    def get_combined_embeddings(self) -> Tensor:
        return (self.center_embeddings.weight + self.context_embeddings.weight) / 2.0

    @torch.no_grad()
    def get_card_bias(self, card_id: int) -> float:
        center_bias = self.center_bias(torch.tensor(card_id))
        context_bias = self.context_bias(torch.tensor(card_id))
        return (center_bias + context_bias).item() / 2.0


MtgGloVeModel.register_for_auto_class(AutoModel)