"""Rubi-RLM: 1B-class Recursive Language Model (RLM) prototype. Bu dosya, recursive düşünme + dual-loop öğrenme hedefiyle tasarlanmış bir araÅŸtırma prototipi içerir. Eklenen sohbet katmanı: - İngilizce/Türkçe çift dilli chat ÅŸablonu - HF tokenizer ile metin->id / id->metin köprüsü - Tek mesaj veya interaktif chat CLI """ from __future__ import annotations import argparse import importlib import importlib.util from dataclasses import dataclass from typing import List, Optional, Protocol, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F from rubi_train_stack import ( TrainStackConfig, build_dataloader, build_dataset, build_optimizer, train_demo_steps, ) from xqs_moe import build_deepspeed_moe from xqs_stack import choose_moe_backend, detect_xqs_backends, format_backend_report from x_quantum_sparse_ops import ( build_linear, causal_scaled_dot_product_attention, fused_residual_add, maybe_compile_module, pack_rows, scatter_rows, ) class TextTokenizer(Protocol): def encode(self, text: str, return_tensors: Optional[str] = None): ... def decode(self, token_ids: Sequence[int], skip_special_tokens: bool = True) -> str: ... @dataclass class ChatTurn: role: str content: str @dataclass class RLMConfig: vocab_size: int = 50_257 max_seq_len: int = 2_048 d_model: int = 2_048 n_layers: int = 14 n_heads: int = 16 ff_mult: int = 4 dropout: float = 0.1 recurse_steps: int = 6 critique_threshold: float = 0.20 tie_embeddings: bool = True use_moe: bool = False moe_num_experts: int = 0 moe_top_k: int = 2 moe_expert_hidden: int = 0 moe_router_jitter: float = 0.0 moe_aux_loss_weight: float = 0.01 use_layer_skip: bool = False layer_skip_threshold: float = 0.50 layer_skip_target: float = 1.0 layer_skip_aux_weight: float = 0.01 use_ternary_weights: bool = False use_flash_attention: bool = False use_fused_ops: bool = False packed_execution: bool = False use_torch_compile: bool = False moe_backend: str = "auto" moe_ep_size: int = 1 @classmethod def scale_1b(cls) -> "RLMConfig": return cls( vocab_size=50_257, max_seq_len=2_048, d_model=1_024, n_layers=10, n_heads=16, ff_mult=4, recurse_steps=6, critique_threshold=0.20, use_moe=True, moe_num_experts=32, moe_top_k=1, moe_expert_hidden=1_280, moe_router_jitter=0.01, moe_aux_loss_weight=0.01, use_layer_skip=True, layer_skip_threshold=0.80, layer_skip_target=0.03, layer_skip_aux_weight=0.01, use_ternary_weights=True, use_flash_attention=True, use_fused_ops=True, packed_execution=True, use_torch_compile=False, moe_backend="auto", moe_ep_size=1, ) class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.scale = nn.Parameter(torch.ones(d_model)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() return self.scale * (x / rms) class DenseFeedForward(nn.Module): def __init__(self, cfg: RLMConfig): super().__init__() hidden = cfg.d_model * cfg.ff_mult self.up_proj = build_linear(cfg.d_model, hidden, ternary=cfg.use_ternary_weights) self.down_proj = build_linear(hidden, cfg.d_model, ternary=cfg.use_ternary_weights) self.dropout = nn.Dropout(cfg.dropout) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return self.dropout(self.down_proj(F.gelu(self.up_proj(x)))), x.new_zeros(()) class FastSelfAttention(nn.Module): def __init__(self, cfg: RLMConfig): super().__init__() if cfg.d_model % cfg.n_heads != 0: raise ValueError("d_model must be divisible by n_heads.") self.n_heads = cfg.n_heads self.head_dim = cfg.d_model // cfg.n_heads self.dropout = cfg.dropout self.use_flash_attention = cfg.use_flash_attention self.q_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights) self.k_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights) self.v_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights) self.out_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights) def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: bsz, seq_len, _ = x.shape q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) attn_out = causal_scaled_dot_product_attention( q, k, v, dropout_p=self.dropout, training=self.training, ) attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seq_len, self.n_heads * self.head_dim) return self.out_proj(attn_out) class MoEExpert(nn.Module): def __init__(self, d_model: int, hidden: int): super().__init__() self.up_proj = build_linear(d_model, hidden, ternary=True) self.down_proj = build_linear(hidden, d_model, ternary=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.gelu(self.up_proj(x))) class MoEFeedForward(nn.Module): def __init__(self, cfg: RLMConfig): super().__init__() if cfg.moe_num_experts <= 0: raise ValueError("moe_num_experts must be positive when use_moe=True.") if cfg.moe_top_k <= 0 or cfg.moe_top_k > cfg.moe_num_experts: raise ValueError("moe_top_k must be in the range [1, moe_num_experts].") self.num_experts = cfg.moe_num_experts self.top_k = cfg.moe_top_k self.router_jitter = cfg.moe_router_jitter requested_backend = cfg.moe_backend.lower() self.backend = choose_moe_backend(prefer_deepspeed=requested_backend in {"auto", "deepspeed"}) if requested_backend != "native" else "native" self.router = build_linear(cfg.d_model, cfg.moe_num_experts, ternary=cfg.use_ternary_weights) self.experts = nn.ModuleList([MoEExpert(cfg.d_model, cfg.moe_expert_hidden) for _ in range(cfg.moe_num_experts)]) self.deepspeed_moe = None if self.backend == "deepspeed": self.deepspeed_moe = build_deepspeed_moe( hidden_size=cfg.d_model, expert=MoEExpert(cfg.d_model, cfg.moe_expert_hidden), num_experts=cfg.moe_num_experts, top_k=cfg.moe_top_k, ep_size=cfg.moe_ep_size, ) if self.deepspeed_moe is None: self.backend = "native" self.dropout = nn.Dropout(cfg.dropout) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.deepspeed_moe is not None: out, aux_loss = self.deepspeed_moe(x) return self.dropout(out), aux_loss flat_x = x.reshape(-1, x.size(-1)) router_logits = self.router(flat_x) if self.training and self.router_jitter > 0: router_logits = router_logits + torch.randn_like(router_logits) * self.router_jitter router_probs = F.softmax(router_logits, dim=-1) topk_weights, topk_indices = torch.topk(router_probs, self.top_k, dim=-1) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) mixed = flat_x.new_zeros(flat_x.shape) expert_load = router_probs.new_zeros(self.num_experts) for expert_id, expert in enumerate(self.experts): expert_mask = topk_indices == expert_id if not expert_mask.any(): continue token_indices, slot_indices = expert_mask.nonzero(as_tuple=True) expert_inputs = flat_x.index_select(0, token_indices) expert_outputs = expert(expert_inputs) weights = topk_weights[token_indices, slot_indices].unsqueeze(-1) mixed.index_add_(0, token_indices, expert_outputs * weights) expert_load[expert_id] = float(token_indices.numel()) mixed = self.dropout(mixed.view_as(x)) importance = router_probs.mean(dim=0) load = expert_load / max(1, flat_x.size(0) * self.top_k) aux_loss = self.num_experts * torch.sum(importance * load) return mixed, aux_loss class RecursiveBlock(nn.Module): def __init__(self, cfg: RLMConfig): super().__init__() self.use_layer_skip = cfg.use_layer_skip self.layer_skip_threshold = cfg.layer_skip_threshold self.layer_skip_target = cfg.layer_skip_target self.use_fused_ops = cfg.use_fused_ops self.packed_execution = cfg.packed_execution self.norm_attn = RMSNorm(cfg.d_model) self.norm_ff = RMSNorm(cfg.d_model) self.attn = FastSelfAttention(cfg) self.ffn = MoEFeedForward(cfg) if cfg.use_moe else DenseFeedForward(cfg) self.skip_router = build_linear(cfg.d_model, 1, ternary=cfg.use_ternary_weights) if cfg.use_layer_skip else None self.state_fuse = build_linear(cfg.d_model * 2, cfg.d_model, ternary=cfg.use_ternary_weights) self.state_update = build_linear(cfg.d_model, cfg.d_model, ternary=cfg.use_ternary_weights) self.state_gate = build_linear(cfg.d_model * 2, cfg.d_model, ternary=cfg.use_ternary_weights) def _run_core( self, x: torch.Tensor, state: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x_norm = self.norm_attn(x) attn_out = self.attn(x_norm, attn_mask=attn_mask) fuse_input = torch.cat([attn_out, state], dim=-1) gate = torch.sigmoid(self.state_gate(fuse_input)) fused = self.state_fuse(fuse_input) fused = gate * fused + (1.0 - gate) * state if self.use_fused_ops: x = fused_residual_add(x, fused) else: x = x + fused ff_out, moe_aux_loss = self.ffn(self.norm_ff(x)) if self.use_fused_ops: x = fused_residual_add(x, ff_out) else: x = x + ff_out new_state = torch.tanh(self.state_update(x)) return x, new_state, moe_aux_loss def forward( self, x: torch.Tensor, state: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: exec_prob = x.new_ones((x.size(0),)) skip_aux_loss = x.new_zeros(()) if self.skip_router is None: x, new_state, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask) return x, new_state, moe_aux_loss, skip_aux_loss, exec_prob.mean() router_input = x.mean(dim=1) exec_prob = torch.sigmoid(self.skip_router(router_input)).squeeze(-1) target = exec_prob.new_full(exec_prob.shape, self.layer_skip_target) skip_aux_loss = F.mse_loss(exec_prob, target) hard_gate = exec_prob >= self.layer_skip_threshold if not torch.any(hard_gate): return x, state, x.new_zeros(()), skip_aux_loss, exec_prob.mean() if torch.all(hard_gate): x_exec, state_exec, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask) elif self.packed_execution: active_indices = torch.nonzero(hard_gate, as_tuple=False).squeeze(-1) x_active, state_active = pack_rows(active_indices, x, state) x_active, state_active, moe_aux_loss = self._run_core(x_active, state_active, attn_mask=attn_mask) x_exec = scatter_rows(x, active_indices, x_active) state_exec = scatter_rows(state, active_indices, state_active) else: x_exec, state_exec, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask) if self.training: exec_gate = exec_prob + (hard_gate.to(exec_prob.dtype) - exec_prob).detach() exec_scale = exec_gate.view(-1, 1, 1) x_exec = x + exec_scale * (x_exec - x) state_exec = state + exec_scale * (state_exec - state) return x_exec, state_exec, moe_aux_loss, skip_aux_loss, exec_prob.mean() class RubiRLM(nn.Module): def __init__(self, cfg: RLMConfig): super().__init__() self.cfg = cfg self._last_moe_aux_loss = torch.tensor(0.0) self._last_layer_skip_aux_loss = torch.tensor(0.0) self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) self.drop = nn.Dropout(cfg.dropout) self.layers = nn.ModuleList([maybe_compile_module(RecursiveBlock(cfg), cfg.use_torch_compile) for _ in range(cfg.n_layers)]) self.final_norm = RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) if cfg.tie_embeddings: self.lm_head.weight = self.tok_emb.weight self.critique_head = nn.Sequential( nn.Linear(cfg.d_model, cfg.d_model // 2), nn.GELU(), nn.Linear(cfg.d_model // 2, 1), ) def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: mask = torch.full((seq_len, seq_len), float("-inf"), device=device) return torch.triu(mask, diagonal=1) def _embed(self, input_ids: torch.Tensor) -> torch.Tensor: bsz, seq_len = input_ids.shape if seq_len > self.cfg.max_seq_len: raise ValueError(f"Girdi uzunluÄŸu max_seq_len={self.cfg.max_seq_len} sınırını aşıyor.") pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(bsz, seq_len) return self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) def forward_recursive( self, input_ids: torch.Tensor, steps: Optional[int] = None, stop_on_critique: bool = True, return_trace: bool = False, ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: steps = steps or self.cfg.recurse_steps x = self._embed(input_ids) bsz, seq_len, d_model = x.shape states = [x.new_zeros((bsz, seq_len, d_model)) for _ in range(self.cfg.n_layers)] mask = self._causal_mask(seq_len, x.device) logits_trace: List[torch.Tensor] = [] critique_trace: List[torch.Tensor] = [] moe_aux_total = x.new_zeros(()) layer_skip_aux_total = x.new_zeros(()) for _ in range(steps): h = x new_states = [] for layer, st in zip(self.layers, states): h, st_new, moe_aux, skip_aux, _ = layer(h, st, attn_mask=mask) new_states.append(st_new) moe_aux_total = moe_aux_total + moe_aux layer_skip_aux_total = layer_skip_aux_total + skip_aux states = new_states h_norm = self.final_norm(h) logits = self.lm_head(h_norm) pooled = h_norm[:, -1, :] critique = torch.sigmoid(self.critique_head(pooled)).squeeze(-1) logits_trace.append(logits) critique_trace.append(critique) x = h if stop_on_critique and torch.all(critique < self.cfg.critique_threshold): break denom = max(1, len(logits_trace) * len(self.layers)) self._last_moe_aux_loss = moe_aux_total / denom self._last_layer_skip_aux_loss = layer_skip_aux_total / denom final_logits = logits_trace[-1] if return_trace: return final_logits, logits_trace, critique_trace return final_logits, [], critique_trace def training_loss( self, input_ids: torch.Tensor, target_ids: torch.Tensor, steps: Optional[int] = None, alpha_iterative: float = 0.30, beta_correction: float = 0.10, ) -> torch.Tensor: final_logits, trace, critique = self.forward_recursive( input_ids, steps=steps, stop_on_critique=False, return_trace=True ) final_loss = F.cross_entropy( final_logits.view(-1, final_logits.size(-1)), target_ids.view(-1), ignore_index=-100, ) if trace: iterative = 0.0 for logits in trace[:-1]: iterative = iterative + F.cross_entropy( logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=-100, ) iterative = iterative / max(1, len(trace) - 1) else: iterative = final_loss.new_tensor(0.0) correction_bonus = 0.0 if len(critique) > 1: start = critique[0].mean() end = critique[-1].mean() correction_bonus = torch.relu(end - start) total_loss = final_loss + alpha_iterative * iterative + beta_correction * correction_bonus if self.cfg.use_moe: total_loss = total_loss + self.cfg.moe_aux_loss_weight * self._last_moe_aux_loss if self.cfg.use_layer_skip: total_loss = total_loss + self.cfg.layer_skip_aux_weight * self._last_layer_skip_aux_loss return total_loss @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 64, temperature: float = 0.8, top_k: int = 50, steps: Optional[int] = None, ) -> torch.Tensor: self.eval() out = input_ids for _ in range(max_new_tokens): context = out[:, -self.cfg.max_seq_len :] logits, _, _ = self.forward_recursive(context, steps=steps, stop_on_critique=True, return_trace=False) next_logits = logits[:, -1, :] / max(temperature, 1e-5) if top_k > 0: values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) cutoff = values[:, [-1]] next_logits = torch.where(next_logits < cutoff, torch.full_like(next_logits, -1e9), next_logits) probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) out = torch.cat([out, next_token], dim=1) return out def generate_text( self, tokenizer: TextTokenizer, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 50, steps: Optional[int] = None, device: Optional[torch.device] = None, ) -> str: device = device or next(self.parameters()).device input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) output_ids = self.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, steps=steps, ) new_tokens = output_ids[0, input_ids.shape[1] :].tolist() return tokenizer.decode(new_tokens, skip_special_tokens=True).strip() def chat( self, tokenizer: TextTokenizer, history: List[ChatTurn], user_message: str, lang: str = "auto", max_new_tokens: int = 192, temperature: float = 0.7, top_k: int = 50, steps: Optional[int] = None, device: Optional[torch.device] = None, ) -> Tuple[str, List[ChatTurn]]: prompt = build_chat_prompt(history, user_message, lang=lang) assistant_reply = self.generate_text( tokenizer=tokenizer, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, steps=steps, device=device, ) updated = history + [ChatTurn(role="user", content=user_message), ChatTurn(role="assistant", content=assistant_reply)] return assistant_reply, updated def outer_sleep_phase_step( self, optimizer: torch.optim.Optimizer, input_ids: torch.Tensor, target_ids: torch.Tensor, steps: Optional[int] = None, ) -> float: self.train() optimizer.zero_grad(set_to_none=True) loss = self.training_loss(input_ids, target_ids, steps=steps) loss.backward() nn.utils.clip_grad_norm_(self.parameters(), 1.0) optimizer.step() return float(loss.detach().item()) def estimate_parameters(cfg: RLMConfig) -> int: d = cfg.d_model total = cfg.vocab_size * d + cfg.max_seq_len * d attn_params = (4 * d * d) + (4 * d) state_params = (5 * d * d) + (3 * d) router_params = 0 layer_skip_params = 0 ff_params = (2 * d * d * cfg.ff_mult) + (d * cfg.ff_mult) + d if cfg.use_moe: router_params = (d * cfg.moe_num_experts) + cfg.moe_num_experts expert_params = (2 * d * cfg.moe_expert_hidden) + cfg.moe_expert_hidden + d ff_params = cfg.moe_num_experts * expert_params if cfg.use_layer_skip: layer_skip_params = d + 1 per_layer = attn_params + state_params + router_params + layer_skip_params + ff_params + (2 * d) total += cfg.n_layers * per_layer total += d * (d // 2) + (d // 2) + (d // 2) + 1 + d if not cfg.tie_embeddings: total += d * cfg.vocab_size return total def estimate_active_parameters(cfg: RLMConfig) -> int: d = cfg.d_model total = cfg.vocab_size * d + cfg.max_seq_len * d attn_params = (4 * d * d) + (4 * d) state_params = (5 * d * d) + (3 * d) router_params = 0 layer_skip_params = 0 ff_params = (2 * d * d * cfg.ff_mult) + (d * cfg.ff_mult) + d if cfg.use_moe: router_params = (d * cfg.moe_num_experts) + cfg.moe_num_experts expert_params = (2 * d * cfg.moe_expert_hidden) + cfg.moe_expert_hidden + d ff_params = cfg.moe_top_k * expert_params if cfg.use_layer_skip: layer_skip_params = d + 1 routed_layer = attn_params + state_params + router_params + ff_params + (2 * d) routed_layer = cfg.layer_skip_target * routed_layer per_layer = layer_skip_params + routed_layer total += cfg.n_layers * per_layer total += d * (d // 2) + (d // 2) + (d // 2) + 1 + d if not cfg.tie_embeddings: total += d * cfg.vocab_size return int(total) def language_system_prompt(lang: str) -> str: base = ( "You are Rubi-RLM assistant. Reason step-by-step internally, be concise in final answer, " "self-correct if needed." ) if lang == "tr": return base + " Yanıtlarını Türkçe ver." if lang == "en": return base + " Reply in English." return base + " Reply in the user's language (Turkish or English)." def build_chat_prompt(history: List[ChatTurn], user_message: str, lang: str = "auto") -> str: lines = [f"<|system|>\n{language_system_prompt(lang)}"] for turn in history: role = "user" if turn.role.lower() == "user" else "assistant" lines.append(f"<|{role}|>\n{turn.content}") lines.append(f"\n{user_message}") lines.append("<|assistant|>\n") return "\n".join(lines) def load_hf_tokenizer(tokenizer_name: str): if importlib.util.find_spec("transformers") is None: raise RuntimeError("transformers yüklü deÄŸil. `pip install transformers` ile kurun.") transformers = importlib.import_module("transformers") tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def demo() -> None: cfg = RLMConfig( vocab_size=4096, max_seq_len=128, d_model=256, n_layers=4, n_heads=8, ff_mult=4, recurse_steps=4, use_moe=True, moe_num_experts=8, moe_top_k=2, moe_expert_hidden=384, ) model = RubiRLM(cfg) x = torch.randint(0, cfg.vocab_size, (2, 32)) y = torch.randint(0, cfg.vocab_size, (2, 32)) loss = model.training_loss(x, y) print(f"demo_loss={loss.item():.4f}") out = model.generate(x[:, :8], max_new_tokens=8, steps=3) print("generated_shape=", tuple(out.shape)) def resolve_config(scale: str) -> RLMConfig: if scale == "1b": return RLMConfig.scale_1b() return RLMConfig(d_model=512, n_layers=8, n_heads=8, vocab_size=50_257, max_seq_len=512) def runtime_torch_compile_available() -> bool: if not hasattr(torch, "compile"): return False if torch.cuda.is_available() and importlib.util.find_spec("triton") is None: return False return True def apply_runtime_config_overrides(cfg: RLMConfig, args: argparse.Namespace) -> RLMConfig: cfg.moe_backend = getattr(args, "moe_backend", cfg.moe_backend) cfg.moe_ep_size = getattr(args, "moe_ep_size", cfg.moe_ep_size) requested_compile = bool(getattr(args, "use_torch_compile", cfg.use_torch_compile)) cfg.use_torch_compile = requested_compile and runtime_torch_compile_available() return cfg def maybe_load_checkpoint(model: RubiRLM, checkpoint: Optional[str], device: torch.device) -> None: if not checkpoint: return state = torch.load(checkpoint, map_location=device) if isinstance(state, dict) and "model_state_dict" in state: model.load_state_dict(state["model_state_dict"]) return model.load_state_dict(state) def run_single_chat(args: argparse.Namespace) -> None: cfg = apply_runtime_config_overrides(resolve_config(args.scale), args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = RubiRLM(cfg).to(device) maybe_load_checkpoint(model, args.checkpoint, device) tokenizer = load_hf_tokenizer(args.tokenizer_name) history: List[ChatTurn] = [] if args.interactive: print("Interactive chat başladı. Çıkmak için /exit yaz.") while True: user_msg = input("You> ").strip() if not user_msg: continue if user_msg.lower() in {"/exit", "exit", "quit"}: break reply, history = model.chat( tokenizer=tokenizer, history=history, user_message=user_msg, lang=args.lang, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, steps=args.steps, device=device, ) print(f"Rubi> {reply}") return if not args.prompt: raise ValueError("--chat modunda --prompt veya --interactive gerekli.") reply, _ = model.chat( tokenizer=tokenizer, history=[], user_message=args.prompt, lang=args.lang, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, steps=args.steps, device=device, ) print(reply) def print_stack_report() -> None: report = detect_xqs_backends() print(format_backend_report(report)) def run_train_demo(args: argparse.Namespace) -> None: cfg = apply_runtime_config_overrides(resolve_config(args.scale), args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = RubiRLM(cfg).to(device) maybe_load_checkpoint(model, args.checkpoint, device) train_cfg = TrainStackConfig( optimizer_name=args.optimizer_name, learning_rate=args.learning_rate, weight_decay=args.weight_decay, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=not args.disable_pin_memory, prefetch_factor=args.prefetch_factor, persistent_workers=not args.disable_persistent_workers, max_seq_len=cfg.max_seq_len, dataset_dir=args.dataset_dir, use_bf16=not args.disable_bf16, ) dataset = build_dataset( dataset_dir=train_cfg.dataset_dir, vocab_size=cfg.vocab_size, max_seq_len=min(cfg.max_seq_len, args.train_seq_len), synthetic_samples=max(args.train_steps * args.batch_size * 2, 32), ) dataloader = build_dataloader(dataset, train_cfg, shuffle=True) optimizer = build_optimizer(model, train_cfg) mean_loss, total_tokens = train_demo_steps( model=model, optimizer=optimizer, dataloader=dataloader, device=device, steps=args.train_steps, use_bf16=train_cfg.use_bf16, ) print( f"train_demo optimizer={optimizer.__class__.__name__} steps={args.train_steps} " f"mean_loss={mean_loss:.4f} tokens={total_tokens:,} device={device}" ) def main() -> None: parser = argparse.ArgumentParser(description="Rubi-RLM recursive language model") parser.add_argument("--scale", choices=["1b", "tiny"], default="1b") parser.add_argument("--estimate-only", action="store_true") parser.add_argument("--demo", action="store_true") parser.add_argument("--train-demo", action="store_true") parser.add_argument("--stack-report", action="store_true") parser.add_argument("--chat", action="store_true", help="Türkçe/İngilizce sohbet modunu açar") parser.add_argument("--interactive", action="store_true", help="Interactive chat loop") parser.add_argument("--prompt", type=str, default="") parser.add_argument("--lang", choices=["auto", "tr", "en"], default="auto") parser.add_argument("--tokenizer-name", type=str, default="gpt2") parser.add_argument("--checkpoint", type=str, default=None) parser.add_argument("--steps", type=int, default=None) parser.add_argument("--max-new-tokens", type=int, default=192) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top-k", type=int, default=50) parser.add_argument("--optimizer-name", type=str, default="auto") parser.add_argument("--moe-backend", choices=["auto", "native", "deepspeed"], default="auto") parser.add_argument("--moe-ep-size", type=int, default=1) parser.add_argument("--use-torch-compile", action="store_true") parser.add_argument("--learning-rate", type=float, default=3e-4) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--num-workers", type=int, default=2) parser.add_argument("--prefetch-factor", type=int, default=4) parser.add_argument("--dataset-dir", type=str, default="") parser.add_argument("--train-steps", type=int, default=2) parser.add_argument("--train-seq-len", type=int, default=256) parser.add_argument("--disable-pin-memory", action="store_true") parser.add_argument("--disable-persistent-workers", action="store_true") parser.add_argument("--disable-bf16", action="store_true") args = parser.parse_args() if args.chat: run_single_chat(args) return if args.stack_report: print_stack_report() return if args.train_demo: run_train_demo(args) return if args.demo: demo() return cfg = apply_runtime_config_overrides(resolve_config(args.scale), args) n_params = estimate_parameters(cfg) active_params = estimate_active_parameters(cfg) print(f"Scale={args.scale}, estimated_params={n_params:,}, estimated_active_params={active_params:,}") if not args.estimate_only: model = RubiRLM(cfg) actual = sum(p.numel() for p in model.parameters()) print(f"actual_params={actual:,}") if __name__ == "__main__": main()