import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig class SykoSLMConfig(PretrainedConfig): model_type = "sykollm" def __init__(self, vocab_size=32000, d_model=768, n_layers=24, n_heads=6, num_memory_tokens=16, chunk_size=128, context_size=1024, overlap_size=16, code_overlap_size=64, abstract_head_hidden=256, abstract_head_layers=2, intermediate_size=3072, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.num_memory_tokens = num_memory_tokens self.chunk_size = chunk_size self.context_size = context_size self.overlap_size = overlap_size self.code_overlap_size = code_overlap_size self.abstract_head_hidden = abstract_head_hidden self.abstract_head_layers = abstract_head_layers self.intermediate_size = intermediate_size def apply_rotary_emb(x, cos, sin): cos, sin = cos.to(x.dtype), sin.to(x.dtype) d = x.shape[-1] x1, x2 = x[..., :d//2], x[..., d//2:] return (x * cos) + (torch.cat([-x2, x1], dim=-1) * sin) class SykoRoPE(nn.Module): def __init__(self, dim, base=10000.0): super().__init__() self.dim, self.base = dim, base def forward(self, positions): inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=positions.device).float() / self.dim)) freqs = torch.outer(positions.float(), inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :] class SykoAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.n_heads, self.head_dim = n_heads, d_model // n_heads self.qkv = nn.Linear(d_model, d_model * 3, bias=False) self.out = nn.Linear(d_model, d_model, bias=False) def forward(self, x, cos, sin): B, L, D = x.shape qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.out(out.transpose(1, 2).reshape(B, L, D)) class SykoTransformerLayer(nn.Module): def __init__(self, d_model, n_heads, intermediate_size): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.attn = SykoAttention(d_model, n_heads) self.norm2 = nn.LayerNorm(d_model) self.mlp = nn.Sequential( nn.Linear(d_model, intermediate_size), nn.GELU(), nn.Dropout(0.0), nn.Linear(intermediate_size, d_model) ) def forward(self, x, cos, sin): x = x + self.attn(self.norm1(x), cos, sin) return x + self.mlp(self.norm2(x)) class SykoMemoryGate(nn.Module): def __init__(self, d_model): super().__init__() self.forget_linear = nn.Linear(d_model * 2, d_model) self.update_linear = nn.Linear(d_model, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, current_context, prev_memory): combined = torch.cat([current_context, prev_memory], dim=-1) forget_ratio = torch.sigmoid(self.forget_linear(combined)) new_candidate = torch.tanh(self.update_linear(current_context)) return self.norm((forget_ratio * prev_memory) + ((1 - forget_ratio) * new_candidate)) class SykoSLM(PreTrainedModel): config_class = SykoSLMConfig def __init__(self, config): super().__init__(config) self.mem_tokens = config.num_memory_tokens self.d_model = config.d_model pad_idx = getattr(config, "pad_token_id", 0) or 0 self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=pad_idx) self.mem_pos_emb = nn.Embedding(config.num_memory_tokens, config.d_model) self.rope = SykoRoPE(config.d_model // config.n_heads) self.layers = nn.ModuleList([ SykoTransformerLayer(config.d_model, config.n_heads, config.intermediate_size) for _ in range(config.n_layers) ]) self.final_norm = nn.LayerNorm(config.d_model) self.memory_gate = SykoMemoryGate(config.d_model) self.fc_out = nn.Linear(config.d_model, config.vocab_size) def forward(self, input_ids, prev_memory=None, chunk_start_idx=0, **kwargs): B = input_ids.size(0) if prev_memory is None: prev_memory = torch.zeros(B, self.mem_tokens, self.d_model, device=input_ids.device) x = self.embedding(input_ids) mem_idx = torch.arange(self.mem_tokens, device=input_ids.device) memory_with_pos = prev_memory + self.mem_pos_emb(mem_idx).unsqueeze(0) x_with_memory = torch.cat([memory_with_pos, x], dim=1) mem_pos = torch.zeros(self.mem_tokens, dtype=torch.long, device=input_ids.device) word_pos = torch.arange(chunk_start_idx, chunk_start_idx + x.size(1), device=input_ids.device) cos, sin = self.rope(torch.cat([mem_pos, word_pos])) for layer in self.layers: x_with_memory = layer(x_with_memory, cos, sin) x_with_memory = self.final_norm(x_with_memory) memory_output = x_with_memory[:, :self.mem_tokens, :] token_outputs = x_with_memory[:, self.mem_tokens:, :] return self.fc_out(token_outputs), self.memory_gate(memory_output, prev_memory) def generate_text(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50): self.eval() device = input_ids.device prev_memory = torch.zeros(1, self.mem_tokens, self.d_model, device=device) generated = input_ids.clone() with torch.no_grad(): for _ in range(max_new_tokens): chunk = generated[:, -self.config.chunk_size:] logits, prev_memory = self.forward(chunk, prev_memory) next_logits = logits[:, -1, :] / temperature top_k_vals, top_k_idx = torch.topk(next_logits, k=min(top_k, next_logits.size(-1))) filtered = torch.full_like(next_logits, float("-inf")) filtered.scatter_(1, top_k_idx, top_k_vals) next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1) generated = torch.cat([generated, next_token], dim=1) eos = getattr(self.config, "eos_token_id", None) if eos and next_token.item() == eos: break return generated