DevHunterAI commited on
Commit
cd16f07
·
verified ·
1 Parent(s): ee11e73

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ architecture.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - tr
5
+ license: apache-2.0
6
+ library_name: transformers
7
+ tags:
8
+ - rubirlm
9
+ - causal-lm
10
+ - base-model
11
+ - text-generation
12
+ - 1b
13
+ - moe
14
+ datasets:
15
+ - FineWeb
16
+ - UltraChat-200k
17
+ pipeline_tag: text-generation
18
+ ---
19
+
20
+ # RubiRLM-1B-Base
21
+
22
+ **RubiRLM-1B-Base** is a **1B-parameter base language model** released by **DevHunterAI**.
23
+
24
+ **Model size: 1B parameters**
25
+
26
+ **Training datasets:** FineWeb, UltraChat-200k
27
+
28
+ **Model type:** Base / pretrained language model
29
+
30
+ **Important:** This release is a **base model**. It can be used for prompt-based generation and experimental chat-style interaction, but it is **not an instruction-tuned chat assistant**.
31
+
32
+ ## Architecture
33
+
34
+ ![RubiRLM 1B Architecture](architecture.png)
35
+
36
+ **RubiRLM 1B** uses a recursive language modeling architecture with recurrent state flow, Mixture-of-Experts routing, and conditional block execution.
37
+
38
+ ## Key Features
39
+
40
+ - **1B parameters**
41
+ - **Recursive Language Model (RLM)** architecture
42
+ - **10 recursive blocks**
43
+ - **d_model = 1024**
44
+ - **16 attention heads**
45
+ - **max sequence length = 2048**
46
+ - **6 recursive reasoning steps**
47
+ - **Mixture-of-Experts: 32 experts, top-1 routing**
48
+ - **Layer skip router for conditional execution**
49
+ - **Packed execution support**
50
+ - **Tied token embedding and LM head**
51
+
52
+ ## Training Data
53
+
54
+ This model was trained using a mixture of:
55
+
56
+ - **FineWeb**
57
+ - **UltraChat-200k**
58
+
59
+ ## Intended Usage
60
+
61
+ This model is intended for:
62
+
63
+ - base language modeling research
64
+ - continued pretraining
65
+ - experimental prompt-based generation
66
+ - architecture experimentation around recursive and MoE-based language models
67
+
68
+ ## Not Intended As
69
+
70
+ This release should **not** be treated as:
71
+
72
+ - a fully aligned assistant
73
+ - a safety-tuned production chatbot
74
+ - an instruction-following model with guaranteed conversational quality
75
+
76
+ ## Loading
77
+
78
+ Because this repository includes custom model code, loading may require `trust_remote_code=True` depending on your workflow.
79
+
80
+ ## Files
81
+
82
+ - `pytorch_model.bin`: exported RubiRLM weights
83
+ - `training_checkpoint.pt`: original training checkpoint
84
+ - `config.json`: Hugging Face-facing config
85
+ - `rubirlm_config.json`: full RubiRLM architecture config
86
+ - `RubiRLM.py`: model implementation
87
+ - `xqs_moe.py`, `xqs_stack.py`, `x_quantum_sparse_ops.py`, `rubi_train_stack.py`: supporting code
88
+
89
+ ## Notes
90
+
91
+ The exported weights were produced from the final training checkpoint and packaged for Hugging Face publication.
RubiRLM.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rubi-RLM: 1B-class Recursive Language Model (RLM) prototype.
2
+
3
+ Bu dosya, recursive düşünme + dual-loop öğrenme hedefiyle tasarlanmış bir
4
+ araştırma prototipi içerir.
5
+
6
+ Eklenen sohbet katmanı:
7
+ - İngilizce/Türkçe çift dilli chat şablonu
8
+ - HF tokenizer ile metin->id / id->metin köprüsü
9
+ - Tek mesaj veya interaktif chat CLI
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import importlib
16
+ import importlib.util
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Protocol, Sequence, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from rubi_train_stack import (
24
+ TrainStackConfig,
25
+ build_dataloader,
26
+ build_dataset,
27
+ build_optimizer,
28
+ train_demo_steps,
29
+ )
30
+ from xqs_moe import build_deepspeed_moe
31
+ from xqs_stack import choose_moe_backend, detect_xqs_backends, format_backend_report
32
+ from x_quantum_sparse_ops import (
33
+ build_linear,
34
+ causal_scaled_dot_product_attention,
35
+ fused_residual_add,
36
+ maybe_compile_module,
37
+ pack_rows,
38
+ scatter_rows,
39
+ )
40
+
41
+
42
+ class TextTokenizer(Protocol):
43
+ def encode(self, text: str, return_tensors: Optional[str] = None): ...
44
+
45
+ def decode(self, token_ids: Sequence[int], skip_special_tokens: bool = True) -> str: ...
46
+
47
+
48
+ @dataclass
49
+ class ChatTurn:
50
+ role: str
51
+ content: str
52
+
53
+
54
+ @dataclass
55
+ class RLMConfig:
56
+ vocab_size: int = 50_257
57
+ max_seq_len: int = 2_048
58
+ d_model: int = 2_048
59
+ n_layers: int = 14
60
+ n_heads: int = 16
61
+ ff_mult: int = 4
62
+ dropout: float = 0.1
63
+ recurse_steps: int = 6
64
+ critique_threshold: float = 0.20
65
+ tie_embeddings: bool = True
66
+ use_moe: bool = False
67
+ moe_num_experts: int = 0
68
+ moe_top_k: int = 2
69
+ moe_expert_hidden: int = 0
70
+ moe_router_jitter: float = 0.0
71
+ moe_aux_loss_weight: float = 0.01
72
+ use_layer_skip: bool = False
73
+ layer_skip_threshold: float = 0.50
74
+ layer_skip_target: float = 1.0
75
+ layer_skip_aux_weight: float = 0.01
76
+ use_ternary_weights: bool = False
77
+ use_flash_attention: bool = False
78
+ use_fused_ops: bool = False
79
+ packed_execution: bool = False
80
+ use_torch_compile: bool = False
81
+ moe_backend: str = "auto"
82
+ moe_ep_size: int = 1
83
+
84
+ @classmethod
85
+ def scale_1b(cls) -> "RLMConfig":
86
+ return cls(
87
+ vocab_size=50_257,
88
+ max_seq_len=2_048,
89
+ d_model=1_024,
90
+ n_layers=10,
91
+ n_heads=16,
92
+ ff_mult=4,
93
+ recurse_steps=6,
94
+ critique_threshold=0.20,
95
+ use_moe=True,
96
+ moe_num_experts=32,
97
+ moe_top_k=1,
98
+ moe_expert_hidden=1_280,
99
+ moe_router_jitter=0.01,
100
+ moe_aux_loss_weight=0.01,
101
+ use_layer_skip=True,
102
+ layer_skip_threshold=0.80,
103
+ layer_skip_target=0.03,
104
+ layer_skip_aux_weight=0.01,
105
+ use_ternary_weights=True,
106
+ use_flash_attention=True,
107
+ use_fused_ops=True,
108
+ packed_execution=True,
109
+ use_torch_compile=False,
110
+ moe_backend="auto",
111
+ moe_ep_size=1,
112
+ )
113
+
114
+
115
+ class RMSNorm(nn.Module):
116
+ def __init__(self, d_model: int, eps: float = 1e-6):
117
+ super().__init__()
118
+ self.scale = nn.Parameter(torch.ones(d_model))
119
+ self.eps = eps
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
123
+ return self.scale * (x / rms)
124
+
125
+
126
+ class DenseFeedForward(nn.Module):
127
+ def __init__(self, cfg: RLMConfig):
128
+ super().__init__()
129
+ hidden = cfg.d_model * cfg.ff_mult
130
+ self.up_proj = build_linear(cfg.d_model, hidden, ternary=cfg.use_ternary_weights)
131
+ self.down_proj = build_linear(hidden, cfg.d_model, ternary=cfg.use_ternary_weights)
132
+ self.dropout = nn.Dropout(cfg.dropout)
133
+
134
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
135
+ return self.dropout(self.down_proj(F.gelu(self.up_proj(x)))), x.new_zeros(())
136
+
137
+
138
+ class FastSelfAttention(nn.Module):
139
+ def __init__(self, cfg: RLMConfig):
140
+ super().__init__()
141
+ if cfg.d_model % cfg.n_heads != 0:
142
+ raise ValueError("d_model must be divisible by n_heads.")
143
+ self.n_heads = cfg.n_heads
144
+ self.head_dim = cfg.d_model // cfg.n_heads
145
+ self.dropout = cfg.dropout
146
+ self.use_flash_attention = cfg.use_flash_attention
147
+ self.q_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
148
+ self.k_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
149
+ self.v_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
150
+ self.out_proj = build_linear(cfg.d_model, cfg.d_model, bias=False, ternary=cfg.use_ternary_weights)
151
+
152
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
153
+ bsz, seq_len, _ = x.shape
154
+ q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
155
+ k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
156
+ v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
157
+ attn_out = causal_scaled_dot_product_attention(
158
+ q,
159
+ k,
160
+ v,
161
+ dropout_p=self.dropout,
162
+ training=self.training,
163
+ )
164
+ attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seq_len, self.n_heads * self.head_dim)
165
+ return self.out_proj(attn_out)
166
+
167
+
168
+ class MoEExpert(nn.Module):
169
+ def __init__(self, d_model: int, hidden: int):
170
+ super().__init__()
171
+ self.up_proj = build_linear(d_model, hidden, ternary=True)
172
+ self.down_proj = build_linear(hidden, d_model, ternary=True)
173
+
174
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
175
+ return self.down_proj(F.gelu(self.up_proj(x)))
176
+
177
+
178
+ class MoEFeedForward(nn.Module):
179
+ def __init__(self, cfg: RLMConfig):
180
+ super().__init__()
181
+ if cfg.moe_num_experts <= 0:
182
+ raise ValueError("moe_num_experts must be positive when use_moe=True.")
183
+ if cfg.moe_top_k <= 0 or cfg.moe_top_k > cfg.moe_num_experts:
184
+ raise ValueError("moe_top_k must be in the range [1, moe_num_experts].")
185
+
186
+ self.num_experts = cfg.moe_num_experts
187
+ self.top_k = cfg.moe_top_k
188
+ self.router_jitter = cfg.moe_router_jitter
189
+ requested_backend = cfg.moe_backend.lower()
190
+ self.backend = choose_moe_backend(prefer_deepspeed=requested_backend in {"auto", "deepspeed"}) if requested_backend != "native" else "native"
191
+ self.router = build_linear(cfg.d_model, cfg.moe_num_experts, ternary=cfg.use_ternary_weights)
192
+ self.experts = nn.ModuleList([MoEExpert(cfg.d_model, cfg.moe_expert_hidden) for _ in range(cfg.moe_num_experts)])
193
+ self.deepspeed_moe = None
194
+ if self.backend == "deepspeed":
195
+ self.deepspeed_moe = build_deepspeed_moe(
196
+ hidden_size=cfg.d_model,
197
+ expert=MoEExpert(cfg.d_model, cfg.moe_expert_hidden),
198
+ num_experts=cfg.moe_num_experts,
199
+ top_k=cfg.moe_top_k,
200
+ ep_size=cfg.moe_ep_size,
201
+ )
202
+ if self.deepspeed_moe is None:
203
+ self.backend = "native"
204
+ self.dropout = nn.Dropout(cfg.dropout)
205
+
206
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ if self.deepspeed_moe is not None:
208
+ out, aux_loss = self.deepspeed_moe(x)
209
+ return self.dropout(out), aux_loss
210
+ flat_x = x.reshape(-1, x.size(-1))
211
+ router_logits = self.router(flat_x)
212
+ if self.training and self.router_jitter > 0:
213
+ router_logits = router_logits + torch.randn_like(router_logits) * self.router_jitter
214
+
215
+ router_probs = F.softmax(router_logits, dim=-1)
216
+ topk_weights, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
217
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
218
+
219
+ mixed = flat_x.new_zeros(flat_x.shape)
220
+ expert_load = router_probs.new_zeros(self.num_experts)
221
+
222
+ for expert_id, expert in enumerate(self.experts):
223
+ expert_mask = topk_indices == expert_id
224
+ if not expert_mask.any():
225
+ continue
226
+ token_indices, slot_indices = expert_mask.nonzero(as_tuple=True)
227
+ expert_inputs = flat_x.index_select(0, token_indices)
228
+ expert_outputs = expert(expert_inputs)
229
+ weights = topk_weights[token_indices, slot_indices].unsqueeze(-1)
230
+ mixed.index_add_(0, token_indices, expert_outputs * weights)
231
+ expert_load[expert_id] = float(token_indices.numel())
232
+
233
+ mixed = self.dropout(mixed.view_as(x))
234
+ importance = router_probs.mean(dim=0)
235
+ load = expert_load / max(1, flat_x.size(0) * self.top_k)
236
+ aux_loss = self.num_experts * torch.sum(importance * load)
237
+ return mixed, aux_loss
238
+
239
+
240
+ class RecursiveBlock(nn.Module):
241
+ def __init__(self, cfg: RLMConfig):
242
+ super().__init__()
243
+
244
+ self.use_layer_skip = cfg.use_layer_skip
245
+ self.layer_skip_threshold = cfg.layer_skip_threshold
246
+ self.layer_skip_target = cfg.layer_skip_target
247
+ self.use_fused_ops = cfg.use_fused_ops
248
+ self.packed_execution = cfg.packed_execution
249
+ self.norm_attn = RMSNorm(cfg.d_model)
250
+ self.norm_ff = RMSNorm(cfg.d_model)
251
+ self.attn = FastSelfAttention(cfg)
252
+ self.ffn = MoEFeedForward(cfg) if cfg.use_moe else DenseFeedForward(cfg)
253
+ self.skip_router = build_linear(cfg.d_model, 1, ternary=cfg.use_ternary_weights) if cfg.use_layer_skip else None
254
+
255
+ self.state_fuse = build_linear(cfg.d_model * 2, cfg.d_model, ternary=cfg.use_ternary_weights)
256
+ self.state_update = build_linear(cfg.d_model, cfg.d_model, ternary=cfg.use_ternary_weights)
257
+ self.state_gate = build_linear(cfg.d_model * 2, cfg.d_model, ternary=cfg.use_ternary_weights)
258
+
259
+ def _run_core(
260
+ self,
261
+ x: torch.Tensor,
262
+ state: torch.Tensor,
263
+ attn_mask: Optional[torch.Tensor] = None,
264
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
265
+ x_norm = self.norm_attn(x)
266
+ attn_out = self.attn(x_norm, attn_mask=attn_mask)
267
+ fuse_input = torch.cat([attn_out, state], dim=-1)
268
+ gate = torch.sigmoid(self.state_gate(fuse_input))
269
+ fused = self.state_fuse(fuse_input)
270
+ fused = gate * fused + (1.0 - gate) * state
271
+ if self.use_fused_ops:
272
+ x = fused_residual_add(x, fused)
273
+ else:
274
+ x = x + fused
275
+ ff_out, moe_aux_loss = self.ffn(self.norm_ff(x))
276
+ if self.use_fused_ops:
277
+ x = fused_residual_add(x, ff_out)
278
+ else:
279
+ x = x + ff_out
280
+ new_state = torch.tanh(self.state_update(x))
281
+ return x, new_state, moe_aux_loss
282
+
283
+ def forward(
284
+ self,
285
+ x: torch.Tensor,
286
+ state: torch.Tensor,
287
+ attn_mask: Optional[torch.Tensor] = None,
288
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
289
+ exec_prob = x.new_ones((x.size(0),))
290
+ skip_aux_loss = x.new_zeros(())
291
+ if self.skip_router is None:
292
+ x, new_state, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
293
+ return x, new_state, moe_aux_loss, skip_aux_loss, exec_prob.mean()
294
+
295
+ router_input = x.mean(dim=1)
296
+ exec_prob = torch.sigmoid(self.skip_router(router_input)).squeeze(-1)
297
+ target = exec_prob.new_full(exec_prob.shape, self.layer_skip_target)
298
+ skip_aux_loss = F.mse_loss(exec_prob, target)
299
+ hard_gate = exec_prob >= self.layer_skip_threshold
300
+ if not torch.any(hard_gate):
301
+ return x, state, x.new_zeros(()), skip_aux_loss, exec_prob.mean()
302
+
303
+ if torch.all(hard_gate):
304
+ x_exec, state_exec, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
305
+ elif self.packed_execution:
306
+ active_indices = torch.nonzero(hard_gate, as_tuple=False).squeeze(-1)
307
+ x_active, state_active = pack_rows(active_indices, x, state)
308
+ x_active, state_active, moe_aux_loss = self._run_core(x_active, state_active, attn_mask=attn_mask)
309
+ x_exec = scatter_rows(x, active_indices, x_active)
310
+ state_exec = scatter_rows(state, active_indices, state_active)
311
+ else:
312
+ x_exec, state_exec, moe_aux_loss = self._run_core(x, state, attn_mask=attn_mask)
313
+
314
+ if self.training:
315
+ exec_gate = exec_prob + (hard_gate.to(exec_prob.dtype) - exec_prob).detach()
316
+ exec_scale = exec_gate.view(-1, 1, 1)
317
+ x_exec = x + exec_scale * (x_exec - x)
318
+ state_exec = state + exec_scale * (state_exec - state)
319
+
320
+ return x_exec, state_exec, moe_aux_loss, skip_aux_loss, exec_prob.mean()
321
+
322
+
323
+ class RubiRLM(nn.Module):
324
+ def __init__(self, cfg: RLMConfig):
325
+ super().__init__()
326
+ self.cfg = cfg
327
+ self._last_moe_aux_loss = torch.tensor(0.0)
328
+ self._last_layer_skip_aux_loss = torch.tensor(0.0)
329
+
330
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
331
+ self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
332
+ self.drop = nn.Dropout(cfg.dropout)
333
+
334
+ self.layers = nn.ModuleList([maybe_compile_module(RecursiveBlock(cfg), cfg.use_torch_compile) for _ in range(cfg.n_layers)])
335
+ self.final_norm = RMSNorm(cfg.d_model)
336
+
337
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
338
+ if cfg.tie_embeddings:
339
+ self.lm_head.weight = self.tok_emb.weight
340
+
341
+ self.critique_head = nn.Sequential(
342
+ nn.Linear(cfg.d_model, cfg.d_model // 2),
343
+ nn.GELU(),
344
+ nn.Linear(cfg.d_model // 2, 1),
345
+ )
346
+
347
+ def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
348
+ mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
349
+ return torch.triu(mask, diagonal=1)
350
+
351
+ def _embed(self, input_ids: torch.Tensor) -> torch.Tensor:
352
+ bsz, seq_len = input_ids.shape
353
+ if seq_len > self.cfg.max_seq_len:
354
+ raise ValueError(f"Girdi uzunluğu max_seq_len={self.cfg.max_seq_len} sınırını aşıyor.")
355
+ pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(bsz, seq_len)
356
+ return self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
357
+
358
+ def forward_recursive(
359
+ self,
360
+ input_ids: torch.Tensor,
361
+ steps: Optional[int] = None,
362
+ stop_on_critique: bool = True,
363
+ return_trace: bool = False,
364
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
365
+ steps = steps or self.cfg.recurse_steps
366
+ x = self._embed(input_ids)
367
+
368
+ bsz, seq_len, d_model = x.shape
369
+ states = [x.new_zeros((bsz, seq_len, d_model)) for _ in range(self.cfg.n_layers)]
370
+ mask = self._causal_mask(seq_len, x.device)
371
+
372
+ logits_trace: List[torch.Tensor] = []
373
+ critique_trace: List[torch.Tensor] = []
374
+ moe_aux_total = x.new_zeros(())
375
+ layer_skip_aux_total = x.new_zeros(())
376
+
377
+ for _ in range(steps):
378
+ h = x
379
+ new_states = []
380
+ for layer, st in zip(self.layers, states):
381
+ h, st_new, moe_aux, skip_aux, _ = layer(h, st, attn_mask=mask)
382
+ new_states.append(st_new)
383
+ moe_aux_total = moe_aux_total + moe_aux
384
+ layer_skip_aux_total = layer_skip_aux_total + skip_aux
385
+ states = new_states
386
+
387
+ h_norm = self.final_norm(h)
388
+ logits = self.lm_head(h_norm)
389
+ pooled = h_norm[:, -1, :]
390
+ critique = torch.sigmoid(self.critique_head(pooled)).squeeze(-1)
391
+
392
+ logits_trace.append(logits)
393
+ critique_trace.append(critique)
394
+ x = h
395
+
396
+ if stop_on_critique and torch.all(critique < self.cfg.critique_threshold):
397
+ break
398
+
399
+ denom = max(1, len(logits_trace) * len(self.layers))
400
+ self._last_moe_aux_loss = moe_aux_total / denom
401
+ self._last_layer_skip_aux_loss = layer_skip_aux_total / denom
402
+
403
+ final_logits = logits_trace[-1]
404
+ if return_trace:
405
+ return final_logits, logits_trace, critique_trace
406
+ return final_logits, [], critique_trace
407
+
408
+ def training_loss(
409
+ self,
410
+ input_ids: torch.Tensor,
411
+ target_ids: torch.Tensor,
412
+ steps: Optional[int] = None,
413
+ alpha_iterative: float = 0.30,
414
+ beta_correction: float = 0.10,
415
+ ) -> torch.Tensor:
416
+ final_logits, trace, critique = self.forward_recursive(
417
+ input_ids, steps=steps, stop_on_critique=False, return_trace=True
418
+ )
419
+
420
+ final_loss = F.cross_entropy(
421
+ final_logits.view(-1, final_logits.size(-1)),
422
+ target_ids.view(-1),
423
+ ignore_index=-100,
424
+ )
425
+
426
+ if trace:
427
+ iterative = 0.0
428
+ for logits in trace[:-1]:
429
+ iterative = iterative + F.cross_entropy(
430
+ logits.view(-1, logits.size(-1)),
431
+ target_ids.view(-1),
432
+ ignore_index=-100,
433
+ )
434
+ iterative = iterative / max(1, len(trace) - 1)
435
+ else:
436
+ iterative = final_loss.new_tensor(0.0)
437
+
438
+ correction_bonus = 0.0
439
+ if len(critique) > 1:
440
+ start = critique[0].mean()
441
+ end = critique[-1].mean()
442
+ correction_bonus = torch.relu(end - start)
443
+
444
+ total_loss = final_loss + alpha_iterative * iterative + beta_correction * correction_bonus
445
+ if self.cfg.use_moe:
446
+ total_loss = total_loss + self.cfg.moe_aux_loss_weight * self._last_moe_aux_loss
447
+ if self.cfg.use_layer_skip:
448
+ total_loss = total_loss + self.cfg.layer_skip_aux_weight * self._last_layer_skip_aux_loss
449
+ return total_loss
450
+
451
+ @torch.no_grad()
452
+ def generate(
453
+ self,
454
+ input_ids: torch.Tensor,
455
+ max_new_tokens: int = 64,
456
+ temperature: float = 0.8,
457
+ top_k: int = 50,
458
+ steps: Optional[int] = None,
459
+ ) -> torch.Tensor:
460
+ self.eval()
461
+ out = input_ids
462
+
463
+ for _ in range(max_new_tokens):
464
+ context = out[:, -self.cfg.max_seq_len :]
465
+ logits, _, _ = self.forward_recursive(context, steps=steps, stop_on_critique=True, return_trace=False)
466
+ next_logits = logits[:, -1, :] / max(temperature, 1e-5)
467
+
468
+ if top_k > 0:
469
+ values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
470
+ cutoff = values[:, [-1]]
471
+ next_logits = torch.where(next_logits < cutoff, torch.full_like(next_logits, -1e9), next_logits)
472
+
473
+ probs = F.softmax(next_logits, dim=-1)
474
+ next_token = torch.multinomial(probs, num_samples=1)
475
+ out = torch.cat([out, next_token], dim=1)
476
+
477
+ return out
478
+
479
+ def generate_text(
480
+ self,
481
+ tokenizer: TextTokenizer,
482
+ prompt: str,
483
+ max_new_tokens: int = 128,
484
+ temperature: float = 0.7,
485
+ top_k: int = 50,
486
+ steps: Optional[int] = None,
487
+ device: Optional[torch.device] = None,
488
+ ) -> str:
489
+ device = device or next(self.parameters()).device
490
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
491
+ output_ids = self.generate(
492
+ input_ids,
493
+ max_new_tokens=max_new_tokens,
494
+ temperature=temperature,
495
+ top_k=top_k,
496
+ steps=steps,
497
+ )
498
+ new_tokens = output_ids[0, input_ids.shape[1] :].tolist()
499
+ return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
500
+
501
+ def chat(
502
+ self,
503
+ tokenizer: TextTokenizer,
504
+ history: List[ChatTurn],
505
+ user_message: str,
506
+ lang: str = "auto",
507
+ max_new_tokens: int = 192,
508
+ temperature: float = 0.7,
509
+ top_k: int = 50,
510
+ steps: Optional[int] = None,
511
+ device: Optional[torch.device] = None,
512
+ ) -> Tuple[str, List[ChatTurn]]:
513
+ prompt = build_chat_prompt(history, user_message, lang=lang)
514
+ assistant_reply = self.generate_text(
515
+ tokenizer=tokenizer,
516
+ prompt=prompt,
517
+ max_new_tokens=max_new_tokens,
518
+ temperature=temperature,
519
+ top_k=top_k,
520
+ steps=steps,
521
+ device=device,
522
+ )
523
+ updated = history + [ChatTurn(role="user", content=user_message), ChatTurn(role="assistant", content=assistant_reply)]
524
+ return assistant_reply, updated
525
+
526
+ def outer_sleep_phase_step(
527
+ self,
528
+ optimizer: torch.optim.Optimizer,
529
+ input_ids: torch.Tensor,
530
+ target_ids: torch.Tensor,
531
+ steps: Optional[int] = None,
532
+ ) -> float:
533
+ self.train()
534
+ optimizer.zero_grad(set_to_none=True)
535
+ loss = self.training_loss(input_ids, target_ids, steps=steps)
536
+ loss.backward()
537
+ nn.utils.clip_grad_norm_(self.parameters(), 1.0)
538
+ optimizer.step()
539
+ return float(loss.detach().item())
540
+
541
+
542
+ def estimate_parameters(cfg: RLMConfig) -> int:
543
+ d = cfg.d_model
544
+ total = cfg.vocab_size * d + cfg.max_seq_len * d
545
+ attn_params = (4 * d * d) + (4 * d)
546
+ state_params = (5 * d * d) + (3 * d)
547
+ router_params = 0
548
+ layer_skip_params = 0
549
+ ff_params = (2 * d * d * cfg.ff_mult) + (d * cfg.ff_mult) + d
550
+ if cfg.use_moe:
551
+ router_params = (d * cfg.moe_num_experts) + cfg.moe_num_experts
552
+ expert_params = (2 * d * cfg.moe_expert_hidden) + cfg.moe_expert_hidden + d
553
+ ff_params = cfg.moe_num_experts * expert_params
554
+ if cfg.use_layer_skip:
555
+ layer_skip_params = d + 1
556
+ per_layer = attn_params + state_params + router_params + layer_skip_params + ff_params + (2 * d)
557
+ total += cfg.n_layers * per_layer
558
+ total += d * (d // 2) + (d // 2) + (d // 2) + 1 + d
559
+ if not cfg.tie_embeddings:
560
+ total += d * cfg.vocab_size
561
+ return total
562
+
563
+
564
+ def estimate_active_parameters(cfg: RLMConfig) -> int:
565
+ d = cfg.d_model
566
+ total = cfg.vocab_size * d + cfg.max_seq_len * d
567
+ attn_params = (4 * d * d) + (4 * d)
568
+ state_params = (5 * d * d) + (3 * d)
569
+ router_params = 0
570
+ layer_skip_params = 0
571
+ ff_params = (2 * d * d * cfg.ff_mult) + (d * cfg.ff_mult) + d
572
+ if cfg.use_moe:
573
+ router_params = (d * cfg.moe_num_experts) + cfg.moe_num_experts
574
+ expert_params = (2 * d * cfg.moe_expert_hidden) + cfg.moe_expert_hidden + d
575
+ ff_params = cfg.moe_top_k * expert_params
576
+ if cfg.use_layer_skip:
577
+ layer_skip_params = d + 1
578
+ routed_layer = attn_params + state_params + router_params + ff_params + (2 * d)
579
+ routed_layer = cfg.layer_skip_target * routed_layer
580
+ per_layer = layer_skip_params + routed_layer
581
+ total += cfg.n_layers * per_layer
582
+ total += d * (d // 2) + (d // 2) + (d // 2) + 1 + d
583
+ if not cfg.tie_embeddings:
584
+ total += d * cfg.vocab_size
585
+ return int(total)
586
+
587
+
588
+ def language_system_prompt(lang: str) -> str:
589
+ base = (
590
+ "You are Rubi-RLM assistant. Reason step-by-step internally, be concise in final answer, "
591
+ "self-correct if needed."
592
+ )
593
+ if lang == "tr":
594
+ return base + " Yanıtlarını Türkçe ver."
595
+ if lang == "en":
596
+ return base + " Reply in English."
597
+ return base + " Reply in the user's language (Turkish or English)."
598
+
599
+
600
+ def build_chat_prompt(history: List[ChatTurn], user_message: str, lang: str = "auto") -> str:
601
+ lines = [f"<|system|>\n{language_system_prompt(lang)}"]
602
+ for turn in history:
603
+ role = "user" if turn.role.lower() == "user" else "assistant"
604
+ lines.append(f"<|{role}|>\n{turn.content}")
605
+ lines.append(f"\n{user_message}")
606
+ lines.append("<|assistant|>\n")
607
+ return "\n".join(lines)
608
+
609
+
610
+ def load_hf_tokenizer(tokenizer_name: str):
611
+ if importlib.util.find_spec("transformers") is None:
612
+ raise RuntimeError("transformers yüklü değil. `pip install transformers` ile kurun.")
613
+ transformers = importlib.import_module("transformers")
614
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
615
+ if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None:
616
+ tokenizer.pad_token = tokenizer.eos_token
617
+ return tokenizer
618
+
619
+
620
+ def demo() -> None:
621
+ cfg = RLMConfig(
622
+ vocab_size=4096,
623
+ max_seq_len=128,
624
+ d_model=256,
625
+ n_layers=4,
626
+ n_heads=8,
627
+ ff_mult=4,
628
+ recurse_steps=4,
629
+ use_moe=True,
630
+ moe_num_experts=8,
631
+ moe_top_k=2,
632
+ moe_expert_hidden=384,
633
+ )
634
+ model = RubiRLM(cfg)
635
+ x = torch.randint(0, cfg.vocab_size, (2, 32))
636
+ y = torch.randint(0, cfg.vocab_size, (2, 32))
637
+
638
+ loss = model.training_loss(x, y)
639
+ print(f"demo_loss={loss.item():.4f}")
640
+
641
+ out = model.generate(x[:, :8], max_new_tokens=8, steps=3)
642
+ print("generated_shape=", tuple(out.shape))
643
+
644
+
645
+ def resolve_config(scale: str) -> RLMConfig:
646
+ if scale == "1b":
647
+ return RLMConfig.scale_1b()
648
+ return RLMConfig(d_model=512, n_layers=8, n_heads=8, vocab_size=50_257, max_seq_len=512)
649
+
650
+
651
+ def runtime_torch_compile_available() -> bool:
652
+ if not hasattr(torch, "compile"):
653
+ return False
654
+ if torch.cuda.is_available() and importlib.util.find_spec("triton") is None:
655
+ return False
656
+ return True
657
+
658
+
659
+ def apply_runtime_config_overrides(cfg: RLMConfig, args: argparse.Namespace) -> RLMConfig:
660
+ cfg.moe_backend = getattr(args, "moe_backend", cfg.moe_backend)
661
+ cfg.moe_ep_size = getattr(args, "moe_ep_size", cfg.moe_ep_size)
662
+ requested_compile = bool(getattr(args, "use_torch_compile", cfg.use_torch_compile))
663
+ cfg.use_torch_compile = requested_compile and runtime_torch_compile_available()
664
+ return cfg
665
+
666
+
667
+ def maybe_load_checkpoint(model: RubiRLM, checkpoint: Optional[str], device: torch.device) -> None:
668
+ if not checkpoint:
669
+ return
670
+ state = torch.load(checkpoint, map_location=device)
671
+ if isinstance(state, dict) and "model_state_dict" in state:
672
+ model.load_state_dict(state["model_state_dict"])
673
+ return
674
+ model.load_state_dict(state)
675
+
676
+
677
+ def run_single_chat(args: argparse.Namespace) -> None:
678
+ cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
679
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
680
+ model = RubiRLM(cfg).to(device)
681
+ maybe_load_checkpoint(model, args.checkpoint, device)
682
+ tokenizer = load_hf_tokenizer(args.tokenizer_name)
683
+
684
+ history: List[ChatTurn] = []
685
+ if args.interactive:
686
+ print("Interactive chat başladı. Çıkmak için /exit yaz.")
687
+ while True:
688
+ user_msg = input("You> ").strip()
689
+ if not user_msg:
690
+ continue
691
+ if user_msg.lower() in {"/exit", "exit", "quit"}:
692
+ break
693
+ reply, history = model.chat(
694
+ tokenizer=tokenizer,
695
+ history=history,
696
+ user_message=user_msg,
697
+ lang=args.lang,
698
+ max_new_tokens=args.max_new_tokens,
699
+ temperature=args.temperature,
700
+ top_k=args.top_k,
701
+ steps=args.steps,
702
+ device=device,
703
+ )
704
+ print(f"Rubi> {reply}")
705
+ return
706
+
707
+ if not args.prompt:
708
+ raise ValueError("--chat modunda --prompt veya --interactive gerekli.")
709
+
710
+ reply, _ = model.chat(
711
+ tokenizer=tokenizer,
712
+ history=[],
713
+ user_message=args.prompt,
714
+ lang=args.lang,
715
+ max_new_tokens=args.max_new_tokens,
716
+ temperature=args.temperature,
717
+ top_k=args.top_k,
718
+ steps=args.steps,
719
+ device=device,
720
+ )
721
+ print(reply)
722
+
723
+
724
+ def print_stack_report() -> None:
725
+ report = detect_xqs_backends()
726
+ print(format_backend_report(report))
727
+
728
+
729
+ def run_train_demo(args: argparse.Namespace) -> None:
730
+ cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
731
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
732
+ model = RubiRLM(cfg).to(device)
733
+ maybe_load_checkpoint(model, args.checkpoint, device)
734
+
735
+ train_cfg = TrainStackConfig(
736
+ optimizer_name=args.optimizer_name,
737
+ learning_rate=args.learning_rate,
738
+ weight_decay=args.weight_decay,
739
+ batch_size=args.batch_size,
740
+ num_workers=args.num_workers,
741
+ pin_memory=not args.disable_pin_memory,
742
+ prefetch_factor=args.prefetch_factor,
743
+ persistent_workers=not args.disable_persistent_workers,
744
+ max_seq_len=cfg.max_seq_len,
745
+ dataset_dir=args.dataset_dir,
746
+ use_bf16=not args.disable_bf16,
747
+ )
748
+ dataset = build_dataset(
749
+ dataset_dir=train_cfg.dataset_dir,
750
+ vocab_size=cfg.vocab_size,
751
+ max_seq_len=min(cfg.max_seq_len, args.train_seq_len),
752
+ synthetic_samples=max(args.train_steps * args.batch_size * 2, 32),
753
+ )
754
+ dataloader = build_dataloader(dataset, train_cfg, shuffle=True)
755
+ optimizer = build_optimizer(model, train_cfg)
756
+ mean_loss, total_tokens = train_demo_steps(
757
+ model=model,
758
+ optimizer=optimizer,
759
+ dataloader=dataloader,
760
+ device=device,
761
+ steps=args.train_steps,
762
+ use_bf16=train_cfg.use_bf16,
763
+ )
764
+ print(
765
+ f"train_demo optimizer={optimizer.__class__.__name__} steps={args.train_steps} "
766
+ f"mean_loss={mean_loss:.4f} tokens={total_tokens:,} device={device}"
767
+ )
768
+
769
+
770
+ def main() -> None:
771
+ parser = argparse.ArgumentParser(description="Rubi-RLM recursive language model")
772
+ parser.add_argument("--scale", choices=["1b", "tiny"], default="1b")
773
+ parser.add_argument("--estimate-only", action="store_true")
774
+ parser.add_argument("--demo", action="store_true")
775
+ parser.add_argument("--train-demo", action="store_true")
776
+ parser.add_argument("--stack-report", action="store_true")
777
+
778
+ parser.add_argument("--chat", action="store_true", help="Türkçe/İngilizce sohbet modunu açar")
779
+ parser.add_argument("--interactive", action="store_true", help="Interactive chat loop")
780
+ parser.add_argument("--prompt", type=str, default="")
781
+ parser.add_argument("--lang", choices=["auto", "tr", "en"], default="auto")
782
+ parser.add_argument("--tokenizer-name", type=str, default="gpt2")
783
+ parser.add_argument("--checkpoint", type=str, default=None)
784
+ parser.add_argument("--steps", type=int, default=None)
785
+ parser.add_argument("--max-new-tokens", type=int, default=192)
786
+ parser.add_argument("--temperature", type=float, default=0.7)
787
+ parser.add_argument("--top-k", type=int, default=50)
788
+ parser.add_argument("--optimizer-name", type=str, default="auto")
789
+ parser.add_argument("--moe-backend", choices=["auto", "native", "deepspeed"], default="auto")
790
+ parser.add_argument("--moe-ep-size", type=int, default=1)
791
+ parser.add_argument("--use-torch-compile", action="store_true")
792
+ parser.add_argument("--learning-rate", type=float, default=3e-4)
793
+ parser.add_argument("--weight-decay", type=float, default=0.01)
794
+ parser.add_argument("--batch-size", type=int, default=2)
795
+ parser.add_argument("--num-workers", type=int, default=2)
796
+ parser.add_argument("--prefetch-factor", type=int, default=4)
797
+ parser.add_argument("--dataset-dir", type=str, default="")
798
+ parser.add_argument("--train-steps", type=int, default=2)
799
+ parser.add_argument("--train-seq-len", type=int, default=256)
800
+ parser.add_argument("--disable-pin-memory", action="store_true")
801
+ parser.add_argument("--disable-persistent-workers", action="store_true")
802
+ parser.add_argument("--disable-bf16", action="store_true")
803
+ args = parser.parse_args()
804
+
805
+ if args.chat:
806
+ run_single_chat(args)
807
+ return
808
+
809
+ if args.stack_report:
810
+ print_stack_report()
811
+ return
812
+
813
+ if args.train_demo:
814
+ run_train_demo(args)
815
+ return
816
+
817
+ if args.demo:
818
+ demo()
819
+ return
820
+
821
+ cfg = apply_runtime_config_overrides(resolve_config(args.scale), args)
822
+ n_params = estimate_parameters(cfg)
823
+ active_params = estimate_active_parameters(cfg)
824
+ print(f"Scale={args.scale}, estimated_params={n_params:,}, estimated_active_params={active_params:,}")
825
+ if not args.estimate_only:
826
+ model = RubiRLM(cfg)
827
+ actual = sum(p.numel() for p in model.parameters())
828
+ print(f"actual_params={actual:,}")
829
+
830
+
831
+ if __name__ == "__main__":
832
+ main()
architecture.png ADDED

Git LFS Details

  • SHA256: 0bbfea9207bced66682e03f5e73bfcf6025d269e40fd3d74fbce8c0d6d51c503
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RubiRLM"
4
+ ],
5
+ "model_type": "rubirlm",
6
+ "vocab_size": 50257,
7
+ "max_position_embeddings": 2048,
8
+ "hidden_size": 1024,
9
+ "num_hidden_layers": 10,
10
+ "num_attention_heads": 16,
11
+ "intermediate_size": 4096,
12
+ "tie_word_embeddings": true,
13
+ "tokenizer_name": "gpt2",
14
+ "trust_remote_code": true
15
+ }
export_manifest.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint_path": "D:\\Downloads\\final (1).pt",
3
+ "step": 99361,
4
+ "scale": "1b",
5
+ "tokenizer_name": "gpt2"
6
+ }
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_new_tokens": 192,
3
+ "temperature": 0.7,
4
+ "top_k": 50,
5
+ "pad_token_id": 50256,
6
+ "bos_token_id": 50256,
7
+ "eos_token_id": 50256
8
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94fd068c3b25ebcd5c311c8585a5e3560543c0520dcace9b0dae74cba1918738
3
+ size 3954306149
rubi_train_stack.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import bisect
4
+ import functools
5
+ import importlib.util
6
+ import json
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Dict, Iterable, List, Optional, Tuple
10
+
11
+ import torch
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from xqs_stack import choose_optimizer_backend
14
+
15
+ SHARD_INDEX_FILENAME = "shard_index.json"
16
+ SHARD_INDEX_PROGRESS_EVERY = 256
17
+
18
+
19
+ @dataclass
20
+ class TrainStackConfig:
21
+ optimizer_name: str = "adafactor"
22
+ learning_rate: float = 3e-4
23
+ weight_decay: float = 0.01
24
+ batch_size: int = 4
25
+ grad_accum_steps: int = 1
26
+ num_workers: int = 2
27
+ pin_memory: bool = True
28
+ prefetch_factor: int = 4
29
+ persistent_workers: bool = True
30
+ max_seq_len: int = 2048
31
+ dataset_dir: str = ""
32
+ use_bf16: bool = True
33
+
34
+
35
+ class PretokenizedShardDataset(Dataset):
36
+ def __init__(self, dataset_dir: str, max_seq_len: int):
37
+ self.root = Path(dataset_dir)
38
+ if not self.root.exists():
39
+ raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
40
+ self.max_seq_len = max_seq_len
41
+ self.shard_paths = sorted(self.root.glob("*.pt"))
42
+ if not self.shard_paths:
43
+ raise FileNotFoundError(f"No .pt shards found in {dataset_dir}")
44
+ self.shard_sizes: List[int] = []
45
+ self.cumulative_sizes: List[int] = []
46
+ total = 0
47
+ self._cached_shard_path: Optional[Path] = None
48
+ self._cached_shard_tensor: Optional[torch.Tensor] = None
49
+ for shard_path, shard_len in self._load_or_build_shard_index():
50
+ total += shard_len
51
+ self.shard_sizes.append(shard_len)
52
+ self.cumulative_sizes.append(total)
53
+
54
+ def _shard_index_path(self) -> Path:
55
+ return self.root / SHARD_INDEX_FILENAME
56
+
57
+ def _read_json_file(self, path: Path) -> Dict[str, object]:
58
+ try:
59
+ return json.loads(path.read_text(encoding="utf-8"))
60
+ except (OSError, json.JSONDecodeError):
61
+ return {}
62
+
63
+ def _extract_index_entries(self, payload: Dict[str, object]) -> Optional[List[Tuple[Path, int]]]:
64
+ shard_entries = payload.get("shards")
65
+ if not isinstance(shard_entries, list):
66
+ return None
67
+ lengths_by_name: Dict[str, int] = {}
68
+ for entry in shard_entries:
69
+ if not isinstance(entry, dict):
70
+ return None
71
+ file_name = entry.get("file")
72
+ length = entry.get("length")
73
+ if not isinstance(file_name, str) or not isinstance(length, int):
74
+ return None
75
+ lengths_by_name[file_name] = length
76
+ resolved: List[Tuple[Path, int]] = []
77
+ for shard_path in self.shard_paths:
78
+ length = lengths_by_name.get(shard_path.name)
79
+ if length is None:
80
+ return None
81
+ resolved.append((shard_path, length))
82
+ return resolved
83
+
84
+ def _load_cached_index(self) -> Optional[List[Tuple[Path, int]]]:
85
+ for candidate in [self._shard_index_path(), self.root / "metadata.json"]:
86
+ if not candidate.exists():
87
+ continue
88
+ resolved = self._extract_index_entries(self._read_json_file(candidate))
89
+ if resolved is not None:
90
+ print(
91
+ json.dumps(
92
+ {
93
+ "event": "dataset_index_loaded",
94
+ "dataset_dir": str(self.root),
95
+ "source": candidate.name,
96
+ "shards": len(resolved),
97
+ "samples": sum(length for _, length in resolved),
98
+ }
99
+ ),
100
+ flush=True,
101
+ )
102
+ return resolved
103
+ return None
104
+
105
+ def _infer_shard_len(self, shard_path: Path) -> int:
106
+ shard = torch.load(shard_path, map_location="cpu")
107
+ if isinstance(shard, torch.Tensor):
108
+ if shard.ndim == 2:
109
+ return int(shard.size(0))
110
+ return 1
111
+ if isinstance(shard, list):
112
+ return len(shard)
113
+ raise TypeError(f"Unsupported shard format in {shard_path}")
114
+
115
+ def _write_cached_index(self, entries: List[Tuple[Path, int]]) -> None:
116
+ payload = {
117
+ "shards": [{"file": path.name, "length": length} for path, length in entries],
118
+ "total_samples": sum(length for _, length in entries),
119
+ }
120
+ self._shard_index_path().write_text(json.dumps(payload, indent=2), encoding="utf-8")
121
+
122
+ def _load_or_build_shard_index(self) -> List[Tuple[Path, int]]:
123
+ cached = self._load_cached_index()
124
+ if cached is not None:
125
+ return cached
126
+ print(
127
+ json.dumps(
128
+ {
129
+ "event": "dataset_index_build_start",
130
+ "dataset_dir": str(self.root),
131
+ "shards": len(self.shard_paths),
132
+ }
133
+ ),
134
+ flush=True,
135
+ )
136
+ entries: List[Tuple[Path, int]] = []
137
+ running_total = 0
138
+ for shard_idx, shard_path in enumerate(self.shard_paths, start=1):
139
+ shard_len = self._infer_shard_len(shard_path)
140
+ entries.append((shard_path, shard_len))
141
+ running_total += shard_len
142
+ if shard_idx % SHARD_INDEX_PROGRESS_EVERY == 0 or shard_idx == len(self.shard_paths):
143
+ print(
144
+ json.dumps(
145
+ {
146
+ "event": "dataset_index_build_progress",
147
+ "dataset_dir": str(self.root),
148
+ "indexed_shards": shard_idx,
149
+ "total_shards": len(self.shard_paths),
150
+ "samples": running_total,
151
+ }
152
+ ),
153
+ flush=True,
154
+ )
155
+ self._write_cached_index(entries)
156
+ print(
157
+ json.dumps(
158
+ {
159
+ "event": "dataset_index_build_done",
160
+ "dataset_dir": str(self.root),
161
+ "shards": len(entries),
162
+ "samples": running_total,
163
+ }
164
+ ),
165
+ flush=True,
166
+ )
167
+ return entries
168
+
169
+ def __len__(self) -> int:
170
+ return self.cumulative_sizes[-1]
171
+
172
+ def _load_shard(self, shard_idx: int) -> torch.Tensor:
173
+ shard_path = self.shard_paths[shard_idx]
174
+ if self._cached_shard_path == shard_path and self._cached_shard_tensor is not None:
175
+ return self._cached_shard_tensor
176
+ shard = torch.load(shard_path, map_location="cpu")
177
+ if isinstance(shard, list):
178
+ shard = torch.stack([torch.as_tensor(item, dtype=torch.long) for item in shard], dim=0)
179
+ elif isinstance(shard, torch.Tensor):
180
+ if shard.ndim == 1:
181
+ shard = shard.unsqueeze(0)
182
+ else:
183
+ raise TypeError(f"Unsupported shard format in {shard_path}")
184
+ self._cached_shard_path = shard_path
185
+ self._cached_shard_tensor = shard
186
+ return shard
187
+
188
+ def __getitem__(self, idx: int) -> torch.Tensor:
189
+ if idx < 0:
190
+ idx += len(self)
191
+ shard_idx = bisect.bisect_right(self.cumulative_sizes, idx)
192
+ shard_start = 0 if shard_idx == 0 else self.cumulative_sizes[shard_idx - 1]
193
+ item_idx = idx - shard_start
194
+ tokens = self._load_shard(shard_idx)[item_idx].to(dtype=torch.long)
195
+ if tokens.numel() < 2:
196
+ padded = torch.zeros(2, dtype=torch.long)
197
+ padded[: tokens.numel()] = tokens
198
+ tokens = padded
199
+ return tokens[: self.max_seq_len + 1]
200
+
201
+
202
+ class SyntheticTokenDataset(Dataset):
203
+ def __init__(self, vocab_size: int, max_seq_len: int, num_samples: int = 128):
204
+ self.vocab_size = vocab_size
205
+ self.max_seq_len = max_seq_len
206
+ self.num_samples = num_samples
207
+
208
+ def __len__(self) -> int:
209
+ return self.num_samples
210
+
211
+ def __getitem__(self, idx: int) -> torch.Tensor:
212
+ return torch.randint(0, self.vocab_size, (self.max_seq_len + 1,), dtype=torch.long)
213
+
214
+
215
+ class LayerWiseSGD(torch.optim.Optimizer):
216
+ def __init__(self, params: Iterable[torch.nn.Parameter], lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0.0):
217
+ defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
218
+ super().__init__(params, defaults)
219
+
220
+ @torch.no_grad()
221
+ def step(self, closure=None):
222
+ loss = None
223
+ if closure is not None:
224
+ with torch.enable_grad():
225
+ loss = closure()
226
+ for group in self.param_groups:
227
+ lr = group["lr"]
228
+ momentum = group["momentum"]
229
+ weight_decay = group["weight_decay"]
230
+ params_with_grad = [p for p in group["params"] if p.grad is not None]
231
+ if not params_with_grad:
232
+ continue
233
+ device = params_with_grad[0].device
234
+ mean_grad_sq = torch.zeros((), device=device)
235
+ counted = 0
236
+ for p in params_with_grad:
237
+ grad = p.grad
238
+ if weight_decay != 0:
239
+ grad = grad.add(p, alpha=weight_decay)
240
+ mean_grad_sq = mean_grad_sq + grad.pow(2).mean()
241
+ counted += 1
242
+ mean_grad_sq = mean_grad_sq / max(1, counted)
243
+ velocity = group.get("layer_velocity")
244
+ if velocity is None:
245
+ velocity = torch.zeros((), device=device)
246
+ velocity = (momentum * velocity) + mean_grad_sq.sqrt()
247
+ group["layer_velocity"] = velocity
248
+ scale = lr / velocity.clamp(min=1e-8)
249
+ for p in params_with_grad:
250
+ grad = p.grad
251
+ if weight_decay != 0:
252
+ grad = grad.add(p, alpha=weight_decay)
253
+ p.add_(grad, alpha=-scale)
254
+ return loss
255
+
256
+
257
+ def _build_adafactor(params: Iterable[torch.nn.Parameter], cfg: TrainStackConfig):
258
+ if importlib.util.find_spec("transformers") is None:
259
+ return torch.optim.AdamW(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
260
+ transformers = __import__("transformers")
261
+ return transformers.Adafactor(
262
+ params,
263
+ lr=cfg.learning_rate,
264
+ relative_step=False,
265
+ scale_parameter=False,
266
+ warmup_init=False,
267
+ weight_decay=cfg.weight_decay,
268
+ )
269
+
270
+
271
+ def _build_adam8bit(params: Iterable[torch.nn.Parameter], cfg: TrainStackConfig):
272
+ if importlib.util.find_spec("bitsandbytes") is None:
273
+ return torch.optim.AdamW(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
274
+ bnb = __import__("bitsandbytes")
275
+ return bnb.optim.Adam8bit(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
276
+
277
+
278
+ def build_optimizer(model: torch.nn.Module, cfg: TrainStackConfig) -> torch.optim.Optimizer:
279
+ name = cfg.optimizer_name.lower()
280
+ if name == "auto":
281
+ name = choose_optimizer_backend(prefer_low_memory=True)
282
+ if name in {"adamw_fused", "fused_adamw"}:
283
+ if torch.cuda.is_available():
284
+ try:
285
+ return torch.optim.AdamW(
286
+ model.parameters(),
287
+ lr=cfg.learning_rate,
288
+ weight_decay=cfg.weight_decay,
289
+ fused=True,
290
+ )
291
+ except TypeError:
292
+ pass
293
+ return torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
294
+ if name == "adafactor":
295
+ return _build_adafactor(model.parameters(), cfg)
296
+ if name in {"adam8bit", "adam_8bit", "8bit-adam"}:
297
+ return _build_adam8bit(model.parameters(), cfg)
298
+ if name in {"layerwisesgd", "lowmemsgd", "sgd"}:
299
+ return LayerWiseSGD(model.parameters(), lr=cfg.learning_rate, momentum=0.9, weight_decay=cfg.weight_decay)
300
+ return torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
301
+
302
+
303
+ def collate_token_batch(batch: List[torch.Tensor], fixed_length: Optional[int] = None) -> Dict[str, torch.Tensor]:
304
+ if fixed_length is not None and all(item.numel() >= fixed_length for item in batch):
305
+ stacked = torch.stack([item[:fixed_length] for item in batch], dim=0)
306
+ return {"input_ids": stacked[:, :-1], "target_ids": stacked[:, 1:]}
307
+ max_len = max(item.numel() for item in batch)
308
+ padded = torch.zeros((len(batch), max_len), dtype=torch.long)
309
+ targets = torch.full((len(batch), max_len - 1), -100, dtype=torch.long)
310
+ inputs = torch.zeros((len(batch), max_len - 1), dtype=torch.long)
311
+ for i, item in enumerate(batch):
312
+ padded[i, : item.numel()] = item
313
+ inputs[i, : item.numel() - 1] = item[:-1]
314
+ targets[i, : item.numel() - 1] = item[1:]
315
+ return {"input_ids": inputs, "target_ids": targets}
316
+
317
+
318
+ def build_dataset(dataset_dir: str, vocab_size: int, max_seq_len: int, synthetic_samples: int = 128) -> Dataset:
319
+ if dataset_dir:
320
+ return PretokenizedShardDataset(dataset_dir, max_seq_len=max_seq_len)
321
+ return SyntheticTokenDataset(vocab_size=vocab_size, max_seq_len=max_seq_len, num_samples=synthetic_samples)
322
+
323
+
324
+ def build_dataloader(dataset: Dataset, cfg: TrainStackConfig, shuffle: bool = True) -> DataLoader:
325
+ kwargs = dict(
326
+ batch_size=cfg.batch_size,
327
+ shuffle=shuffle,
328
+ num_workers=cfg.num_workers,
329
+ pin_memory=cfg.pin_memory,
330
+ persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
331
+ collate_fn=functools.partial(collate_token_batch, fixed_length=cfg.max_seq_len + 1),
332
+ )
333
+ if cfg.num_workers > 0:
334
+ kwargs["prefetch_factor"] = cfg.prefetch_factor
335
+ return DataLoader(dataset, **kwargs)
336
+
337
+
338
+ def move_batch_to_device(batch: Dict[str, torch.Tensor], device: torch.device, non_blocking: bool = True) -> Dict[str, torch.Tensor]:
339
+ return {key: value.to(device, non_blocking=non_blocking) for key, value in batch.items()}
340
+
341
+
342
+
343
+ def train_demo_steps(
344
+ model: torch.nn.Module,
345
+ optimizer: torch.optim.Optimizer,
346
+ dataloader: DataLoader,
347
+ device: torch.device,
348
+ steps: int = 2,
349
+ use_bf16: bool = True,
350
+ ) -> Tuple[float, int]:
351
+ model.train()
352
+ total_loss = 0.0
353
+ total_tokens = 0
354
+ autocast_enabled = use_bf16 and device.type == "cuda"
355
+ for step_idx, batch in enumerate(dataloader):
356
+ if step_idx >= steps:
357
+ break
358
+ batch = move_batch_to_device(batch, device)
359
+ optimizer.zero_grad(set_to_none=True)
360
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
361
+ loss = model.training_loss(batch["input_ids"], batch["target_ids"])
362
+ loss.backward()
363
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
364
+ optimizer.step()
365
+ total_loss += float(loss.detach().item())
366
+ total_tokens += int((batch["target_ids"] != -100).sum().item())
367
+ mean_loss = total_loss / max(1, steps)
368
+ return mean_loss, total_tokens
rubirlm_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 50257,
3
+ "max_seq_len": 2048,
4
+ "d_model": 1024,
5
+ "n_layers": 10,
6
+ "n_heads": 16,
7
+ "ff_mult": 4,
8
+ "dropout": 0.1,
9
+ "recurse_steps": 6,
10
+ "critique_threshold": 0.2,
11
+ "tie_embeddings": true,
12
+ "use_moe": true,
13
+ "moe_num_experts": 32,
14
+ "moe_top_k": 1,
15
+ "moe_expert_hidden": 1280,
16
+ "moe_router_jitter": 0.01,
17
+ "moe_aux_loss_weight": 0.01,
18
+ "use_layer_skip": true,
19
+ "layer_skip_threshold": 0.8,
20
+ "layer_skip_target": 0.03,
21
+ "layer_skip_aux_weight": 0.01,
22
+ "use_ternary_weights": true,
23
+ "use_flash_attention": true,
24
+ "use_fused_ops": true,
25
+ "packed_execution": true,
26
+ "use_torch_compile": false,
27
+ "moe_backend": "auto",
28
+ "moe_ep_size": 1
29
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "model_max_length": 1024,
9
+ "pad_token": "<|endoftext|>",
10
+ "tokenizer_class": "GPT2Tokenizer",
11
+ "unk_token": "<|endoftext|>"
12
+ }
training_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69d776a3d091911fba2c791b74eeacd1f75950da3ac35bfca956d3df1ec6b4ba
3
+ size 7747395326
x_quantum_sparse_ops.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from xqs_stack import choose_attention_backend, choose_quant_backend
10
+ from xqs_triton_ops import triton_ternary_linear
11
+
12
+
13
+ _HAS_FLASH_ATTN = importlib.util.find_spec("flash_attn") is not None
14
+ if _HAS_FLASH_ATTN:
15
+ from flash_attn import flash_attn_func
16
+
17
+ _ATTN_BACKEND = choose_attention_backend(prefer_flash=True)
18
+ _QUANT_BACKEND = choose_quant_backend(prefer_triton=True)
19
+
20
+
21
+ def ternary_quantize(weight: torch.Tensor) -> torch.Tensor:
22
+ scale = weight.detach().abs().mean().clamp(min=1e-6)
23
+ pos = weight > (0.5 * scale)
24
+ neg = weight < (-0.5 * scale)
25
+ quantized = torch.zeros_like(weight)
26
+ quantized = torch.where(pos, torch.ones_like(weight), quantized)
27
+ quantized = torch.where(neg, -torch.ones_like(weight), quantized)
28
+ quantized = quantized * scale
29
+ return weight + (quantized - weight).detach()
30
+
31
+
32
+ class TernaryLinear(nn.Module):
33
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
34
+ super().__init__()
35
+ self.in_features = in_features
36
+ self.out_features = out_features
37
+ self.backend = _QUANT_BACKEND
38
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
39
+ if bias:
40
+ self.bias = nn.Parameter(torch.empty(out_features))
41
+ else:
42
+ self.register_parameter("bias", None)
43
+ self.reset_parameters()
44
+
45
+ def reset_parameters(self) -> None:
46
+ nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
47
+ if self.bias is not None:
48
+ bound = 1 / max(1, self.in_features) ** 0.5
49
+ nn.init.uniform_(self.bias, -bound, bound)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ if self.backend == "triton":
53
+ return triton_ternary_linear(x, self.weight, self.bias)
54
+ return F.linear(x, ternary_quantize(self.weight), self.bias)
55
+
56
+
57
+ def build_linear(in_features: int, out_features: int, bias: bool = True, ternary: bool = False) -> nn.Module:
58
+ if ternary:
59
+ return TernaryLinear(in_features, out_features, bias=bias)
60
+ return nn.Linear(in_features, out_features, bias=bias)
61
+
62
+
63
+ def fused_residual_add(x: torch.Tensor, residual: torch.Tensor, gate: Optional[torch.Tensor] = None) -> torch.Tensor:
64
+ if gate is None:
65
+ return x + residual
66
+ return x + (gate * residual)
67
+
68
+
69
+ def causal_scaled_dot_product_attention(
70
+ q: torch.Tensor,
71
+ k: torch.Tensor,
72
+ v: torch.Tensor,
73
+ dropout_p: float = 0.0,
74
+ training: bool = False,
75
+ ) -> torch.Tensor:
76
+ if _ATTN_BACKEND == "flash_attn" and _HAS_FLASH_ATTN and q.is_cuda and q.dtype in {torch.float16, torch.bfloat16}:
77
+ q_flash = q.transpose(1, 2).contiguous()
78
+ k_flash = k.transpose(1, 2).contiguous()
79
+ v_flash = v.transpose(1, 2).contiguous()
80
+ out = flash_attn_func(
81
+ q_flash,
82
+ k_flash,
83
+ v_flash,
84
+ dropout_p=dropout_p if training else 0.0,
85
+ causal=True,
86
+ )
87
+ return out.transpose(1, 2).contiguous()
88
+
89
+ if hasattr(F, "scaled_dot_product_attention"):
90
+ return F.scaled_dot_product_attention(
91
+ q,
92
+ k,
93
+ v,
94
+ attn_mask=None,
95
+ dropout_p=dropout_p if training else 0.0,
96
+ is_causal=True,
97
+ )
98
+
99
+ scale = q.size(-1) ** -0.5
100
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
101
+ causal_mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1), device=scores.device, dtype=torch.bool), diagonal=1)
102
+ scores = scores.masked_fill(causal_mask, float("-inf"))
103
+ probs = torch.softmax(scores, dim=-1)
104
+ if training and dropout_p > 0:
105
+ probs = F.dropout(probs, p=dropout_p)
106
+ return torch.matmul(probs, v)
107
+
108
+
109
+ def pack_rows(indices: torch.Tensor, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
110
+ return tuple(t.index_select(0, indices) for t in tensors)
111
+
112
+
113
+ def scatter_rows(base: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor:
114
+ if indices.numel() == 0:
115
+ return base
116
+ out = base.clone()
117
+ out.index_copy_(0, indices, updates)
118
+ return out
119
+
120
+
121
+ def maybe_compile_module(module: nn.Module, enabled: bool) -> nn.Module:
122
+ if not enabled:
123
+ return module
124
+ compile_fn = getattr(torch, "compile", None)
125
+ if compile_fn is None:
126
+ return module
127
+ try:
128
+ return compile_fn(module)
129
+ except Exception:
130
+ return module
xqs_moe.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ _HAS_DEEPSPEED = importlib.util.find_spec("deepspeed") is not None
11
+ _DEEPSPEED_MOE_LAYER = None
12
+ _DEEPSPEED_IMPORT_ATTEMPTED = False
13
+ _DEEPSPEED_IMPORT_ERROR: Optional[str] = None
14
+
15
+
16
+ def _load_deepspeed_moe_layer():
17
+ global _DEEPSPEED_MOE_LAYER, _DEEPSPEED_IMPORT_ATTEMPTED, _DEEPSPEED_IMPORT_ERROR
18
+ if _DEEPSPEED_IMPORT_ATTEMPTED:
19
+ return _DEEPSPEED_MOE_LAYER
20
+ _DEEPSPEED_IMPORT_ATTEMPTED = True
21
+ if not _HAS_DEEPSPEED:
22
+ return None
23
+ try:
24
+ from deepspeed.moe.layer import MoE as deepspeed_moe_layer
25
+ except Exception as exc:
26
+ _DEEPSPEED_IMPORT_ERROR = str(exc)
27
+ _DEEPSPEED_MOE_LAYER = None
28
+ return None
29
+ _DEEPSPEED_MOE_LAYER = deepspeed_moe_layer
30
+ return _DEEPSPEED_MOE_LAYER
31
+
32
+
33
+ class DeepSpeedMoEWrapper(nn.Module):
34
+ def __init__(
35
+ self,
36
+ hidden_size: int,
37
+ expert: nn.Module,
38
+ num_experts: int,
39
+ top_k: int,
40
+ ep_size: int = 1,
41
+ ):
42
+ super().__init__()
43
+ deepspeed_moe_layer = _load_deepspeed_moe_layer()
44
+ if deepspeed_moe_layer is None:
45
+ details = f": {_DEEPSPEED_IMPORT_ERROR}" if _DEEPSPEED_IMPORT_ERROR else ""
46
+ raise RuntimeError(f"DeepSpeed MoE backend is not available{details}")
47
+ self.layer = deepspeed_moe_layer(
48
+ hidden_size=hidden_size,
49
+ expert=expert,
50
+ num_experts=num_experts,
51
+ ep_size=ep_size,
52
+ k=top_k,
53
+ use_residual=False,
54
+ )
55
+
56
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ out, aux_loss, _ = self.layer(x)
58
+ if isinstance(aux_loss, torch.Tensor):
59
+ return out, aux_loss
60
+ return out, x.new_zeros(())
61
+
62
+
63
+ def build_deepspeed_moe(
64
+ hidden_size: int,
65
+ expert: nn.Module,
66
+ num_experts: int,
67
+ top_k: int,
68
+ ep_size: int = 1,
69
+ ) -> Optional[DeepSpeedMoEWrapper]:
70
+ if _load_deepspeed_moe_layer() is None:
71
+ return None
72
+ return DeepSpeedMoEWrapper(
73
+ hidden_size=hidden_size,
74
+ expert=expert,
75
+ num_experts=num_experts,
76
+ top_k=top_k,
77
+ ep_size=ep_size,
78
+ )
79
+
80
+
81
+ def has_deepspeed_moe() -> bool:
82
+ return _load_deepspeed_moe_layer() is not None
xqs_stack.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ import shutil
5
+ from dataclasses import dataclass
6
+ from typing import Dict
7
+
8
+ import torch
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class XQSBackendReport:
13
+ torch_version: str
14
+ cuda_available: bool
15
+ cuda_device_name: str
16
+ bf16_supported: bool
17
+ torch_compile_available: bool
18
+ triton_available: bool
19
+ deepspeed_available: bool
20
+ bitsandbytes_available: bool
21
+ flash_attn_available: bool
22
+ nvcc_available: bool
23
+
24
+ def as_dict(self) -> Dict[str, object]:
25
+ return {
26
+ "torch_version": self.torch_version,
27
+ "cuda_available": self.cuda_available,
28
+ "cuda_device_name": self.cuda_device_name,
29
+ "bf16_supported": self.bf16_supported,
30
+ "torch_compile_available": self.torch_compile_available,
31
+ "triton_available": self.triton_available,
32
+ "deepspeed_available": self.deepspeed_available,
33
+ "bitsandbytes_available": self.bitsandbytes_available,
34
+ "flash_attn_available": self.flash_attn_available,
35
+ "nvcc_available": self.nvcc_available,
36
+ }
37
+
38
+
39
+ def _has_module(name: str) -> bool:
40
+ return importlib.util.find_spec(name) is not None
41
+
42
+
43
+
44
+ def detect_xqs_backends() -> XQSBackendReport:
45
+ cuda_available = torch.cuda.is_available()
46
+ device_name = torch.cuda.get_device_name(0) if cuda_available else "cpu"
47
+ bf16_supported = bool(cuda_available and torch.cuda.is_bf16_supported())
48
+ return XQSBackendReport(
49
+ torch_version=torch.__version__,
50
+ cuda_available=cuda_available,
51
+ cuda_device_name=device_name,
52
+ bf16_supported=bf16_supported,
53
+ torch_compile_available=hasattr(torch, "compile"),
54
+ triton_available=_has_module("triton"),
55
+ deepspeed_available=_has_module("deepspeed"),
56
+ bitsandbytes_available=_has_module("bitsandbytes"),
57
+ flash_attn_available=_has_module("flash_attn"),
58
+ nvcc_available=shutil.which("nvcc") is not None,
59
+ )
60
+
61
+
62
+
63
+ def choose_attention_backend(prefer_flash: bool = True) -> str:
64
+ report = detect_xqs_backends()
65
+ if prefer_flash and report.flash_attn_available and report.cuda_available:
66
+ return "flash_attn"
67
+ if report.cuda_available:
68
+ return "scaled_dot_product_attention"
69
+ return "eager"
70
+
71
+
72
+
73
+ def choose_optimizer_backend(prefer_low_memory: bool = True) -> str:
74
+ report = detect_xqs_backends()
75
+ adamw_signature = getattr(torch.optim.AdamW, "__init__", None)
76
+ fused_supported = bool(adamw_signature and "fused" in adamw_signature.__code__.co_varnames)
77
+ if report.cuda_available and fused_supported:
78
+ return "adamw_fused"
79
+ if prefer_low_memory and report.bitsandbytes_available:
80
+ return "adam8bit"
81
+ if _has_module("transformers"):
82
+ return "adafactor"
83
+ return "sgd"
84
+
85
+
86
+
87
+ def choose_moe_backend(prefer_deepspeed: bool = True) -> str:
88
+ report = detect_xqs_backends()
89
+ if prefer_deepspeed and report.deepspeed_available and report.cuda_available:
90
+ return "deepspeed"
91
+ return "native"
92
+
93
+
94
+
95
+ def choose_quant_backend(prefer_triton: bool = True) -> str:
96
+ report = detect_xqs_backends()
97
+ if prefer_triton and report.triton_available and report.cuda_available:
98
+ return "triton"
99
+ return "pytorch"
100
+
101
+
102
+
103
+ def format_backend_report(report: XQSBackendReport) -> str:
104
+ ordered = report.as_dict()
105
+ return "\n".join(f"{key}={value}" for key, value in ordered.items())