av-codes commited on
Commit
dbfd261
·
verified ·
1 Parent(s): 930beac

upload training script for HF Jobs benchmark

Browse files
Files changed (1) hide show
  1. train_hrm_text_pi.py +1177 -0
train_hrm_text_pi.py ADDED
@@ -0,0 +1,1177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ train_hrm_text_pi.py — Train an HRM-Text prompt injection detector
4
+ on the Bordair multimodal dataset with ~128k context support.
5
+
6
+ Architecture follows HRM-Text (sapientinc/HRM-Text, arXiv:2506.21734):
7
+ - ScaledEmbeddingInit: byte-level embedding with lecun scaling
8
+ - H module: high-level transformer stack (recurrent)
9
+ - L module: low-level transformer stack (recurrent)
10
+ - Recurrent loop: H_cycles × (L_cycles × L → H) cascade
11
+ - Classification head on final z_H (last-token pooling)
12
+ - RoPE with optional NTK-aware scaling (default 8k, configurable up to 128k)
13
+ - Backprop warmup: 2→5 recurrent steps over first 20% of training
14
+
15
+ Data: Bordair/bordair-multimodal (503K samples, balanced 1:1)
16
+ Target: ~36M parameters
17
+ """
18
+ import os
19
+ os.environ.setdefault("PYTORCH_HIP_ALLOC_CONF", "expandable_segments:True")
20
+
21
+ import argparse
22
+ import json
23
+ import math
24
+ import glob
25
+ import time
26
+ from collections import Counter
27
+
28
+ import datasets as hf_datasets
29
+ import evaluate
30
+ import numpy as np
31
+ import torch
32
+ torch.set_float32_matmul_precision('high')
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
36
+ from datasets import Dataset, concatenate_datasets
37
+ from huggingface_hub import snapshot_download, HfApi
38
+ from torch.utils.data import DataLoader
39
+ from transformers import (
40
+ PreTrainedTokenizerFast,
41
+ Trainer,
42
+ TrainingArguments,
43
+ set_seed,
44
+ )
45
+
46
+ # ═══════════════════════════════════════════════════════════════════════════════
47
+ # Rotary Embedding (NTK-aware scaling for 128k)
48
+ # ═══════════════════════════════════════════════════════════════════════════════
49
+
50
+ class RotaryEmbedding(nn.Module):
51
+ """RoPE with optional NTK-aware scaling for long context extension.
52
+
53
+ Standard RoPE: theta=10000.0, max_seq_len=4096.
54
+ NTK scaling: scale_factor = target_len / original_max_len.
55
+ Extends context by redistributing frequencies.
56
+ """
57
+
58
+ def __init__(self, dim, max_seq_len=4096, base=10000.0, scaling_factor=32.0):
59
+ super().__init__()
60
+ self.dim = dim
61
+ self.max_seq_len = max_seq_len
62
+ self.base = base
63
+ self.scaling_factor = scaling_factor
64
+
65
+ if scaling_factor > 1.0:
66
+ # NTK-aware scaling: adjust base instead of interpolating positions
67
+ ntk_base = base * scaling_factor ** (dim / (dim - 2))
68
+ else:
69
+ ntk_base = base
70
+
71
+ inv_freq = 1.0 / (ntk_base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
72
+ t = torch.arange(max_seq_len, dtype=torch.float32)
73
+ freqs = torch.outer(t, inv_freq.float())
74
+ emb = torch.cat((freqs, freqs), dim=-1)
75
+
76
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
77
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
78
+
79
+ def forward(self, position_ids):
80
+ cos = self.cos_cached[position_ids] # [B, L, dim]
81
+ sin = self.sin_cached[position_ids]
82
+ return cos, sin
83
+
84
+
85
+ def apply_rotary_pos_emb(x, cos, sin):
86
+ """Apply RoPE to tensor x [B, L, H, HD] using precomputed cos/sin."""
87
+ half = x.shape[-1] // 2
88
+ x_rot = x[..., :half].to(cos.dtype)
89
+ x_pass = x[..., half:].to(cos.dtype)
90
+ cos = cos[..., :half].unsqueeze(-2)
91
+ sin = sin[..., :half].unsqueeze(-2)
92
+ x_rot_out = x_rot * cos + rotate_half(x_rot) * sin
93
+ return torch.cat([x_rot_out, x_pass], dim=-1)
94
+
95
+
96
+ def rotate_half(x):
97
+ x1 = x[..., :x.shape[-1] // 2]
98
+ x2 = x[..., x.shape[-1] // 2:]
99
+ return torch.cat((-x2, x1), dim=-1)
100
+
101
+
102
+ # ═══════════════════════════════════════════════════════════════════════════════
103
+ # Initialization helpers
104
+ # ═══════════════════════════════════════════════════════════════════════════════
105
+
106
+ def trunc_normal_init_(tensor, std=1.0):
107
+ """Truncated normal approximation via 3-sigma clamping."""
108
+ with torch.no_grad():
109
+ return tensor.normal_().fmod_(3.0).mul_(1.014762601732121 * std)
110
+
111
+
112
+ class LinearInit(nn.Linear):
113
+ """Linear layer with lecun-normal-like init."""
114
+
115
+ def __init__(self, in_features, out_features, bias=True, init_std=None):
116
+ super().__init__(in_features, out_features, bias=bias)
117
+ if init_std is None:
118
+ init_std = 1.0 / math.sqrt(in_features)
119
+ trunc_normal_init_(self.weight, std=init_std)
120
+ if self.bias is not None:
121
+ nn.init.zeros_(self.bias)
122
+
123
+
124
+ # ════��══════════════════════════════════════════════════════════════════════════
125
+ # Scaled Embedding
126
+ # ═══════════════════════════════════════════════════════════════════════════════
127
+
128
+ class ScaledEmbeddingInit(nn.Embedding):
129
+ """Embedding with lecun-normal scaling (HRM-Text §2.1)."""
130
+
131
+ def __init__(self, vocab_size, d_model, padding_idx=0, init_std=None):
132
+ super().__init__(vocab_size, d_model, padding_idx=padding_idx)
133
+ if init_std is None:
134
+ init_std = 1.0 / math.sqrt(d_model)
135
+ trunc_normal_init_(self.weight, std=init_std)
136
+ with torch.no_grad():
137
+ if padding_idx is not None:
138
+ self.weight[padding_idx].zero_()
139
+ self.scale = 1.0 / init_std if init_std > 0 else 1.0
140
+
141
+ def forward(self, input_ids):
142
+ return super().forward(input_ids) * self.scale
143
+
144
+
145
+ # ═══════════════════════════════════════════════════════════════════════════════
146
+ # Gated Attention (HRM-Text sigmoid-gated MHA)
147
+ # ═══════════════════════════════════════════════════════════════════════════════
148
+
149
+ class GatedAttention(nn.Module):
150
+ """Sigmoid-gated multi-head attention with RoPE.
151
+
152
+ Single projection: gate + q + k + v → split → gate(sigmoid) × attn → o_proj
153
+ """
154
+
155
+ def __init__(self, hidden_size, num_heads, head_dim, init_std):
156
+ super().__init__()
157
+ self.hidden_size = hidden_size
158
+ self.num_heads = num_heads
159
+ self.head_dim = head_dim
160
+
161
+ # Single projection for gate (G), query (Q), key (K), value (V)
162
+ # G: num_heads, Q: num_heads, K: num_heads, V: num_heads
163
+ n_gqkv = num_heads * 4 # G+Q+K+V
164
+ self.gqkv_proj = LinearInit(
165
+ hidden_size, head_dim * n_gqkv,
166
+ bias=False, init_std=init_std,
167
+ )
168
+ self.o_proj = LinearInit(
169
+ head_dim * num_heads, hidden_size,
170
+ bias=False, init_std=init_std,
171
+ )
172
+
173
+ def forward(self, hidden_states, cos, sin):
174
+ B, L, D = hidden_states.shape
175
+
176
+ gqkv = self.gqkv_proj(hidden_states)
177
+ gqkv = gqkv.view(B, L, self.num_heads * 4, self.head_dim)
178
+
179
+ gate, query, key, value = gqkv.split(self.num_heads, dim=2)
180
+ # gate: [B, L, num_heads, head_dim]
181
+ # query, key, value: [B, L, num_heads, head_dim]
182
+
183
+ # RoPE on Q and K
184
+ query = apply_rotary_pos_emb(query, cos, sin)
185
+ key = apply_rotary_pos_emb(key, cos, sin)
186
+
187
+ # Transpose to [B, num_heads, L, head_dim] for SDPA
188
+ query = query.transpose(1, 2)
189
+ key = key.transpose(1, 2)
190
+ value = value.transpose(1, 2)
191
+
192
+ # SDPA with causal masking (flash attention backend, O(L) memory)
193
+ attn_output = F.scaled_dot_product_attention(
194
+ query, key, value,
195
+ is_causal=True,
196
+ )
197
+
198
+ # Gate: sigmoid(gate) × attn_output (elementwise)
199
+ gate = gate.transpose(1, 2) # [B, num_heads, L, head_dim]
200
+ attn_output = torch.sigmoid(gate) * attn_output
201
+
202
+ # Concatenate heads
203
+ attn_output = attn_output.transpose(1, 2).reshape(B, L, -1)
204
+ return self.o_proj(attn_output)
205
+
206
+
207
+ # ═══════════════════════════════════════════════════════════════════════════════
208
+ # SwiGLU FFN
209
+ # ═══════════════════════════════════════════════════════════════════════════════
210
+
211
+ class SwiGLU(nn.Module):
212
+ """SwiGLU feed-forward network: gate_up_proj → SiLU(gate)*up → down_proj.
213
+
214
+ gate_up_proj maps hidden_size → intermediate_size * 2 (gate + up).
215
+ """
216
+
217
+ def __init__(self, hidden_size, intermediate_size, init_std):
218
+ super().__init__()
219
+ self.gate_up_proj = LinearInit(
220
+ hidden_size, intermediate_size * 2,
221
+ bias=False, init_std=init_std,
222
+ )
223
+ self.down_proj = LinearInit(
224
+ intermediate_size, hidden_size,
225
+ bias=False, init_std=init_std,
226
+ )
227
+
228
+ def forward(self, x):
229
+ gate_up = self.gate_up_proj(x)
230
+ gate, up = gate_up.chunk(2, dim=-1)
231
+ return self.down_proj(F.silu(gate) * up)
232
+
233
+
234
+ # ══════════════════════════════════��════════════════════════════════════════════
235
+ # Transformer Block
236
+ # ═══════════════════════════════════════════════════════════════════════════════
237
+
238
+ class TransformerBlock(nn.Module):
239
+ """Pre-norm Transformer block: Attn → Residual → SwiGLU → Residual."""
240
+
241
+ def __init__(self, hidden_size, num_heads, head_dim, intermediate_size, init_std):
242
+ super().__init__()
243
+ self.attn_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
244
+ self.attn = GatedAttention(hidden_size, num_heads, head_dim, init_std)
245
+ self.ffn_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
246
+ self.ffn = SwiGLU(hidden_size, intermediate_size, init_std)
247
+
248
+ def forward(self, x, cos, sin):
249
+ # Pre-norm attention (causal via SDPA)
250
+ x = x + self.attn(self.attn_norm(x), cos, sin)
251
+ # Pre-norm FFN
252
+ x = x + self.ffn(self.ffn_norm(x))
253
+ return x
254
+
255
+
256
+ # ═══════════════════════════════════════════════════════════════════════════════
257
+ # Recurrent Module (H or L)
258
+ # ═══════════════════════════════════════════════════════════════════════════════
259
+
260
+ class RecurrentModule(nn.Module):
261
+ """A stack of TransformerBlocks used as one recurrent module (H or L).
262
+
263
+ In HRM-Text, each module is a full transformer stack. The module receives
264
+ its own hidden state + the other module's hidden state (via additive fusion).
265
+ """
266
+
267
+ def __init__(self, layers, hidden_size, num_heads, head_dim, intermediate_size, init_std, use_checkpoint=False):
268
+ super().__init__()
269
+ self.use_checkpoint = use_checkpoint
270
+ self.blocks = nn.ModuleList([
271
+ TransformerBlock(hidden_size, num_heads, head_dim, intermediate_size, init_std)
272
+ for _ in range(layers)
273
+ ])
274
+ self.final_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
275
+
276
+ def forward(self, z_self, z_other, cos, sin):
277
+ """Forward through all blocks with additive cross-module fusion.
278
+
279
+ Args:
280
+ z_self: This module's hidden state [B, L, D]
281
+ z_other: Other module's hidden state [B, L, D]
282
+ """
283
+ x = z_self + z_other # Additive fusion
284
+ for block in self.blocks:
285
+ if self.use_checkpoint and self.training:
286
+ x = torch_checkpoint(block, x, cos, sin, use_reentrant=True)
287
+ else:
288
+ x = block(x, cos, sin)
289
+ x = self.final_norm(x)
290
+ return x
291
+
292
+
293
+ # ═══════════════════════════════════════════════════════════════════════════════
294
+ # HRM-Text Classifier
295
+ # ═══════════════════════════════════════════════════════════════════════════════
296
+
297
+ class HrmTextClassifier(nn.Module):
298
+ """
299
+ HRM-Text adapted for binary classification (prompt injection detection).
300
+
301
+ Architecture:
302
+ f_emb: ScaledEmbeddingInit (byte-level, vocab=256)
303
+ L_module: RecurrentModule (low-level transformer stack)
304
+ H_module: RecurrentModule (high-level transformer stack)
305
+ f_cls: Classification head (LayerNorm → Linear)
306
+
307
+ Recurrent loop (H_cycles × L_cycles):
308
+ z_L = L_module(z_L, z_H, cos, sin)
309
+ z_H = H_module(z_H, z_L, cos, sin)
310
+
311
+ Final: logits = f_cls(z_H[:, -1, :]) # last-token pooling for classification
312
+
313
+ Backprop warmup: gradient-track only the last bp_steps recurrent steps.
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ vocab_size=256,
319
+ hidden_size=768,
320
+ num_heads=12,
321
+ head_dim=64,
322
+ n_layers_H=3,
323
+ n_layers_L=3,
324
+ intermediate_size=2048,
325
+ H_cycles=2,
326
+ L_cycles=3,
327
+ max_seq_len=4096,
328
+ rope_base=10000.0,
329
+ rope_scaling_factor=32.0,
330
+ num_classes=2,
331
+ bp_min_steps=2,
332
+ bp_max_steps=5,
333
+ bp_warmup_ratio=0.2,
334
+ use_gradient_checkpointing=False,
335
+ ):
336
+ super().__init__()
337
+ self.hidden_size = hidden_size
338
+ self.H_cycles = H_cycles
339
+ self.L_cycles = L_cycles
340
+ self.bp_min_steps = bp_min_steps
341
+ self.bp_max_steps = bp_max_steps
342
+ self.bp_warmup_ratio = bp_warmup_ratio
343
+ self.total_steps = H_cycles * L_cycles # total recurrent steps
344
+
345
+ init_std = 1.0 / math.sqrt(hidden_size) # lecun-normal std
346
+
347
+ # Token embedding (byte-level)
348
+ self.embed = ScaledEmbeddingInit(
349
+ vocab_size, hidden_size, padding_idx=0,
350
+ init_std=init_std,
351
+ )
352
+
353
+ # z_L initial state (learned buffer)
354
+ self.zL_init = nn.Parameter(torch.zeros(1, 1, hidden_size))
355
+
356
+ # Rotary embeddings (NTK-scaled for 128k)
357
+ self.rotary = RotaryEmbedding(
358
+ dim=head_dim,
359
+ max_seq_len=max_seq_len,
360
+ base=rope_base,
361
+ scaling_factor=rope_scaling_factor,
362
+ )
363
+
364
+ # Recurrent modules
365
+ self.L_module = RecurrentModule(
366
+ layers=n_layers_L,
367
+ hidden_size=hidden_size,
368
+ num_heads=num_heads,
369
+ head_dim=head_dim,
370
+ intermediate_size=intermediate_size,
371
+ init_std=init_std,
372
+ use_checkpoint=use_gradient_checkpointing,
373
+ )
374
+ self.H_module = RecurrentModule(
375
+ layers=n_layers_H,
376
+ hidden_size=hidden_size,
377
+ num_heads=num_heads,
378
+ head_dim=head_dim,
379
+ intermediate_size=intermediate_size,
380
+ init_std=init_std,
381
+ use_checkpoint=use_gradient_checkpointing,
382
+ )
383
+
384
+ # Classification head (on last token of z_H)
385
+ self.classifier = nn.Sequential(
386
+ nn.LayerNorm(hidden_size),
387
+ nn.Linear(hidden_size, num_classes),
388
+ )
389
+
390
+ self._init_weights()
391
+
392
+ def _init_weights(self):
393
+ # zL_init small
394
+ nn.init.zeros_(self.zL_init)
395
+ # Classifier init
396
+ for layer in self.classifier:
397
+ if isinstance(layer, nn.Linear):
398
+ trunc_normal_init_(layer.weight, std=0.02)
399
+ if layer.bias is not None:
400
+ nn.init.zeros_(layer.bias)
401
+
402
+ def _get_bp_steps(self, training_step_ratio=1.0):
403
+ """Compute number of backprop steps (warmup from bp_min to bp_max)."""
404
+ if training_step_ratio >= 1.0:
405
+ return min(self.bp_max_steps, self.total_steps)
406
+ warmup_progress = min(1.0, training_step_ratio / self.bp_warmup_ratio)
407
+ bp = self.bp_min_steps + warmup_progress * (self.bp_max_steps - self.bp_min_steps)
408
+ return min(int(bp), self.total_steps)
409
+
410
+ def forward(self, input_ids, attention_mask=None, labels=None, training_step_ratio=None):
411
+ """
412
+ Args:
413
+ input_ids: [B, L] byte token IDs
414
+ attention_mask: [B, L] 1=valid, 0=padding
415
+ labels: [B] binary labels (0=safe, 1=injection)
416
+ training_step_ratio: float in [0, 1] for BP warmup scheduling
417
+ Returns:
418
+ dict with logits and optional loss
419
+ """
420
+ B, L = input_ids.shape
421
+ device = input_ids.device
422
+
423
+ if attention_mask is None:
424
+ attention_mask = (input_ids != 0).long()
425
+
426
+ # Position IDs
427
+ position_ids = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
428
+ # Apply attention mask to position IDs (positions after padding clamped)
429
+ position_ids = position_ids * attention_mask
430
+
431
+ # ── Embedding ──
432
+ z_H = self.embed(input_ids) # [B, L, D]
433
+ z_L = self.zL_init.expand(B, L, -1) # [B, L, D]
434
+
435
+ # ── RoPE ──
436
+ cos, sin = self.rotary(position_ids) # [B, L, head_dim]
437
+
438
+ # ── Attention: use is_causal=True with flash attention ──
439
+ # At 128k context, explicit attention masks are prohibitively large
440
+ # (17B elements = 34GB). Causal flash attention uses O(L) memory.
441
+ # Padding tokens after the sequence naturally don't affect the
442
+ # last valid token's representation under causal masking.
443
+
444
+ # ── BP warmup ──
445
+ bp_steps = self._get_bp_steps(training_step_ratio or 1.0)
446
+ H_bp = min(self.H_cycles, max(1, bp_steps - 1))
447
+ L_bp = max(0, bp_steps - H_bp)
448
+ # Map: last L_bp L-steps across all cycles, last H_bp H-steps
449
+ # Each H-cycle has L_cycles L-steps inside it
450
+ total_L_steps = self.H_cycles * self.L_cycles
451
+ total_H_steps = self.H_cycles
452
+
453
+ # ── Recurrent loop ──
454
+ # BP warmup: block gradient flow through early steps via .detach()
455
+ # instead of torch.set_grad_enabled (incompatible with torch.compile)
456
+ step_idx = 0
457
+ for i in range(self.H_cycles):
458
+ for k in range(self.L_cycles):
459
+ grad_enabled = (step_idx >= total_L_steps - L_bp)
460
+ z_L = self.L_module(
461
+ z_L if grad_enabled else z_L.detach(),
462
+ z_H if grad_enabled else z_H.detach(),
463
+ cos, sin,
464
+ )
465
+ step_idx += 1
466
+
467
+ H_grad_enabled = (i >= self.H_cycles - H_bp)
468
+ z_H = self.H_module(
469
+ z_H if H_grad_enabled else z_H.detach(),
470
+ z_L if H_grad_enabled else z_L.detach(),
471
+ cos, sin,
472
+ )
473
+
474
+ # ── Classification: pool from the last valid token of each sequence ──
475
+ # Use last-token pooling: grab the final non-padding token's representation
476
+ # Find last valid position
477
+ seq_lengths = attention_mask.sum(dim=1).long() # [B]
478
+ last_token_indices = (seq_lengths - 1).clamp(min=0) # [B]
479
+ batch_indices = torch.arange(B, device=device)
480
+ pooled = z_H[batch_indices, last_token_indices, :] # [B, D]
481
+
482
+ logits = self.classifier(pooled) # [B, num_classes]
483
+
484
+ loss = None
485
+ if labels is not None:
486
+ loss = F.cross_entropy(logits, labels)
487
+
488
+ return {"logits": logits, "loss": loss}
489
+
490
+ @torch.no_grad()
491
+ def inference(self, input_ids, attention_mask=None):
492
+ """Inference-only forward (all steps with gradients disabled)."""
493
+ self.eval()
494
+ B, L = input_ids.shape
495
+ device = input_ids.device
496
+
497
+ if attention_mask is None:
498
+ attention_mask = (input_ids != 0).long()
499
+
500
+ position_ids = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
501
+ position_ids = position_ids * attention_mask
502
+
503
+ z_H = self.embed(input_ids)
504
+ z_L = self.zL_init.expand(B, L, -1)
505
+ cos, sin = self.rotary(position_ids)
506
+
507
+ for _ in range(self.H_cycles):
508
+ for _ in range(self.L_cycles):
509
+ z_L = self.L_module(z_L, z_H, cos, sin)
510
+ z_H = self.H_module(z_H, z_L, cos, sin)
511
+
512
+ seq_lengths = attention_mask.sum(dim=1).long()
513
+ last_token_indices = (seq_lengths - 1).clamp(min=0)
514
+ pooled = z_H[torch.arange(B, device=device), last_token_indices, :]
515
+ logits = self.classifier(pooled)
516
+ return logits
517
+
518
+
519
+ # ═══════════════════════════════════════════════════════════════════════════════
520
+ # Data pipeline — Bordair multimodal loader
521
+ # ═══════════════════════════════════════════════════════════════════════════════
522
+
523
+ def load_bordair_multimodal(cache_dir=None, max_samples=None):
524
+ """Load the full Bordair multimodal dataset from HF Hub.
525
+
526
+ The dataset is stored as raw JSON arrays (not HF Dataset format).
527
+ We snapshot_download the repo and read all JSON files manually.
528
+
529
+ Returns:
530
+ Dataset with columns: "text" (concatenated modalities), "label" (0/1)
531
+ """
532
+ print("📦 Downloading Bordair/bordair-multimodal from HF Hub...")
533
+ path = snapshot_download(
534
+ repo_id="Bordair/bordair-multimodal",
535
+ repo_type="dataset",
536
+ cache_dir=cache_dir,
537
+ )
538
+ print(f" Downloaded to: {path}")
539
+
540
+ all_samples = []
541
+
542
+ # Pattern: collect all JSON files, skip summary/pool metadata
543
+ dir_patterns = [
544
+ "benign/*.json",
545
+ "payloads/*/*.json",
546
+ "payloads_v5/*.json",
547
+ "payloads_v5_external/*/*.json",
548
+ ]
549
+
550
+ for pattern in dir_patterns:
551
+ files = sorted(glob.glob(os.path.join(path, pattern)))
552
+ for f in files:
553
+ fname = os.path.basename(f)
554
+ if fname in ("summary.json", "_pool.json", "summary_old.json"):
555
+ continue
556
+
557
+ try:
558
+ with open(f, "r", encoding="utf-8") as fh:
559
+ data = json.load(fh)
560
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
561
+ print(f" ⚠️ Skipping {f}: {e}")
562
+ continue
563
+
564
+ if not isinstance(data, list):
565
+ continue
566
+
567
+ for item in data:
568
+ if not isinstance(item, dict):
569
+ continue
570
+ all_samples.append(item)
571
+
572
+ print(f" {pattern}: {len(all_samples)} cumulative")
573
+
574
+ print(f"\n✅ Total raw samples loaded: {len(all_samples)}")
575
+
576
+ # Build unified dataset
577
+ # Concatenate all text fields: text + image_content + document_content + audio_content
578
+ rows = []
579
+ labels = []
580
+ skipped = 0
581
+ for item in all_samples:
582
+ # Get expected_detection (the boolean label)
583
+ label_val = item.get("expected_detection")
584
+ if label_val is None:
585
+ skipped += 1
586
+ continue
587
+
588
+ text_parts = []
589
+ if item.get("text"):
590
+ text_parts.append(item["text"])
591
+ if item.get("image_content"):
592
+ text_parts.append(item["image_content"])
593
+ if item.get("document_content"):
594
+ text_parts.append(item["document_content"])
595
+ if item.get("audio_content"):
596
+ text_parts.append(item["audio_content"])
597
+
598
+ combined = "\n".join(text_parts)
599
+
600
+ # Skip empty texts
601
+ if not combined.strip():
602
+ skipped += 1
603
+ continue
604
+
605
+ rows.append(combined)
606
+ labels.append(1 if label_val else 0)
607
+
608
+ if skipped:
609
+ print(f" Skipped {skipped} samples (missing label or empty text)")
610
+
611
+ # Convert to HF Dataset
612
+ ds = Dataset.from_dict({"text": rows, "label": labels})
613
+ print(f"✅ Dataset: {len(ds)} samples ({sum(labels)} injection, {len(labels) - sum(labels)} safe)")
614
+ return ds
615
+
616
+
617
+ def normalize_label(ex, label_col):
618
+ """Convert label to int64 0/1 (for compatibility with other datasets)."""
619
+ val = ex[label_col]
620
+ if isinstance(val, str):
621
+ return {label_col: 1 if val.lower() in ("malicious", "injection", "yes", "1") else 0}
622
+ return {label_col: int(val)}
623
+
624
+
625
+ # ═══════════════════════════════════════════════════════════════════════════════
626
+ # Byte-level tokenizer
627
+ # ═══════════════════════════════════════════════════════════════════════════════
628
+
629
+ class ByteTokenizer:
630
+ """Byte-level tokenizer: encodes strings as byte IDs [0-255].
631
+
632
+ Supports variable-length sequences — padded in collation, not here.
633
+ """
634
+
635
+ def __init__(self, max_length=131072):
636
+ self.max_length = max_length
637
+ self.pad_token_id = 0
638
+ self.eos_token_id = 0
639
+ self.pad_token = "<pad>"
640
+ self.eos_token = "<pad>"
641
+ self.vocab_size = 256
642
+
643
+ def __call__(self, text, truncation=True, max_length=None):
644
+ max_len = max_length or self.max_length
645
+ if isinstance(text, str):
646
+ byte_ids = list(text.encode("utf-8", errors="replace"))
647
+ else:
648
+ byte_ids = []
649
+ if truncation:
650
+ byte_ids = byte_ids[:max_len]
651
+ return byte_ids
652
+
653
+ def encode_batch(self, texts, max_length=None):
654
+ """Encode a batch of texts into variable-length byte ID lists."""
655
+ max_len = max_length or self.max_length
656
+ result = []
657
+ for text in texts:
658
+ if isinstance(text, str):
659
+ byte_ids = list(text.encode("utf-8", errors="replace"))
660
+ else:
661
+ byte_ids = []
662
+ if max_len:
663
+ byte_ids = byte_ids[:max_len]
664
+ result.append(byte_ids)
665
+ return result
666
+
667
+ def __len__(self):
668
+ return self.vocab_size
669
+
670
+
671
+ def collate_hrm_text(batch, max_length=131072):
672
+ """Collation for HRM-Text: variable-length byte sequences with padding.
673
+
674
+ Returns dict with input_ids, attention_mask, labels.
675
+ Sequences are padded to the max length in the batch (not to max_length).
676
+ """
677
+ texts = [ex["text"] for ex in batch]
678
+ labels = torch.tensor([ex["label"] for ex in batch], dtype=torch.long)
679
+
680
+ # Encode to byte IDs
681
+ all_ids = []
682
+ for t in texts:
683
+ ids = list(t.encode("utf-8", errors="replace")[:max_length])
684
+ all_ids.append(ids)
685
+
686
+ max_len_in_batch = max(len(ids) for ids in all_ids) if all_ids else 0
687
+ # Clamp to prevent excessive padding
688
+ max_len_in_batch = min(max_len_in_batch, max_length)
689
+
690
+ all_ids_padded = []
691
+ attention_masks = []
692
+ for ids in all_ids:
693
+ length = min(len(ids), max_length)
694
+ ids = ids[:length]
695
+ padded = ids + [0] * (max_len_in_batch - length)
696
+ mask = [1] * length + [0] * (max_len_in_batch - length)
697
+ all_ids_padded.append(padded)
698
+ attention_masks.append(mask)
699
+
700
+ return {
701
+ "input_ids": torch.tensor(all_ids_padded, dtype=torch.long),
702
+ "attention_mask": torch.tensor(attention_masks, dtype=torch.long),
703
+ "labels": labels,
704
+ }
705
+
706
+
707
+ # ═══════════════════════════════════════════════════════════════════════════════
708
+ # Custom Trainer
709
+ # ═══════════════════════════════════════════════════════════════════════════════
710
+
711
+ class HrmTextTrainer(Trainer):
712
+ """Trainer subclass that handles HRM-Text's custom forward signature
713
+ and BP warmup scheduling."""
714
+
715
+ def __init__(self, *args, total_training_steps=None, **kwargs):
716
+ super().__init__(*args, **kwargs)
717
+ self.total_training_steps = total_training_steps
718
+ self._current_step = 0
719
+
720
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
721
+ labels = inputs.pop("labels")
722
+ # Compute training step ratio for BP warmup
723
+ if self.total_training_steps and self.total_training_steps > 0:
724
+ step_ratio = min(1.0, self._current_step / self.total_training_steps)
725
+ else:
726
+ step_ratio = 1.0
727
+
728
+ outputs = model(**inputs, labels=labels, training_step_ratio=step_ratio)
729
+ loss = outputs["loss"]
730
+ self._current_step += 1
731
+ return (loss, outputs) if return_outputs else loss
732
+
733
+ def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
734
+ labels = inputs.pop("labels") if "labels" in inputs else None
735
+ with torch.no_grad():
736
+ logits = model.inference(**inputs)
737
+ if prediction_loss_only:
738
+ return (None, None, labels)
739
+ return (None, logits, labels)
740
+
741
+
742
+ # ═══════════════════════════════════════════════════════════════════════════════
743
+ # Parameter counting
744
+ # ═══════════════════════════════════════════════════════════════════════════════
745
+
746
+ def count_params(model):
747
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
748
+
749
+
750
+ # ═══════════════════════════════════════════════════════════════════════════════
751
+ # Main
752
+ # ═══════════════════════════════════════════════════════════════════════════════
753
+
754
+ def main():
755
+ parser = argparse.ArgumentParser(description="Train HRM-Text prompt injection detector")
756
+ parser.add_argument("--test", action="store_true", help="Smoke test on 64 samples")
757
+ parser.add_argument("--lr", type=float, default=5e-4)
758
+ parser.add_argument("--epochs", type=int, default=3)
759
+ parser.add_argument("--batch_size", type=int, default=32)
760
+ parser.add_argument("--output_dir", type=str, default="./pi-hrm-text")
761
+ parser.add_argument("--cpu", action="store_true")
762
+ parser.add_argument("--max_length", type=int, default=2048)
763
+ parser.add_argument("--hidden_size", type=int, default=768)
764
+ parser.add_argument("--num_heads", type=int, default=12)
765
+ parser.add_argument("--head_dim", type=int, default=64)
766
+ parser.add_argument("--n_layers_H", type=int, default=3)
767
+ parser.add_argument("--n_layers_L", type=int, default=3)
768
+ parser.add_argument("--intermediate_size", type=int, default=2048)
769
+ parser.add_argument("--H_cycles", type=int, default=2)
770
+ parser.add_argument("--L_cycles", type=int, default=3)
771
+ parser.add_argument("--rope_base", type=float, default=10000.0)
772
+ parser.add_argument("--rope_scaling", type=float, default=1.0)
773
+ parser.add_argument("--bp_min", type=int, default=2)
774
+ parser.add_argument("--bp_max", type=int, default=5)
775
+ parser.add_argument("--push_to_hub", type=str, default="av-codes/prompt-injection-hrm-text")
776
+ parser.add_argument("--hub_token", type=str, default=None)
777
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
778
+ parser.add_argument("--no_gradient_checkpointing", action="store_false", dest="gradient_checkpointing")
779
+ parser.add_argument("--seed", type=int, default=42)
780
+ parser.add_argument("--data_cache", type=str, default=None,
781
+ help="Cache dir for dataset download")
782
+ parser.add_argument("--max_steps", type=int, default=-1,
783
+ help="Max training steps (-1 = use epochs)")
784
+ args = parser.parse_args()
785
+
786
+ set_seed(args.seed)
787
+ use_cuda = torch.cuda.is_available() and not args.cpu
788
+ device = torch.device("cuda" if use_cuda else "cpu")
789
+ print(f"🖥️ Hardware: {'GPU' if use_cuda else 'CPU'}")
790
+ if use_cuda:
791
+ print(f" Device: {torch.cuda.get_device_name(0)}")
792
+ print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
793
+ print(f"📐 HRM-Text(H={args.n_layers_H}, L={args.n_layers_L}, "
794
+ f"d={args.hidden_size}, H_cycles={args.H_cycles}, L_cycles={args.L_cycles})")
795
+ print(f"📏 Max context: {args.max_length:,} tokens")
796
+ print(f"🔄 BP warmup: {args.bp_min}→{args.bp_max} steps")
797
+
798
+ # ── Load dataset ──────────────────────────────────────────────────────
799
+ print("\n📦 Loading Bordair multimodal dataset...")
800
+ # For test mode, load a small subset directly
801
+ if args.test:
802
+ # Small test with subset
803
+ all_parts = []
804
+ # Try loading just a few files for testing
805
+ test_path = args.data_cache or "/tmp/bordair_test"
806
+ if not os.path.exists(test_path):
807
+ snapshot_download(
808
+ repo_id="Bordair/bordair-multimodal",
809
+ repo_type="dataset",
810
+ cache_dir=args.data_cache,
811
+ local_dir=test_path,
812
+ local_dir_use_symlinks=False,
813
+ allow_patterns=["benign/text_only.json", "benign/multimodal_text_image.json",
814
+ "payloads/text_image/text_image_001.json"],
815
+ )
816
+ for f in glob.glob(f"{test_path}/**/*.json", recursive=True):
817
+ fname = os.path.basename(f)
818
+ if fname in ("summary.json", "_pool.json"):
819
+ continue
820
+ with open(f) as fh:
821
+ data = json.load(fh)
822
+ if isinstance(data, list):
823
+ for item in data:
824
+ if isinstance(item, dict) and item.get("expected_detection") is not None:
825
+ text_parts = [item.get("text", "")]
826
+ for k in ("image_content", "document_content", "audio_content"):
827
+ if item.get(k):
828
+ text_parts.append(item[k])
829
+ all_parts.append({
830
+ "text": "\n".join(text_parts),
831
+ "label": 1 if item["expected_detection"] else 0,
832
+ })
833
+ merged = Dataset.from_list(all_parts)
834
+ print(f" Test mode: {len(merged)} samples loaded")
835
+ else:
836
+ merged = load_bordair_multimodal(cache_dir=args.data_cache)
837
+
838
+ # ── Stratified 90/10 split ────────────────────────────────────────────
839
+ # Cast label to ClassLabel first
840
+ merged = merged.cast_column("label", hf_datasets.ClassLabel(names=["safe", "injection"]))
841
+
842
+ if args.test:
843
+ train_dataset = merged.select(range(min(64, len(merged))))
844
+ eval_dataset = merged.select(range(min(32, len(merged))))
845
+ else:
846
+ split = merged.train_test_split(
847
+ test_size=0.1, seed=args.seed, stratify_by_column="label",
848
+ )
849
+ train_dataset = split["train"]
850
+ eval_dataset = split["test"]
851
+
852
+ print(f"\n✅ Dataset: {len(merged)} total → {len(train_dataset)} train, {len(eval_dataset)} eval")
853
+ train_dist = Counter(train_dataset["label"])
854
+ eval_dist = Counter(eval_dataset["label"])
855
+ print(f" Train label dist: {dict(train_dist)}")
856
+ print(f" Eval label dist: {dict(eval_dist)}")
857
+
858
+ # ── Log token length statistics for context planning ──────────────────
859
+ train_lengths = [len(t.encode("utf-8", errors="replace")) for t in train_dataset["text"]]
860
+ print(f" Train text length stats: mean={np.mean(train_lengths):.0f}, "
861
+ f"median={np.median(train_lengths):.0f}, "
862
+ f"p95={np.percentile(train_lengths, 95):.0f}, "
863
+ f"max={max(train_lengths):,}")
864
+
865
+ # ── Build model ───────────────────────────────────────────────────────
866
+ model = HrmTextClassifier(
867
+ vocab_size=256,
868
+ hidden_size=args.hidden_size,
869
+ num_heads=args.num_heads,
870
+ head_dim=args.head_dim,
871
+ n_layers_H=args.n_layers_H,
872
+ n_layers_L=args.n_layers_L,
873
+ intermediate_size=args.intermediate_size,
874
+ H_cycles=args.H_cycles,
875
+ L_cycles=args.L_cycles,
876
+ max_seq_len=args.max_length,
877
+ rope_base=args.rope_base,
878
+ rope_scaling_factor=args.rope_scaling,
879
+ num_classes=2,
880
+ bp_min_steps=args.bp_min,
881
+ bp_max_steps=args.bp_max,
882
+ use_gradient_checkpointing=args.gradient_checkpointing,
883
+ )
884
+ param_count = count_params(model)
885
+ print(f"\n🧮 Model parameters: {param_count:,}")
886
+ if args.gradient_checkpointing:
887
+ print(" Gradient checkpointing: enabled")
888
+ if not args.test:
889
+ assert 15_000_000 <= param_count <= 55_000_000, \
890
+ f"Param count {param_count:,} outside target range [15M, 55M]"
891
+
892
+ if use_cuda:
893
+ model = model.cuda()
894
+
895
+ # ── Metrics ───────────────────────────────────────────────────────────
896
+ accuracy = evaluate.load("accuracy")
897
+ precision = evaluate.load("precision")
898
+ recall = evaluate.load("recall")
899
+ f1 = evaluate.load("f1")
900
+
901
+ def compute_metrics(eval_pred):
902
+ predictions, labels = eval_pred
903
+ preds = predictions.argmax(-1)
904
+ return {
905
+ "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
906
+ "precision": precision.compute(predictions=preds, references=labels, average="binary")["precision"],
907
+ "recall": recall.compute(predictions=preds, references=labels, average="binary")["recall"],
908
+ "f1": f1.compute(predictions=preds, references=labels, average="binary")["f1"],
909
+ }
910
+
911
+ # ── Estimate total training steps for BP warmup ───────────────────────
912
+ steps_per_epoch = max(1, len(train_dataset) // args.batch_size)
913
+ total_training_steps = steps_per_epoch * args.epochs
914
+ print(f"📊 Steps per epoch: {steps_per_epoch}, total: {total_training_steps}")
915
+
916
+ # ── Training args ────────────────────────────────���────────────────────
917
+ run_name = f"hrm-text-pi_d{args.hidden_size}_lr{args.lr}_ep{args.epochs}_bs{args.batch_size}"
918
+
919
+ training_args = TrainingArguments(
920
+ output_dir=args.output_dir,
921
+ run_name=run_name,
922
+ report_to="none",
923
+ learning_rate=args.lr,
924
+ per_device_train_batch_size=args.batch_size,
925
+ per_device_eval_batch_size=min(args.batch_size * 2, 16),
926
+ num_train_epochs=args.epochs,
927
+ max_steps=args.max_steps,
928
+ weight_decay=0.01,
929
+ warmup_steps=500 if not args.test else 0,
930
+ lr_scheduler_type="cosine",
931
+ eval_strategy="epoch",
932
+ save_strategy="epoch",
933
+ load_best_model_at_end=True,
934
+ metric_for_best_model="f1",
935
+ greater_is_better=True,
936
+ save_total_limit=2,
937
+ logging_strategy="steps",
938
+ logging_first_step=True,
939
+ logging_steps=5 if args.test else 20,
940
+ disable_tqdm=False if args.test else True,
941
+ fp16=use_cuda,
942
+ bf16=False,
943
+ push_to_hub=False,
944
+ hub_model_id=None,
945
+ use_cpu=not use_cuda,
946
+ dataloader_num_workers=4,
947
+ seed=args.seed,
948
+ save_only_model=True,
949
+ remove_unused_columns=False,
950
+ ddp_find_unused_parameters=True,
951
+ gradient_checkpointing=False,
952
+ )
953
+
954
+ # ── Data collator ─────────────────────────────────────────────────────
955
+ def collate_fn(batch):
956
+ return collate_hrm_text(batch, max_length=args.max_length)
957
+
958
+ # ── Trainer ───────────────────────────────────────────────────────────
959
+ trainer = HrmTextTrainer(
960
+ model=model,
961
+ args=training_args,
962
+ train_dataset=train_dataset,
963
+ eval_dataset=eval_dataset,
964
+ data_collator=collate_fn,
965
+ compute_metrics=compute_metrics,
966
+ total_training_steps=total_training_steps,
967
+ )
968
+
969
+ # ── Train ─────────────────────────────────────────────────────────────
970
+ print("\n🚀 Training...")
971
+ train_start = time.time()
972
+ trainer.train()
973
+ train_elapsed = time.time() - train_start
974
+ print(f"✅ Training complete! ({train_elapsed:.1f}s)")
975
+ print(f" Best checkpoint: {trainer.state.best_model_checkpoint}")
976
+
977
+ # ── Final evaluation ──────────────────────────────────────────────────
978
+ print("\n📊 Evaluating on eval set...")
979
+ eval_metrics = trainer.evaluate(eval_dataset)
980
+ print(f" Eval metrics: {json.dumps(eval_metrics, indent=2)}")
981
+
982
+ os.makedirs(args.output_dir, exist_ok=True)
983
+ eval_path = os.path.join(args.output_dir, "eval_metrics.json")
984
+ with open(eval_path, "w") as f:
985
+ json.dump(eval_metrics, f, indent=2)
986
+
987
+ # ── Save model locally ────────────────────────────────────────────────
988
+ best_model_path = os.path.join(args.output_dir, "best_model")
989
+ os.makedirs(best_model_path, exist_ok=True)
990
+ model_path = os.path.join(best_model_path, "hrm_text_model.pt")
991
+
992
+ # Try loading best checkpoint first
993
+ best_ckpt = trainer.state.best_model_checkpoint
994
+ if best_ckpt and os.path.isdir(best_ckpt):
995
+ print(f"\n💾 Loading best checkpoint from {best_ckpt}")
996
+ # The checkpoint is saved by trainer; load state dict from it
997
+ # The model in trainer might have DDP wrappers, get unwrapped
998
+ best_model = trainer.model
999
+ if hasattr(best_model, "module"):
1000
+ best_model = best_model.module
1001
+ torch.save(best_model.state_dict(), model_path)
1002
+ else:
1003
+ # Save final model
1004
+ best_model = trainer.model
1005
+ if hasattr(best_model, "module"):
1006
+ best_model = best_model.module
1007
+ torch.save(best_model.state_dict(), model_path)
1008
+ print(f"\n💾 Saved final model weights to {model_path}")
1009
+
1010
+ # Save config
1011
+ config = {
1012
+ "architecture": "HRM-Text (classification)",
1013
+ "reference": "sapientinc/HRM-Text, arXiv:2506.21734",
1014
+ "hidden_size": args.hidden_size,
1015
+ "num_heads": args.num_heads,
1016
+ "head_dim": args.head_dim,
1017
+ "n_layers_H": args.n_layers_H,
1018
+ "n_layers_L": args.n_layers_L,
1019
+ "intermediate_size": args.intermediate_size,
1020
+ "H_cycles": args.H_cycles,
1021
+ "L_cycles": args.L_cycles,
1022
+ "max_seq_len": args.max_length,
1023
+ "rope_base": args.rope_base,
1024
+ "rope_scaling": args.rope_scaling,
1025
+ "bp_min_steps": args.bp_min,
1026
+ "bp_max_steps": args.bp_max,
1027
+ "vocab_size": 256,
1028
+ "param_count": param_count,
1029
+ "id2label": {0: "safe", 1: "injection"},
1030
+ "label2id": {"safe": 0, "injection": 1},
1031
+ "training": {
1032
+ "learning_rate": args.lr,
1033
+ "epochs": args.epochs,
1034
+ "batch_size": args.batch_size,
1035
+ "weight_decay": 0.01,
1036
+ "scheduler": "cosine",
1037
+ "warmup_steps": 500 if not args.test else 0,
1038
+ },
1039
+ }
1040
+ with open(os.path.join(best_model_path, "config.json"), "w") as f:
1041
+ json.dump(config, f, indent=2)
1042
+ print(f" Saved config to {best_model_path}/config.json")
1043
+
1044
+ # ── Push to Hub ───────────────────────────────────────────────────────
1045
+ if args.push_to_hub:
1046
+ hub_model_id = args.push_to_hub
1047
+ api = HfApi(token=args.hub_token)
1048
+
1049
+ print(f"\n☁️ Pushing to Hub: {hub_model_id}")
1050
+
1051
+ # Create repo if needed
1052
+ try:
1053
+ api.create_repo(repo_id=hub_model_id, repo_type="model", private=False, exist_ok=True)
1054
+ print(f" Repo ready: {hub_model_id}")
1055
+ except Exception as e:
1056
+ print(f" ⚠️ Could not create repo: {e}")
1057
+
1058
+ # Upload model weights
1059
+ api.upload_file(
1060
+ path_or_fileobj=model_path,
1061
+ path_in_repo="pytorch_model.bin",
1062
+ repo_id=hub_model_id,
1063
+ repo_type="model",
1064
+ commit_message=f"HRM-Text prompt injection detector — F1={eval_metrics.get('eval_f1', 0):.4f}",
1065
+ )
1066
+
1067
+ # Upload config
1068
+ api.upload_file(
1069
+ path_or_fileobj=os.path.join(best_model_path, "config.json"),
1070
+ path_in_repo="config.json",
1071
+ repo_id=hub_model_id,
1072
+ repo_type="model",
1073
+ commit_message="Add model config",
1074
+ )
1075
+
1076
+ # Upload metrics
1077
+ api.upload_file(
1078
+ path_or_fileobj=eval_path,
1079
+ path_in_repo="eval_metrics.json",
1080
+ repo_id=hub_model_id,
1081
+ repo_type="model",
1082
+ commit_message="Add evaluation metrics",
1083
+ )
1084
+
1085
+ # Upload the training script
1086
+ script_path = os.path.abspath(__file__) if "__file__" in dir() else None
1087
+ if script_path and os.path.exists(script_path):
1088
+ api.upload_file(
1089
+ path_or_fileobj=script_path,
1090
+ path_in_repo="train_hrm_text_pi.py",
1091
+ repo_id=hub_model_id,
1092
+ repo_type="model",
1093
+ commit_message="Add training script",
1094
+ )
1095
+
1096
+ # Upload a README
1097
+ readme = f"""---
1098
+ license: mit
1099
+ tags:
1100
+ - prompt-injection
1101
+ - hrm-text
1102
+ - hierarchical-reasoning-model
1103
+ - bordair-multimodal
1104
+ - security
1105
+ ---
1106
+
1107
+ # HRM-Text Prompt Injection Detector
1108
+
1109
+ **Parameters:** {param_count:,}
1110
+ **Architecture:** HRM-Text (classification port) | d={args.hidden_size}, H={args.n_layers_H}, L={args.n_layers_L}, cycles={args.H_cycles}×{args.L_cycles}
1111
+ **Context window:** {args.max_length:,} tokens (NTK-scaled RoPE)
1112
+ **Training data:** Bordair/bordair-multimodal (503K samples, balanced 1:1)
1113
+
1114
+ Evaluation on stratified 10% holdout:
1115
+
1116
+ | Metric | Value |
1117
+ |--------|-------|
1118
+ | Accuracy | {eval_metrics.get('eval_accuracy', 0):.4f} |
1119
+ | Precision | {eval_metrics.get('eval_precision', 0):.4f} |
1120
+ | Recall | {eval_metrics.get('eval_recall', 0):.4f} |
1121
+ | F1 | {eval_metrics.get('eval_f1', 0):.4f} |
1122
+
1123
+ ## Architecture
1124
+
1125
+ HRM-Text (arXiv:2506.21734) with a classification head. The model uses a recurrent cascade of two transformer modules (H and L) that exchange information across cycles:
1126
+
1127
+ - **L module** ({args.n_layers_L} layers, low-level): processes detailed token patterns
1128
+ - **H module** ({args.n_layers_H} layers, high-level): integrates across cycles
1129
+ - **Recurrence**: {args.L_cycles} L-steps per H-cycle, {args.H_cycles} H-cycles total = {args.H_cycles * args.L_cycles} recurrent passes
1130
+ - **Classification**: last-token pooling + LayerNorm + Linear(2)
1131
+
1132
+ The byte-level tokenizer (vocab 256) handles any text encoding. RoPE uses NTK-aware scaling (θ={args.rope_base}, factor={args.rope_scaling}) for {args.max_length:,}-token context.
1133
+
1134
+ ## Usage
1135
+ ```python
1136
+ import torch
1137
+ from train_hrm_text_pi import HrmTextClassifier
1138
+
1139
+ model = HrmTextClassifier(
1140
+ hidden_size={args.hidden_size},
1141
+ num_heads={args.num_heads},
1142
+ head_dim={args.head_dim},
1143
+ n_layers_H={args.n_layers_H},
1144
+ n_layers_L={args.n_layers_L},
1145
+ )
1146
+ state_dict = torch.load("pytorch_model.bin", map_location="cpu")
1147
+ # Remove DDP wrapper keys if present
1148
+ state_dict = {{k.replace('module.', ''): v for k, v in state_dict.items()}}
1149
+ model.load_state_dict(state_dict)
1150
+ model.eval()
1151
+
1152
+ def detect(text, max_length=131072):
1153
+ byte_ids = list(text.encode("utf-8", errors="replace")[:max_length])
1154
+ input_ids = torch.tensor([byte_ids])
1155
+ attention_mask = torch.ones_like(input_ids)
1156
+ logits = model.inference(input_ids, attention_mask)
1157
+ pred = logits.argmax(-1).item() # 0=safe, 1=injection
1158
+ return pred
1159
+ ```
1160
+ """
1161
+ api.upload_file(
1162
+ path_or_fileobj=readme.encode(),
1163
+ path_in_repo="README.md",
1164
+ repo_id=hub_model_id,
1165
+ repo_type="model",
1166
+ commit_message="Add README",
1167
+ )
1168
+
1169
+ print(f"✅ https://huggingface.co/{hub_model_id}")
1170
+
1171
+ print("\n✅ Done!")
1172
+
1173
+
1174
+ if __name__ == "__main__":
1175
+ from multiprocessing import freeze_support
1176
+ freeze_support()
1177
+ main()