| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
| class SykoSLMConfig(PretrainedConfig): |
| model_type = "sykollm" |
| def __init__(self, vocab_size=32000, d_model=768, n_layers=24, n_heads=6, |
| num_memory_tokens=16, chunk_size=128, context_size=1024, |
| overlap_size=16, code_overlap_size=64, abstract_head_hidden=256, |
| abstract_head_layers=2, intermediate_size=3072, **kwargs): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.n_heads = n_heads |
| self.num_memory_tokens = num_memory_tokens |
| self.chunk_size = chunk_size |
| self.context_size = context_size |
| self.overlap_size = overlap_size |
| self.code_overlap_size = code_overlap_size |
| self.abstract_head_hidden = abstract_head_hidden |
| self.abstract_head_layers = abstract_head_layers |
| self.intermediate_size = intermediate_size |
|
|
| def apply_rotary_emb(x, cos, sin): |
| cos, sin = cos.to(x.dtype), sin.to(x.dtype) |
| d = x.shape[-1] |
| x1, x2 = x[..., :d//2], x[..., d//2:] |
| return (x * cos) + (torch.cat([-x2, x1], dim=-1) * sin) |
|
|
| class SykoRoPE(nn.Module): |
| def __init__(self, dim, base=10000.0): |
| super().__init__() |
| self.dim, self.base = dim, base |
| def forward(self, positions): |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=positions.device).float() / self.dim)) |
| freqs = torch.outer(positions.float(), inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :] |
|
|
| class SykoAttention(nn.Module): |
| def __init__(self, d_model, n_heads): |
| super().__init__() |
| self.n_heads, self.head_dim = n_heads, d_model // n_heads |
| self.qkv = nn.Linear(d_model, d_model * 3, bias=False) |
| self.out = nn.Linear(d_model, d_model, bias=False) |
| def forward(self, x, cos, sin): |
| B, L, D = x.shape |
| qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) |
| out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| return self.out(out.transpose(1, 2).reshape(B, L, D)) |
|
|
| class SykoTransformerLayer(nn.Module): |
| def __init__(self, d_model, n_heads, intermediate_size): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(d_model) |
| self.attn = SykoAttention(d_model, n_heads) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, intermediate_size), nn.GELU(), |
| nn.Dropout(0.0), |
| nn.Linear(intermediate_size, d_model) |
| ) |
| def forward(self, x, cos, sin): |
| x = x + self.attn(self.norm1(x), cos, sin) |
| return x + self.mlp(self.norm2(x)) |
|
|
| class SykoMemoryGate(nn.Module): |
| def __init__(self, d_model): |
| super().__init__() |
| self.forget_linear = nn.Linear(d_model * 2, d_model) |
| self.update_linear = nn.Linear(d_model, d_model) |
| self.norm = nn.LayerNorm(d_model) |
| def forward(self, current_context, prev_memory): |
| combined = torch.cat([current_context, prev_memory], dim=-1) |
| forget_ratio = torch.sigmoid(self.forget_linear(combined)) |
| new_candidate = torch.tanh(self.update_linear(current_context)) |
| return self.norm((forget_ratio * prev_memory) + ((1 - forget_ratio) * new_candidate)) |
|
|
| class SykoSLM(PreTrainedModel): |
| config_class = SykoSLMConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.mem_tokens = config.num_memory_tokens |
| self.d_model = config.d_model |
| pad_idx = getattr(config, "pad_token_id", 0) or 0 |
| self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=pad_idx) |
| self.mem_pos_emb = nn.Embedding(config.num_memory_tokens, config.d_model) |
| self.rope = SykoRoPE(config.d_model // config.n_heads) |
| self.layers = nn.ModuleList([ |
| SykoTransformerLayer(config.d_model, config.n_heads, config.intermediate_size) |
| for _ in range(config.n_layers) |
| ]) |
| self.final_norm = nn.LayerNorm(config.d_model) |
| self.memory_gate = SykoMemoryGate(config.d_model) |
| self.fc_out = nn.Linear(config.d_model, config.vocab_size) |
|
|
| def forward(self, input_ids, prev_memory=None, chunk_start_idx=0, **kwargs): |
| B = input_ids.size(0) |
| if prev_memory is None: |
| prev_memory = torch.zeros(B, self.mem_tokens, self.d_model, device=input_ids.device) |
| x = self.embedding(input_ids) |
| mem_idx = torch.arange(self.mem_tokens, device=input_ids.device) |
| memory_with_pos = prev_memory + self.mem_pos_emb(mem_idx).unsqueeze(0) |
| x_with_memory = torch.cat([memory_with_pos, x], dim=1) |
| mem_pos = torch.zeros(self.mem_tokens, dtype=torch.long, device=input_ids.device) |
| word_pos = torch.arange(chunk_start_idx, chunk_start_idx + x.size(1), device=input_ids.device) |
| cos, sin = self.rope(torch.cat([mem_pos, word_pos])) |
| for layer in self.layers: |
| x_with_memory = layer(x_with_memory, cos, sin) |
| x_with_memory = self.final_norm(x_with_memory) |
| memory_output = x_with_memory[:, :self.mem_tokens, :] |
| token_outputs = x_with_memory[:, self.mem_tokens:, :] |
| return self.fc_out(token_outputs), self.memory_gate(memory_output, prev_memory) |
|
|
| def generate_text(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50): |
| self.eval() |
| device = input_ids.device |
| prev_memory = torch.zeros(1, self.mem_tokens, self.d_model, device=device) |
| generated = input_ids.clone() |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| chunk = generated[:, -self.config.chunk_size:] |
| logits, prev_memory = self.forward(chunk, prev_memory) |
| next_logits = logits[:, -1, :] / temperature |
| top_k_vals, top_k_idx = torch.topk(next_logits, k=min(top_k, next_logits.size(-1))) |
| filtered = torch.full_like(next_logits, float("-inf")) |
| filtered.scatter_(1, top_k_idx, top_k_vals) |
| next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1) |
| generated = torch.cat([generated, next_token], dim=1) |
| eos = getattr(self.config, "eos_token_id", None) |
| if eos and next_token.item() == eos: |
| break |
| return generated |
|
|