| """Rubi-RLM: 1B-class Recursive Language Model (RLM) prototype.
|
|
|
| Bu dosya, recursive düşünme + dual-loop öğrenme hedefiyle tasarlanmış bir
|
| araştırma prototipi içerir.
|
|
|
| Eklenen sohbet katmanı:
|
| - İngilizce/Türkçe çift dilli chat şablonu
|
| - HF tokenizer ile metin->id / id->metin köprüsü
|
| - Tek mesaj veya interaktif chat CLI
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import importlib
|
| import importlib.util
|
| from dataclasses import dataclass
|
| from typing import List, Optional, Protocol, Sequence, Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from rubi_train_stack import (
|
| TrainStackConfig,
|
| build_dataloader,
|
| build_dataset,
|
| build_optimizer,
|
| train_demo_steps,
|
| )
|
| from xqs_moe import build_deepspeed_moe
|
| from xqs_stack import choose_moe_backend, detect_xqs_backends, format_backend_report
|
| from x_quantum_sparse_ops import (
|
| build_linear,
|
| causal_scaled_dot_product_attention,
|
| fused_residual_add,
|
| maybe_compile_module,
|
| pack_rows,
|
| scatter_rows,
|
| )
|
|
|
|
|
| class TextTokenizer(Protocol):
|
| def encode(self, text: str, return_tensors: Optional[str] = None): ...
|
|
|
| def decode(self, token_ids: Sequence[int], skip_special_tokens: bool = True) -> str: ...
|
|
|
|
|
| @dataclass
|
| class ChatTurn:
|
| role: str
|
| content: str
|
|
|
|
|
| @dataclass
|
| class RLMConfig:
|
| vocab_size: int = 50_257
|
| max_seq_len: int = 2_048
|
| d_model: int = 2_048
|
| n_layers: int = 14
|
| n_heads: int = 16
|
| ff_mult: int = 4
|
| dropout: float = 0.1
|
| recurse_steps: int = 6
|
| critique_threshold: float = 0.20
|
| tie_embeddings: bool = True
|
| use_moe: bool = False
|
| moe_num_experts: int = 0
|
| moe_top_k: int = 2
|
| moe_expert_hidden: int = 0
|
| moe_router_jitter: float = 0.0
|
| moe_aux_loss_weight: float = 0.01
|
| use_layer_skip: bool = False
|
| layer_skip_threshold: float = 0.50
|
| layer_skip_target: float = 1.0
|
| layer_skip_aux_weight: float = 0.01
|
| use_ternary_weights: bool = False
|
| use_flash_attention: bool = False
|
| use_fused_ops: bool = False
|
| packed_execution: bool = False
|
| use_torch_compile: bool = False
|
| moe_backend: str = "auto"
|
| moe_ep_size: int = 1
|
|
|
| @classmethod
|
| def scale_1b(cls) -> "RLMConfig":
|
| return cls(
|
| vocab_size=50_257,
|
| max_seq_len=2_048,
|
| d_model=1_024,
|
| n_layers=10,
|
| n_heads=16,
|
| ff_mult=4,
|
| recurse_steps=6,
|
| critique_threshold=0.20,
|
| use_moe=True,
|
| moe_num_experts=32,
|
| moe_top_k=1,
|
| moe_expert_hidden=1_280,
|
| moe_router_jitter=0.01,
|
| moe_aux_loss_weight=0.01,
|
| use_layer_skip=True,
|
| layer_skip_threshold=0.80,
|
| layer_skip_target=0.03,
|
| layer_skip_aux_weight=0.01,
|
| use_ternary_weights=True,
|
| use_flash_attention=True,
|
| use_fused_ops=True,
|
| packed_execution=True,
|
| use_torch_compile=False,
|
| moe_backend="auto",
|
| moe_ep_size=1,
|
| )
|
|
|
|
|
| class RMSNorm(nn.Module):
|
| def __init__(self, d_model: int, eps: float = 1e-6):
|
| super().__init__()
|
| self.scale = nn.Parameter(torch.ones(d_model))
|
| self.eps = eps
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
|
| return self.scale * (x / rms)
|
|
|
|
|
| class DenseFeedForward(nn.Module):
|
| def __init__(self, cfg: RLMConfig):
|
| super().__init__()
|
| hidden = cfg.d_model * cfg.ff_mult
|
| self.up_proj = build_linear(cfg.d_model, hidden, ternary=cfg.use_ternary_weights)
|
| self.down_proj = build_linear(hidden, cfg.d_model, ternary=cfg.use_ternary_weights)
|
| self.dropout = nn.Dropout(cfg.dropout)
|
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| return self.dropout(self.down_proj(F.gelu(self.up_proj(x)))), x.new_zeros(())
|
|
|
|
|
| class FastSelfAttention(nn.Module):
|
| def __init__(self, cfg: RLMConfig):
|
| super().__init__()
|
| if cfg.d_model % cfg.n_heads != 0:
|
| raise ValueError("d_model must be divisible by n_heads.")
|
| self.n_heads = cfg.n_heads
|
| self.head_dim = cfg.d_model // cfg.n_heads
|
| self.dropout = cfg.dropout
|
| self.use_flash_attention = cfg.use_flash_attention
|
| self.q_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
|
| self.k_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
|
| self.v_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
|
| self.out_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
|
|
|
| def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| bsz, seq_len, _ = x.shape
|
| q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| attn_out = causal_scaled_dot_product_attention(
|
| q,
|
| k,
|
| v,
|
| dropout_p=self.dropout,
|
| training=self.training,
|
| )
|
| attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seq_len, self.n_heads * self.head_dim)
|
| return self.out_proj(attn_out)
|
|
|
|
|
| class MoEExpert(nn.Module):
|
| def __init__(self, d_model: int, hidden: int):
|
| super().__init__()
|
| self.up_proj = build_linear(d_model, hidden, ternary=True)
|
| self.down_proj = build_linear(hidden, d_model, ternary=True)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.down_proj(F.gelu(self.up_proj(x)))
|
|
|
|
|
| class MoEFeedForward(nn.Module):
|
| def __init__(self, cfg: RLMConfig):
|
| super().__init__()
|
| if cfg.moe_num_experts <= 0:
|
| raise ValueError("moe_num_experts must be positive when use_moe=True.")
|
| if cfg.moe_top_k <= 0 or cfg.moe_top_k > cfg.moe_num_experts:
|
| raise ValueError("moe_top_k must be in the range [1, moe_num_experts].")
|
|
|
| self.num_experts = cfg.moe_num_experts
|
| self.top_k = cfg.moe_top_k
|
| self.router_jitter = cfg.moe_router_jitter
|
| requested_backend = cfg.moe_backend.lower()
|
| self.backend = choose_moe_backend(prefer_deepspeed=requested_backend in {"auto", "deepspeed"}) if requested_backend != "native" else "native"
|
| self.router = build_linear(cfg.d_model, cfg.moe_num_experts, ternary=cfg.use_ternary_weights)
|
| self.experts = nn.ModuleList([MoEExpert(cfg.d_model, cfg.moe_expert_hidden) for _ in range(cfg.moe_num_experts)])
|
| self.deepspeed_moe = None
|
| if self.backend == "deepspeed":
|
| self.deepspeed_moe = build_deepspeed_moe(
|
| hidden_size=cfg.d_model,
|
| expert=MoEExpert(cfg.d_model, cfg.moe_expert_hidden),
|
| num_experts=cfg.moe_num_experts,
|
| top_k=cfg.moe_top_k,
|
| ep_size=cfg.moe_ep_size,
|
| )
|
| if self.deepspeed_moe is None:
|
| self.backend = "native"
|
| self.dropout = nn.Dropout(cfg.dropout)
|
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| if self.deepspeed_moe is not None:
|
| out, aux_loss = self.deepspeed_moe(x)
|
| return self.dropout(out), aux_loss
|
| flat_x = x.reshape(-1, x.size(-1))
|
| router_logits = self.router(flat_x)
|
| if self.training and self.router_jitter > 0:
|
| router_logits = router_logits + torch.randn_like(router_logits) * self.router_jitter
|
|
|
| router_probs = F.softmax(router_logits, dim=-1)
|
| topk_weights, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
|
| mixed = flat_x.new_zeros(flat_x.shape)
|
| expert_load = router_probs.new_zeros(self.num_experts)
|
|
|
| for expert_id, expert in enumerate(self.experts):
|
| expert_mask = topk_indices == expert_id
|
| if not expert_mask.any():
|
| continue
|
| token_indices, slot_indices = expert_mask.nonzero(as_tuple=True)
|
| expert_inputs = flat_x.index_select(0, token_indices)
|
| expert_outputs = expert(expert_inputs)
|
| weights = topk_weights[token_indices, slot_indices].unsqueeze(-1)
|
| mixed.index_add_(0, token_indices, expert_outputs * weights)
|
| expert_load[expert_id] = float(token_indices.numel())
|
|
|
| mixed = self.dropout(mixed.view_as(x))
|
| importance = router_probs.mean(dim=0)
|
| load = expert_load / max(1, flat_x.size(0) * self.top_k)
|
| aux_loss = self.num_experts * torch.sum(importance * load)
|
| return mixed, aux_loss
|
|
|
|
|
| class RecursiveBlock(nn.Module):
|
| def __init__(self, cfg: RLMConfig):
|
| super().__init__()
|
|
|
| self.use_layer_skip = cfg.use_layer_skip
|
| self.layer_skip_threshold = cfg.layer_skip_threshold
|
| self.layer_skip_target = cfg.layer_skip_target
|
| self.use_fused_ops = cfg.use_fused_ops
|
| self.packed_execution = cfg.packed_execution
|
| self.norm_attn = RMSNorm(cfg.d_model)
|
| self.norm_ff = RMSNorm(cfg.d_model)
|
| self.attn = FastSelfAttention(cfg)
|
| self.ffn = MoEFeedForward(cfg) if cfg.use_moe else DenseFeedForward(cfg)
|
| self.skip_router = build_linear(cfg.d_model, 1, ternary=cfg.use_ternary_weights) if cfg.use_layer_skip else None
|
|
|
| self.state_fuse = build_linear(cfg.d_model * 2, cfg.d_model, ternary=cfg.use_ternary_weights)
|
| self.state_update = build_linear(cfg.d_model, cfg.d_model, ternary=cfg.use_ternary_weights)
|
| self.state_gate = build_linear(cfg.d_model * 2, cfg.d_model, ternary=cfg.use_ternary_weights)
|
|
|
| def _run_core(
|
| self,
|
| x: torch.Tensor,
|
| state: torch.Tensor,
|
| attn_mask: Optional[torch.Tensor] = None,
|
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| x_norm = self.norm_attn(x)
|
| attn_out = self.attn(x_norm, attn_mask=attn_mask)
|
| fuse_input = torch.cat([attn_out, state], dim=-1)
|
| gate = torch.sigmoid(self.state_gate(fuse_input))
|
| fused = self.state_fuse(fuse_input)
|
| fused = gate * fused + (1.0 - gate) * state
|
| if self.use_fused_ops:
|
| x = fused_residual_add(x, fused)
|
| else:
|
| x = x + fused
|
| ff_out, moe_aux_loss = self.ffn(self.norm_ff(x))
|
| if self.use_fused_ops:
|
| x = fused_residual_add(x, ff_out)
|
| else:
|
| x = x + ff_out
|
| new_state = torch.tanh(self.state_update(x))
|
| return x, new_state, moe_aux_loss
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| state: torch.Tensor,
|
| attn_mask: Optional[torch.Tensor] = None,
|
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| exec_prob = x.new_ones((x.size(0),))
|
| skip_aux_loss = x.new_zeros(())
|
| if self.skip_router is None:
|
| x, new_state, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
|
| return x, new_state, moe_aux_loss, skip_aux_loss, exec_prob.mean()
|
|
|
| router_input = x.mean(dim=1)
|
| exec_prob = torch.sigmoid(self.skip_router(router_input)).squeeze(-1)
|
| target = exec_prob.new_full(exec_prob.shape, self.layer_skip_target)
|
| skip_aux_loss = F.mse_loss(exec_prob, target)
|
| hard_gate = exec_prob >= self.layer_skip_threshold
|
| if not torch.any(hard_gate):
|
| return x, state, x.new_zeros(()), skip_aux_loss, exec_prob.mean()
|
|
|
| if torch.all(hard_gate):
|
| x_exec, state_exec, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
|
| elif self.packed_execution:
|
| active_indices = torch.nonzero(hard_gate, as_tuple=False).squeeze(-1)
|
| x_active, state_active = pack_rows(active_indices, x, state)
|
| x_active, state_active, moe_aux_loss = self._run_core(x_active, state_active, attn_mask=attn_mask)
|
| x_exec = scatter_rows(x, active_indices, x_active)
|
| state_exec = scatter_rows(state, active_indices, state_active)
|
| else:
|
| x_exec, state_exec, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
|
|
|
| if self.training:
|
| exec_gate = exec_prob + (hard_gate.to(exec_prob.dtype) - exec_prob).detach()
|
| exec_scale = exec_gate.view(-1, 1, 1)
|
| x_exec = x + exec_scale * (x_exec - x)
|
| state_exec = state + exec_scale * (state_exec - state)
|
|
|
| return x_exec, state_exec, moe_aux_loss, skip_aux_loss, exec_prob.mean()
|
|
|
|
|
| class RubiRLM(nn.Module):
|
| def __init__(self, cfg: RLMConfig):
|
| super().__init__()
|
| self.cfg = cfg
|
| self._last_moe_aux_loss = torch.tensor(0.0)
|
| self._last_layer_skip_aux_loss = torch.tensor(0.0)
|
|
|
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
|
| self.drop = nn.Dropout(cfg.dropout)
|
|
|
| self.layers = nn.ModuleList([maybe_compile_module(RecursiveBlock(cfg), cfg.use_torch_compile) for _ in range(cfg.n_layers)])
|
| self.final_norm = RMSNorm(cfg.d_model)
|
|
|
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| if cfg.tie_embeddings:
|
| self.lm_head.weight = self.tok_emb.weight
|
|
|
| self.critique_head = nn.Sequential(
|
| nn.Linear(cfg.d_model, cfg.d_model // 2),
|
| nn.GELU(),
|
| nn.Linear(cfg.d_model // 2, 1),
|
| )
|
|
|
| def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
|
| return torch.triu(mask, diagonal=1)
|
|
|
| def _embed(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| bsz, seq_len = input_ids.shape
|
| if seq_len > self.cfg.max_seq_len:
|
| raise ValueError(f"Girdi uzunluğu max_seq_len={self.cfg.max_seq_len} sınırını aşıyor.")
|
| pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(bsz, seq_len)
|
| return self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
|
|
|
| def forward_recursive(
|
| self,
|
| input_ids: torch.Tensor,
|
| steps: Optional[int] = None,
|
| stop_on_critique: bool = True,
|
| return_trace: bool = False,
|
| ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
| steps = steps or self.cfg.recurse_steps
|
| x = self._embed(input_ids)
|
|
|
| bsz, seq_len, d_model = x.shape
|
| states = [x.new_zeros((bsz, seq_len, d_model)) for _ in range(self.cfg.n_layers)]
|
| mask = self._causal_mask(seq_len, x.device)
|
|
|
| logits_trace: List[torch.Tensor] = []
|
| critique_trace: List[torch.Tensor] = []
|
| moe_aux_total = x.new_zeros(())
|
| layer_skip_aux_total = x.new_zeros(())
|
|
|
| for _ in range(steps):
|
| h = x
|
| new_states = []
|
| for layer, st in zip(self.layers, states):
|
| h, st_new, moe_aux, skip_aux, _ = layer(h, st, attn_mask=mask)
|
| new_states.append(st_new)
|
| moe_aux_total = moe_aux_total + moe_aux
|
| layer_skip_aux_total = layer_skip_aux_total + skip_aux
|
| states = new_states
|
|
|
| h_norm = self.final_norm(h)
|
| logits = self.lm_head(h_norm)
|
| pooled = h_norm[:, -1, :]
|
| critique = torch.sigmoid(self.critique_head(pooled)).squeeze(-1)
|
|
|
| logits_trace.append(logits)
|
| critique_trace.append(critique)
|
| x = h
|
|
|
| if stop_on_critique and torch.all(critique < self.cfg.critique_threshold):
|
| break
|
|
|
| denom = max(1, len(logits_trace) * len(self.layers))
|
| self._last_moe_aux_loss = moe_aux_total / denom
|
| self._last_layer_skip_aux_loss = layer_skip_aux_total / denom
|
|
|
| final_logits = logits_trace[-1]
|
| if return_trace:
|
| return final_logits, logits_trace, critique_trace
|
| return final_logits, [], critique_trace
|
|
|
| def training_loss(
|
| self,
|
| input_ids: torch.Tensor,
|
| target_ids: torch.Tensor,
|
| steps: Optional[int] = None,
|
| alpha_iterative: float = 0.30,
|
| beta_correction: float = 0.10,
|
| ) -> torch.Tensor:
|
| final_logits, trace, critique = self.forward_recursive(
|
| input_ids, steps=steps, stop_on_critique=False, return_trace=True
|
| )
|
|
|
| final_loss = F.cross_entropy(
|
| final_logits.view(-1, final_logits.size(-1)),
|
| target_ids.view(-1),
|
| ignore_index=-100,
|
| )
|
|
|
| if trace:
|
| iterative = 0.0
|
| for logits in trace[:-1]:
|
| iterative = iterative + F.cross_entropy(
|
| logits.view(-1, logits.size(-1)),
|
| target_ids.view(-1),
|
| ignore_index=-100,
|
| )
|
| iterative = iterative / max(1, len(trace) - 1)
|
| else:
|
| iterative = final_loss.new_tensor(0.0)
|
|
|
| correction_bonus = 0.0
|
| if len(critique) > 1:
|
| start = critique[0].mean()
|
| end = critique[-1].mean()
|
| correction_bonus = torch.relu(end - start)
|
|
|
| total_loss = final_loss + alpha_iterative * iterative + beta_correction * correction_bonus
|
| if self.cfg.use_moe:
|
| total_loss = total_loss + self.cfg.moe_aux_loss_weight * self._last_moe_aux_loss
|
| if self.cfg.use_layer_skip:
|
| total_loss = total_loss + self.cfg.layer_skip_aux_weight * self._last_layer_skip_aux_loss
|
| return total_loss
|
|
|
| @torch.no_grad()
|
| def generate(
|
| self,
|
| input_ids: torch.Tensor,
|
| max_new_tokens: int = 64,
|
| temperature: float = 0.8,
|
| top_k: int = 50,
|
| steps: Optional[int] = None,
|
| ) -> torch.Tensor:
|
| self.eval()
|
| out = input_ids
|
|
|
| for _ in range(max_new_tokens):
|
| context = out[:, -self.cfg.max_seq_len :]
|
| logits, _, _ = self.forward_recursive(context, steps=steps, stop_on_critique=True, return_trace=False)
|
| next_logits = logits[:, -1, :] / max(temperature, 1e-5)
|
|
|
| if top_k > 0:
|
| values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
| cutoff = values[:, [-1]]
|
| next_logits = torch.where(next_logits < cutoff, torch.full_like(next_logits, -1e9), next_logits)
|
|
|
| probs = F.softmax(next_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
| out = torch.cat([out, next_token], dim=1)
|
|
|
| return out
|
|
|
| def generate_text(
|
| self,
|
| tokenizer: TextTokenizer,
|
| prompt: str,
|
| max_new_tokens: int = 128,
|
| temperature: float = 0.7,
|
| top_k: int = 50,
|
| steps: Optional[int] = None,
|
| device: Optional[torch.device] = None,
|
| ) -> str:
|
| device = device or next(self.parameters()).device
|
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| output_ids = self.generate(
|
| input_ids,
|
| max_new_tokens=max_new_tokens,
|
| temperature=temperature,
|
| top_k=top_k,
|
| steps=steps,
|
| )
|
| new_tokens = output_ids[0, input_ids.shape[1] :].tolist()
|
| return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
|
|
| def chat(
|
| self,
|
| tokenizer: TextTokenizer,
|
| history: List[ChatTurn],
|
| user_message: str,
|
| lang: str = "auto",
|
| max_new_tokens: int = 192,
|
| temperature: float = 0.7,
|
| top_k: int = 50,
|
| steps: Optional[int] = None,
|
| device: Optional[torch.device] = None,
|
| ) -> Tuple[str, List[ChatTurn]]:
|
| prompt = build_chat_prompt(history, user_message, lang=lang)
|
| assistant_reply = self.generate_text(
|
| tokenizer=tokenizer,
|
| prompt=prompt,
|
| max_new_tokens=max_new_tokens,
|
| temperature=temperature,
|
| top_k=top_k,
|
| steps=steps,
|
| device=device,
|
| )
|
| updated = history + [ChatTurn(role="user", content=user_message), ChatTurn(role="assistant", content=assistant_reply)]
|
| return assistant_reply, updated
|
|
|
| def outer_sleep_phase_step(
|
| self,
|
| optimizer: torch.optim.Optimizer,
|
| input_ids: torch.Tensor,
|
| target_ids: torch.Tensor,
|
| steps: Optional[int] = None,
|
| ) -> float:
|
| self.train()
|
| optimizer.zero_grad(set_to_none=True)
|
| loss = self.training_loss(input_ids, target_ids, steps=steps)
|
| loss.backward()
|
| nn.utils.clip_grad_norm_(self.parameters(), 1.0)
|
| optimizer.step()
|
| return float(loss.detach().item())
|
|
|
|
|
| def estimate_parameters(cfg: RLMConfig) -> int:
|
| d = cfg.d_model
|
| total = cfg.vocab_size * d + cfg.max_seq_len * d
|
| attn_params = (4 * d * d) + (4 * d)
|
| state_params = (5 * d * d) + (3 * d)
|
| router_params = 0
|
| layer_skip_params = 0
|
| ff_params = (2 * d * d * cfg.ff_mult) + (d * cfg.ff_mult) + d
|
| if cfg.use_moe:
|
| router_params = (d * cfg.moe_num_experts) + cfg.moe_num_experts
|
| expert_params = (2 * d * cfg.moe_expert_hidden) + cfg.moe_expert_hidden + d
|
| ff_params = cfg.moe_num_experts * expert_params
|
| if cfg.use_layer_skip:
|
| layer_skip_params = d + 1
|
| per_layer = attn_params + state_params + router_params + layer_skip_params + ff_params + (2 * d)
|
| total += cfg.n_layers * per_layer
|
| total += d * (d // 2) + (d // 2) + (d // 2) + 1 + d
|
| if not cfg.tie_embeddings:
|
| total += d * cfg.vocab_size
|
| return total
|
|
|
|
|
| def estimate_active_parameters(cfg: RLMConfig) -> int:
|
| d = cfg.d_model
|
| total = cfg.vocab_size * d + cfg.max_seq_len * d
|
| attn_params = (4 * d * d) + (4 * d)
|
| state_params = (5 * d * d) + (3 * d)
|
| router_params = 0
|
| layer_skip_params = 0
|
| ff_params = (2 * d * d * cfg.ff_mult) + (d * cfg.ff_mult) + d
|
| if cfg.use_moe:
|
| router_params = (d * cfg.moe_num_experts) + cfg.moe_num_experts
|
| expert_params = (2 * d * cfg.moe_expert_hidden) + cfg.moe_expert_hidden + d
|
| ff_params = cfg.moe_top_k * expert_params
|
| if cfg.use_layer_skip:
|
| layer_skip_params = d + 1
|
| routed_layer = attn_params + state_params + router_params + ff_params + (2 * d)
|
| routed_layer = cfg.layer_skip_target * routed_layer
|
| per_layer = layer_skip_params + routed_layer
|
| total += cfg.n_layers * per_layer
|
| total += d * (d // 2) + (d // 2) + (d // 2) + 1 + d
|
| if not cfg.tie_embeddings:
|
| total += d * cfg.vocab_size
|
| return int(total)
|
|
|
|
|
| def language_system_prompt(lang: str) -> str:
|
| base = (
|
| "You are Rubi-RLM assistant. Reason step-by-step internally, be concise in final answer, "
|
| "self-correct if needed."
|
| )
|
| if lang == "tr":
|
| return base + " Yanıtlarını Türkçe ver."
|
| if lang == "en":
|
| return base + " Reply in English."
|
| return base + " Reply in the user's language (Turkish or English)."
|
|
|
|
|
| def build_chat_prompt(history: List[ChatTurn], user_message: str, lang: str = "auto") -> str:
|
| lines = [f"<|system|>\n{language_system_prompt(lang)}"]
|
| for turn in history:
|
| role = "user" if turn.role.lower() == "user" else "assistant"
|
| lines.append(f"<|{role}|>\n{turn.content}")
|
| lines.append(f"\n{user_message}")
|
| lines.append("<|assistant|>\n")
|
| return "\n".join(lines)
|
|
|
|
|
| def load_hf_tokenizer(tokenizer_name: str):
|
| if importlib.util.find_spec("transformers") is None:
|
| raise RuntimeError("transformers yüklü değil. `pip install transformers` ile kurun.")
|
| transformers = importlib.import_module("transformers")
|
| tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
|
| if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
| return tokenizer
|
|
|
|
|
| def demo() -> None:
|
| cfg = RLMConfig(
|
| vocab_size=4096,
|
| max_seq_len=128,
|
| d_model=256,
|
| n_layers=4,
|
| n_heads=8,
|
| ff_mult=4,
|
| recurse_steps=4,
|
| use_moe=True,
|
| moe_num_experts=8,
|
| moe_top_k=2,
|
| moe_expert_hidden=384,
|
| )
|
| model = RubiRLM(cfg)
|
| x = torch.randint(0, cfg.vocab_size, (2, 32))
|
| y = torch.randint(0, cfg.vocab_size, (2, 32))
|
|
|
| loss = model.training_loss(x, y)
|
| print(f"demo_loss={loss.item():.4f}")
|
|
|
| out = model.generate(x[:, :8], max_new_tokens=8, steps=3)
|
| print("generated_shape=", tuple(out.shape))
|
|
|
|
|
| def resolve_config(scale: str) -> RLMConfig:
|
| if scale == "1b":
|
| return RLMConfig.scale_1b()
|
| return RLMConfig(d_model=512, n_layers=8, n_heads=8, vocab_size=50_257, max_seq_len=512)
|
|
|
|
|
| def runtime_torch_compile_available() -> bool:
|
| if not hasattr(torch, "compile"):
|
| return False
|
| if torch.cuda.is_available() and importlib.util.find_spec("triton") is None:
|
| return False
|
| return True
|
|
|
|
|
| def apply_runtime_config_overrides(cfg: RLMConfig, args: argparse.Namespace) -> RLMConfig:
|
| cfg.moe_backend = getattr(args, "moe_backend", cfg.moe_backend)
|
| cfg.moe_ep_size = getattr(args, "moe_ep_size", cfg.moe_ep_size)
|
| requested_compile = bool(getattr(args, "use_torch_compile", cfg.use_torch_compile))
|
| cfg.use_torch_compile = requested_compile and runtime_torch_compile_available()
|
| return cfg
|
|
|
|
|
| def maybe_load_checkpoint(model: RubiRLM, checkpoint: Optional[str], device: torch.device) -> None:
|
| if not checkpoint:
|
| return
|
| state = torch.load(checkpoint, map_location=device)
|
| if isinstance(state, dict) and "model_state_dict" in state:
|
| model.load_state_dict(state["model_state_dict"])
|
| return
|
| model.load_state_dict(state)
|
|
|
|
|
| def run_single_chat(args: argparse.Namespace) -> None:
|
| cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| model = RubiRLM(cfg).to(device)
|
| maybe_load_checkpoint(model, args.checkpoint, device)
|
| tokenizer = load_hf_tokenizer(args.tokenizer_name)
|
|
|
| history: List[ChatTurn] = []
|
| if args.interactive:
|
| print("Interactive chat başladı. Çıkmak için /exit yaz.")
|
| while True:
|
| user_msg = input("You> ").strip()
|
| if not user_msg:
|
| continue
|
| if user_msg.lower() in {"/exit", "exit", "quit"}:
|
| break
|
| reply, history = model.chat(
|
| tokenizer=tokenizer,
|
| history=history,
|
| user_message=user_msg,
|
| lang=args.lang,
|
| max_new_tokens=args.max_new_tokens,
|
| temperature=args.temperature,
|
| top_k=args.top_k,
|
| steps=args.steps,
|
| device=device,
|
| )
|
| print(f"Rubi> {reply}")
|
| return
|
|
|
| if not args.prompt:
|
| raise ValueError("--chat modunda --prompt veya --interactive gerekli.")
|
|
|
| reply, _ = model.chat(
|
| tokenizer=tokenizer,
|
| history=[],
|
| user_message=args.prompt,
|
| lang=args.lang,
|
| max_new_tokens=args.max_new_tokens,
|
| temperature=args.temperature,
|
| top_k=args.top_k,
|
| steps=args.steps,
|
| device=device,
|
| )
|
| print(reply)
|
|
|
|
|
| def print_stack_report() -> None:
|
| report = detect_xqs_backends()
|
| print(format_backend_report(report))
|
|
|
|
|
| def run_train_demo(args: argparse.Namespace) -> None:
|
| cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| model = RubiRLM(cfg).to(device)
|
| maybe_load_checkpoint(model, args.checkpoint, device)
|
|
|
| train_cfg = TrainStackConfig(
|
| optimizer_name=args.optimizer_name,
|
| learning_rate=args.learning_rate,
|
| weight_decay=args.weight_decay,
|
| batch_size=args.batch_size,
|
| num_workers=args.num_workers,
|
| pin_memory=not args.disable_pin_memory,
|
| prefetch_factor=args.prefetch_factor,
|
| persistent_workers=not args.disable_persistent_workers,
|
| max_seq_len=cfg.max_seq_len,
|
| dataset_dir=args.dataset_dir,
|
| use_bf16=not args.disable_bf16,
|
| )
|
| dataset = build_dataset(
|
| dataset_dir=train_cfg.dataset_dir,
|
| vocab_size=cfg.vocab_size,
|
| max_seq_len=min(cfg.max_seq_len, args.train_seq_len),
|
| synthetic_samples=max(args.train_steps * args.batch_size * 2, 32),
|
| )
|
| dataloader = build_dataloader(dataset, train_cfg, shuffle=True)
|
| optimizer = build_optimizer(model, train_cfg)
|
| mean_loss, total_tokens = train_demo_steps(
|
| model=model,
|
| optimizer=optimizer,
|
| dataloader=dataloader,
|
| device=device,
|
| steps=args.train_steps,
|
| use_bf16=train_cfg.use_bf16,
|
| )
|
| print(
|
| f"train_demo optimizer={optimizer.__class__.__name__} steps={args.train_steps} "
|
| f"mean_loss={mean_loss:.4f} tokens={total_tokens:,} device={device}"
|
| )
|
|
|
|
|
| def main() -> None:
|
| parser = argparse.ArgumentParser(description="Rubi-RLM recursive language model")
|
| parser.add_argument("--scale", choices=["1b", "tiny"], default="1b")
|
| parser.add_argument("--estimate-only", action="store_true")
|
| parser.add_argument("--demo", action="store_true")
|
| parser.add_argument("--train-demo", action="store_true")
|
| parser.add_argument("--stack-report", action="store_true")
|
|
|
| parser.add_argument("--chat", action="store_true", help="Türkçe/İngilizce sohbet modunu açar")
|
| parser.add_argument("--interactive", action="store_true", help="Interactive chat loop")
|
| parser.add_argument("--prompt", type=str, default="")
|
| parser.add_argument("--lang", choices=["auto", "tr", "en"], default="auto")
|
| parser.add_argument("--tokenizer-name", type=str, default="gpt2")
|
| parser.add_argument("--checkpoint", type=str, default=None)
|
| parser.add_argument("--steps", type=int, default=None)
|
| parser.add_argument("--max-new-tokens", type=int, default=192)
|
| parser.add_argument("--temperature", type=float, default=0.7)
|
| parser.add_argument("--top-k", type=int, default=50)
|
| parser.add_argument("--optimizer-name", type=str, default="auto")
|
| parser.add_argument("--moe-backend", choices=["auto", "native", "deepspeed"], default="auto")
|
| parser.add_argument("--moe-ep-size", type=int, default=1)
|
| parser.add_argument("--use-torch-compile", action="store_true")
|
| parser.add_argument("--learning-rate", type=float, default=3e-4)
|
| parser.add_argument("--weight-decay", type=float, default=0.01)
|
| parser.add_argument("--batch-size", type=int, default=2)
|
| parser.add_argument("--num-workers", type=int, default=2)
|
| parser.add_argument("--prefetch-factor", type=int, default=4)
|
| parser.add_argument("--dataset-dir", type=str, default="")
|
| parser.add_argument("--train-steps", type=int, default=2)
|
| parser.add_argument("--train-seq-len", type=int, default=256)
|
| parser.add_argument("--disable-pin-memory", action="store_true")
|
| parser.add_argument("--disable-persistent-workers", action="store_true")
|
| parser.add_argument("--disable-bf16", action="store_true")
|
| args = parser.parse_args()
|
|
|
| if args.chat:
|
| run_single_chat(args)
|
| return
|
|
|
| if args.stack_report:
|
| print_stack_report()
|
| return
|
|
|
| if args.train_demo:
|
| run_train_demo(args)
|
| return
|
|
|
| if args.demo:
|
| demo()
|
| return
|
|
|
| cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
|
| n_params = estimate_parameters(cfg)
|
| active_params = estimate_active_parameters(cfg)
|
| print(f"Scale={args.scale}, estimated_params={n_params:,}, estimated_active_params={active_params:,}")
|
| if not args.estimate_only:
|
| model = RubiRLM(cfg)
|
| actual = sum(p.numel() for p in model.parameters())
|
| print(f"actual_params={actual:,}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|