| """ |
| Kayra Turkish GPT Model |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from .configuration_kayra import KayraConfig |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| |
| def forward(self, x): |
| rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) |
| return x / rms * self.weight |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.n_heads = config.num_attention_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| |
| self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False) |
| self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| self.dropout = nn.Dropout(config.hidden_dropout) |
| |
| mask = torch.triu(torch.ones(config.max_position_embeddings, config.max_position_embeddings), diagonal=1).bool() |
| self.register_buffer("mask", mask) |
| |
| def forward(self, x): |
| B, T, C = x.shape |
| |
| qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv.permute(2, 0, 3, 1, 4) |
| |
| attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
| attn = attn.masked_fill(self.mask[:T, :T], float('-inf')) |
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
| |
| out = (attn @ v).transpose(1, 2).reshape(B, T, C) |
| return self.proj(out) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
| self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.dropout = nn.Dropout(config.hidden_dropout) |
| |
| def forward(self, x): |
| return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.norm1 = RMSNorm(config.hidden_size) |
| self.attn = Attention(config) |
| self.norm2 = RMSNorm(config.hidden_size) |
| self.ff = FeedForward(config) |
| |
| def forward(self, x): |
| x = x + self.attn(self.norm1(x)) |
| x = x + self.ff(self.norm2(x)) |
| return x |
|
|
|
|
| class KayraPreTrainedModel(PreTrainedModel): |
| config_class = KayraConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
| class KayraForCausalLM(KayraPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| self.tok_emb = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.pos_emb = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| self.drop = nn.Dropout(config.hidden_dropout) |
| |
| self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)]) |
| self.norm = RMSNorm(config.hidden_size) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.tok_emb.weight |
| |
| self.post_init() |
| |
| def get_input_embeddings(self): |
| return self.tok_emb |
| |
| def set_input_embeddings(self, value): |
| self.tok_emb = value |
| |
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| B, T = input_ids.shape |
| |
| pos = torch.arange(T, device=input_ids.device) |
| x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) |
| |
| for block in self.blocks: |
| x = block(x) |
| |
| x = self.norm(x) |
| logits = self.lm_head(x) |
| |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
| |
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
| |
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|