File size: 4,665 Bytes
cf52a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from __future__ import annotations

"""
Modelo de Linguagem Customizado (TransformerEncoder + BPE)
==========================================================

Pequeno modelo de linguagem causal baseado em TransformerEncoder, usando
tokens produzidos pelo `CustomSPTokenizer` (SentencePiece).

Funções principais:
  - `EpistemicLanguageModel`: arquitetura PyTorch
  - `generate_text`: função de geração autoregressiva
  - helpers para salvar/carregar pesos
"""

from dataclasses import dataclass
from typing import Optional, List

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class LMConfig:
    vocab_size: int
    d_model: int = 256
    n_heads: int = 4
    num_layers: int = 4
    dim_feedforward: int = 512
    max_seq_len: int = 256
    dropout: float = 0.1


class EpistemicLanguageModel(nn.Module):
    """
    Modelo de linguagem simples (causal) com TransformerEncoder.
    """

    def __init__(self, config: LMConfig) -> None:
        super().__init__()
        self.config = config

        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.n_heads,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=config.num_layers,
        )
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,  # (batch, seq_len)
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        bsz, seq_len = input_ids.shape
        device = input_ids.device

        pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1)
        x = self.token_emb(input_ids) + self.pos_emb(pos_ids)

        # Máscara causal: cada posição só vê tokens anteriores
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=device, dtype=torch.bool),
            diagonal=1,
        )

        if attention_mask is not None:
            # attention_mask: (batch, seq_len) com 1 para tokens válidos, 0 para pad
            # A API de TransformerEncoder usa src_key_padding_mask com True = pad
            key_padding_mask = attention_mask == 0
        else:
            key_padding_mask = None

        hidden = self.encoder(
            x,
            mask=causal_mask,
            src_key_padding_mask=key_padding_mask,
        )
        logits = self.lm_head(hidden)
        return logits


def generate_text(
    model: EpistemicLanguageModel,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int = 50,
    device: Optional[torch.device] = None,
) -> str:
    """
    Geração autoregressiva simples.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    model.to(device)

    ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
    input_ids = torch.tensor([ids], dtype=torch.long, device=device)

    for _ in range(max_new_tokens):
        if input_ids.size(1) >= model.config.max_seq_len:
            break

        with torch.no_grad():
            logits = model(input_ids)  # (1, seq_len, vocab)
            next_token_logits = logits[0, -1, :] / max(temperature, 1e-4)

            if top_k > 0:
                values, indices = torch.topk(next_token_logits, k=min(top_k, next_token_logits.size(-1)))
                probs = F.softmax(values, dim=-1)
                next_token = indices[torch.multinomial(probs, num_samples=1)]
            else:
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

        input_ids = torch.cat([input_ids, next_token.view(1, 1)], dim=1)

    generated_ids: List[int] = input_ids[0].tolist()
    return tokenizer.decode(generated_ids)


def save_lm(model: EpistemicLanguageModel, path: str) -> None:
    torch.save({"config": model.config.__dict__, "state_dict": model.state_dict()}, path)


def load_lm(path: str, vocab_size: int) -> EpistemicLanguageModel:
    data = torch.load(path, map_location="cpu")
    cfg_dict = data.get("config", {})
    cfg_dict["vocab_size"] = vocab_size  # garante compatibilidade
    config = LMConfig(**cfg_dict)
    model = EpistemicLanguageModel(config)
    model.load_state_dict(data["state_dict"])
    return model