| |
| """ |
| train_hrm_text_pi.py — Train an HRM-Text prompt injection detector |
| on the Bordair multimodal dataset with ~128k context support. |
| |
| Architecture follows HRM-Text (sapientinc/HRM-Text, arXiv:2506.21734): |
| - ScaledEmbeddingInit: byte-level embedding with lecun scaling |
| - H module: high-level transformer stack (recurrent) |
| - L module: low-level transformer stack (recurrent) |
| - Recurrent loop: H_cycles × (L_cycles × L → H) cascade |
| - Classification head on final z_H (last-token pooling) |
| - RoPE with optional NTK-aware scaling (default 8k, configurable up to 128k) |
| - Backprop warmup: 2→5 recurrent steps over first 20% of training |
| |
| Data: Bordair/bordair-multimodal (503K samples, balanced 1:1) |
| Target: ~36M parameters |
| """ |
| import os |
| os.environ.setdefault("PYTORCH_HIP_ALLOC_CONF", "expandable_segments:True") |
|
|
| import argparse |
| import json |
| import math |
| import glob |
| import time |
| from collections import Counter |
|
|
| import datasets as hf_datasets |
| import evaluate |
| import numpy as np |
| import torch |
| torch.set_float32_matmul_precision('high') |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint as torch_checkpoint |
| from datasets import Dataset, concatenate_datasets |
| from huggingface_hub import snapshot_download, HfApi |
| from torch.utils.data import DataLoader |
| from transformers import ( |
| PreTrainedTokenizerFast, |
| Trainer, |
| TrainingArguments, |
| set_seed, |
| ) |
|
|
| |
| |
| |
|
|
| class RotaryEmbedding(nn.Module): |
| """RoPE with optional NTK-aware scaling for long context extension. |
| |
| Standard RoPE: theta=10000.0, max_seq_len=4096. |
| NTK scaling: scale_factor = target_len / original_max_len. |
| Extends context by redistributing frequencies. |
| """ |
|
|
| def __init__(self, dim, max_seq_len=4096, base=10000.0, scaling_factor=32.0): |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| self.base = base |
| self.scaling_factor = scaling_factor |
|
|
| if scaling_factor > 1.0: |
| |
| ntk_base = base * scaling_factor ** (dim / (dim - 2)) |
| else: |
| ntk_base = base |
|
|
| inv_freq = 1.0 / (ntk_base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| freqs = torch.outer(t, inv_freq.float()) |
| emb = torch.cat((freqs, freqs), dim=-1) |
|
|
| self.register_buffer("cos_cached", emb.cos(), persistent=False) |
| self.register_buffer("sin_cached", emb.sin(), persistent=False) |
|
|
| def forward(self, position_ids): |
| cos = self.cos_cached[position_ids] |
| sin = self.sin_cached[position_ids] |
| return cos, sin |
|
|
|
|
| def apply_rotary_pos_emb(x, cos, sin): |
| """Apply RoPE to tensor x [B, L, H, HD] using precomputed cos/sin.""" |
| half = x.shape[-1] // 2 |
| x_rot = x[..., :half].to(cos.dtype) |
| x_pass = x[..., half:].to(cos.dtype) |
| cos = cos[..., :half].unsqueeze(-2) |
| sin = sin[..., :half].unsqueeze(-2) |
| x_rot_out = x_rot * cos + rotate_half(x_rot) * sin |
| return torch.cat([x_rot_out, x_pass], dim=-1) |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., :x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| def trunc_normal_init_(tensor, std=1.0): |
| """Truncated normal approximation via 3-sigma clamping.""" |
| with torch.no_grad(): |
| return tensor.normal_().fmod_(3.0).mul_(1.014762601732121 * std) |
|
|
|
|
| class LinearInit(nn.Linear): |
| """Linear layer with lecun-normal-like init.""" |
|
|
| def __init__(self, in_features, out_features, bias=True, init_std=None): |
| super().__init__(in_features, out_features, bias=bias) |
| if init_std is None: |
| init_std = 1.0 / math.sqrt(in_features) |
| trunc_normal_init_(self.weight, std=init_std) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias) |
|
|
|
|
| |
| |
| |
|
|
| class ScaledEmbeddingInit(nn.Embedding): |
| """Embedding with lecun-normal scaling (HRM-Text §2.1).""" |
|
|
| def __init__(self, vocab_size, d_model, padding_idx=0, init_std=None): |
| super().__init__(vocab_size, d_model, padding_idx=padding_idx) |
| if init_std is None: |
| init_std = 1.0 / math.sqrt(d_model) |
| trunc_normal_init_(self.weight, std=init_std) |
| with torch.no_grad(): |
| if padding_idx is not None: |
| self.weight[padding_idx].zero_() |
| self.scale = 1.0 / init_std if init_std > 0 else 1.0 |
|
|
| def forward(self, input_ids): |
| return super().forward(input_ids) * self.scale |
|
|
|
|
| |
| |
| |
|
|
| class GatedAttention(nn.Module): |
| """Sigmoid-gated multi-head attention with RoPE. |
| |
| Single projection: gate + q + k + v → split → gate(sigmoid) × attn → o_proj |
| """ |
|
|
| def __init__(self, hidden_size, num_heads, head_dim, init_std): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
|
|
| |
| |
| n_gqkv = num_heads * 4 |
| self.gqkv_proj = LinearInit( |
| hidden_size, head_dim * n_gqkv, |
| bias=False, init_std=init_std, |
| ) |
| self.o_proj = LinearInit( |
| head_dim * num_heads, hidden_size, |
| bias=False, init_std=init_std, |
| ) |
|
|
| def forward(self, hidden_states, cos, sin): |
| B, L, D = hidden_states.shape |
|
|
| gqkv = self.gqkv_proj(hidden_states) |
| gqkv = gqkv.view(B, L, self.num_heads * 4, self.head_dim) |
|
|
| gate, query, key, value = gqkv.split(self.num_heads, dim=2) |
| |
| |
|
|
| |
| query = apply_rotary_pos_emb(query, cos, sin) |
| key = apply_rotary_pos_emb(key, cos, sin) |
|
|
| |
| query = query.transpose(1, 2) |
| key = key.transpose(1, 2) |
| value = value.transpose(1, 2) |
|
|
| |
| attn_output = F.scaled_dot_product_attention( |
| query, key, value, |
| is_causal=True, |
| ) |
|
|
| |
| gate = gate.transpose(1, 2) |
| attn_output = torch.sigmoid(gate) * attn_output |
|
|
| |
| attn_output = attn_output.transpose(1, 2).reshape(B, L, -1) |
| return self.o_proj(attn_output) |
|
|
|
|
| |
| |
| |
|
|
| class SwiGLU(nn.Module): |
| """SwiGLU feed-forward network: gate_up_proj → SiLU(gate)*up → down_proj. |
| |
| gate_up_proj maps hidden_size → intermediate_size * 2 (gate + up). |
| """ |
|
|
| def __init__(self, hidden_size, intermediate_size, init_std): |
| super().__init__() |
| self.gate_up_proj = LinearInit( |
| hidden_size, intermediate_size * 2, |
| bias=False, init_std=init_std, |
| ) |
| self.down_proj = LinearInit( |
| intermediate_size, hidden_size, |
| bias=False, init_std=init_std, |
| ) |
|
|
| def forward(self, x): |
| gate_up = self.gate_up_proj(x) |
| gate, up = gate_up.chunk(2, dim=-1) |
| return self.down_proj(F.silu(gate) * up) |
|
|
|
|
| |
| |
| |
|
|
| class TransformerBlock(nn.Module): |
| """Pre-norm Transformer block: Attn → Residual → SwiGLU → Residual.""" |
|
|
| def __init__(self, hidden_size, num_heads, head_dim, intermediate_size, init_std): |
| super().__init__() |
| self.attn_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) |
| self.attn = GatedAttention(hidden_size, num_heads, head_dim, init_std) |
| self.ffn_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) |
| self.ffn = SwiGLU(hidden_size, intermediate_size, init_std) |
|
|
| def forward(self, x, cos, sin): |
| |
| x = x + self.attn(self.attn_norm(x), cos, sin) |
| |
| x = x + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class RecurrentModule(nn.Module): |
| """A stack of TransformerBlocks used as one recurrent module (H or L). |
| |
| In HRM-Text, each module is a full transformer stack. The module receives |
| its own hidden state + the other module's hidden state (via additive fusion). |
| """ |
|
|
| def __init__(self, layers, hidden_size, num_heads, head_dim, intermediate_size, init_std, use_checkpoint=False): |
| super().__init__() |
| self.use_checkpoint = use_checkpoint |
| self.blocks = nn.ModuleList([ |
| TransformerBlock(hidden_size, num_heads, head_dim, intermediate_size, init_std) |
| for _ in range(layers) |
| ]) |
| self.final_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) |
|
|
| def forward(self, z_self, z_other, cos, sin): |
| """Forward through all blocks with additive cross-module fusion. |
| |
| Args: |
| z_self: This module's hidden state [B, L, D] |
| z_other: Other module's hidden state [B, L, D] |
| """ |
| x = z_self + z_other |
| for block in self.blocks: |
| if self.use_checkpoint and self.training: |
| x = torch_checkpoint(block, x, cos, sin, use_reentrant=True) |
| else: |
| x = block(x, cos, sin) |
| x = self.final_norm(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class HrmTextClassifier(nn.Module): |
| """ |
| HRM-Text adapted for binary classification (prompt injection detection). |
| |
| Architecture: |
| f_emb: ScaledEmbeddingInit (byte-level, vocab=256) |
| L_module: RecurrentModule (low-level transformer stack) |
| H_module: RecurrentModule (high-level transformer stack) |
| f_cls: Classification head (LayerNorm → Linear) |
| |
| Recurrent loop (H_cycles × L_cycles): |
| z_L = L_module(z_L, z_H, cos, sin) |
| z_H = H_module(z_H, z_L, cos, sin) |
| |
| Final: logits = f_cls(z_H[:, -1, :]) # last-token pooling for classification |
| |
| Backprop warmup: gradient-track only the last bp_steps recurrent steps. |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size=256, |
| hidden_size=768, |
| num_heads=12, |
| head_dim=64, |
| n_layers_H=3, |
| n_layers_L=3, |
| intermediate_size=2048, |
| H_cycles=2, |
| L_cycles=3, |
| max_seq_len=4096, |
| rope_base=10000.0, |
| rope_scaling_factor=32.0, |
| num_classes=2, |
| bp_min_steps=2, |
| bp_max_steps=5, |
| bp_warmup_ratio=0.2, |
| use_gradient_checkpointing=False, |
| ): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.H_cycles = H_cycles |
| self.L_cycles = L_cycles |
| self.bp_min_steps = bp_min_steps |
| self.bp_max_steps = bp_max_steps |
| self.bp_warmup_ratio = bp_warmup_ratio |
| self.total_steps = H_cycles * L_cycles |
|
|
| init_std = 1.0 / math.sqrt(hidden_size) |
|
|
| |
| self.embed = ScaledEmbeddingInit( |
| vocab_size, hidden_size, padding_idx=0, |
| init_std=init_std, |
| ) |
|
|
| |
| self.zL_init = nn.Parameter(torch.zeros(1, 1, hidden_size)) |
|
|
| |
| self.rotary = RotaryEmbedding( |
| dim=head_dim, |
| max_seq_len=max_seq_len, |
| base=rope_base, |
| scaling_factor=rope_scaling_factor, |
| ) |
|
|
| |
| self.L_module = RecurrentModule( |
| layers=n_layers_L, |
| hidden_size=hidden_size, |
| num_heads=num_heads, |
| head_dim=head_dim, |
| intermediate_size=intermediate_size, |
| init_std=init_std, |
| use_checkpoint=use_gradient_checkpointing, |
| ) |
| self.H_module = RecurrentModule( |
| layers=n_layers_H, |
| hidden_size=hidden_size, |
| num_heads=num_heads, |
| head_dim=head_dim, |
| intermediate_size=intermediate_size, |
| init_std=init_std, |
| use_checkpoint=use_gradient_checkpointing, |
| ) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(hidden_size), |
| nn.Linear(hidden_size, num_classes), |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| |
| nn.init.zeros_(self.zL_init) |
| |
| for layer in self.classifier: |
| if isinstance(layer, nn.Linear): |
| trunc_normal_init_(layer.weight, std=0.02) |
| if layer.bias is not None: |
| nn.init.zeros_(layer.bias) |
|
|
| def _get_bp_steps(self, training_step_ratio=1.0): |
| """Compute number of backprop steps (warmup from bp_min to bp_max).""" |
| if training_step_ratio >= 1.0: |
| return min(self.bp_max_steps, self.total_steps) |
| warmup_progress = min(1.0, training_step_ratio / self.bp_warmup_ratio) |
| bp = self.bp_min_steps + warmup_progress * (self.bp_max_steps - self.bp_min_steps) |
| return min(int(bp), self.total_steps) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, training_step_ratio=None): |
| """ |
| Args: |
| input_ids: [B, L] byte token IDs |
| attention_mask: [B, L] 1=valid, 0=padding |
| labels: [B] binary labels (0=safe, 1=injection) |
| training_step_ratio: float in [0, 1] for BP warmup scheduling |
| Returns: |
| dict with logits and optional loss |
| """ |
| B, L = input_ids.shape |
| device = input_ids.device |
|
|
| if attention_mask is None: |
| attention_mask = (input_ids != 0).long() |
|
|
| |
| position_ids = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) |
| |
| position_ids = position_ids * attention_mask |
|
|
| |
| z_H = self.embed(input_ids) |
| z_L = self.zL_init.expand(B, L, -1) |
|
|
| |
| cos, sin = self.rotary(position_ids) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| bp_steps = self._get_bp_steps(training_step_ratio or 1.0) |
| H_bp = min(self.H_cycles, max(1, bp_steps - 1)) |
| L_bp = max(0, bp_steps - H_bp) |
| |
| |
| total_L_steps = self.H_cycles * self.L_cycles |
| total_H_steps = self.H_cycles |
|
|
| |
| |
| |
| step_idx = 0 |
| for i in range(self.H_cycles): |
| for k in range(self.L_cycles): |
| grad_enabled = (step_idx >= total_L_steps - L_bp) |
| z_L = self.L_module( |
| z_L if grad_enabled else z_L.detach(), |
| z_H if grad_enabled else z_H.detach(), |
| cos, sin, |
| ) |
| step_idx += 1 |
|
|
| H_grad_enabled = (i >= self.H_cycles - H_bp) |
| z_H = self.H_module( |
| z_H if H_grad_enabled else z_H.detach(), |
| z_L if H_grad_enabled else z_L.detach(), |
| cos, sin, |
| ) |
|
|
| |
| |
| |
| seq_lengths = attention_mask.sum(dim=1).long() |
| last_token_indices = (seq_lengths - 1).clamp(min=0) |
| batch_indices = torch.arange(B, device=device) |
| pooled = z_H[batch_indices, last_token_indices, :] |
|
|
| logits = self.classifier(pooled) |
|
|
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy(logits.float(), labels) |
|
|
| return {"logits": logits, "loss": loss} |
|
|
| @torch.no_grad() |
| def inference(self, input_ids, attention_mask=None): |
| """Inference-only forward (all steps with gradients disabled).""" |
| self.eval() |
| B, L = input_ids.shape |
| device = input_ids.device |
|
|
| if attention_mask is None: |
| attention_mask = (input_ids != 0).long() |
|
|
| position_ids = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) |
| position_ids = position_ids * attention_mask |
|
|
| z_H = self.embed(input_ids) |
| z_L = self.zL_init.expand(B, L, -1) |
| cos, sin = self.rotary(position_ids) |
|
|
| for _ in range(self.H_cycles): |
| for _ in range(self.L_cycles): |
| z_L = self.L_module(z_L, z_H, cos, sin) |
| z_H = self.H_module(z_H, z_L, cos, sin) |
|
|
| seq_lengths = attention_mask.sum(dim=1).long() |
| last_token_indices = (seq_lengths - 1).clamp(min=0) |
| pooled = z_H[torch.arange(B, device=device), last_token_indices, :] |
| logits = self.classifier(pooled) |
| return logits |
|
|
|
|
| |
| |
| |
|
|
| def load_bordair_multimodal(cache_dir=None, max_samples=None): |
| """Load the full Bordair multimodal dataset from HF Hub. |
| |
| The dataset is stored as raw JSON arrays (not HF Dataset format). |
| We snapshot_download the repo and read all JSON files manually. |
| |
| Returns: |
| Dataset with columns: "text" (concatenated modalities), "label" (0/1) |
| """ |
| print("📦 Downloading Bordair/bordair-multimodal from HF Hub...") |
| path = snapshot_download( |
| repo_id="Bordair/bordair-multimodal", |
| repo_type="dataset", |
| cache_dir=cache_dir, |
| ) |
| print(f" Downloaded to: {path}") |
|
|
| all_samples = [] |
|
|
| |
| dir_patterns = [ |
| "benign/*.json", |
| "payloads/*/*.json", |
| "payloads_v5/*.json", |
| "payloads_v5_external/*/*.json", |
| ] |
|
|
| for pattern in dir_patterns: |
| files = sorted(glob.glob(os.path.join(path, pattern))) |
| for f in files: |
| fname = os.path.basename(f) |
| if fname in ("summary.json", "_pool.json", "summary_old.json"): |
| continue |
|
|
| try: |
| with open(f, "r", encoding="utf-8") as fh: |
| data = json.load(fh) |
| except (json.JSONDecodeError, UnicodeDecodeError) as e: |
| print(f" ⚠️ Skipping {f}: {e}") |
| continue |
|
|
| if not isinstance(data, list): |
| continue |
|
|
| for item in data: |
| if not isinstance(item, dict): |
| continue |
| all_samples.append(item) |
|
|
| print(f" {pattern}: {len(all_samples)} cumulative") |
|
|
| print(f"\n✅ Total raw samples loaded: {len(all_samples)}") |
|
|
| |
| |
| rows = [] |
| labels = [] |
| skipped = 0 |
| for item in all_samples: |
| |
| label_val = item.get("expected_detection") |
| if label_val is None: |
| skipped += 1 |
| continue |
|
|
| text_parts = [] |
| if item.get("text"): |
| text_parts.append(item["text"]) |
| if item.get("image_content"): |
| text_parts.append(item["image_content"]) |
| if item.get("document_content"): |
| text_parts.append(item["document_content"]) |
| if item.get("audio_content"): |
| text_parts.append(item["audio_content"]) |
|
|
| combined = "\n".join(text_parts) |
|
|
| |
| if not combined.strip(): |
| skipped += 1 |
| continue |
|
|
| rows.append(combined) |
| labels.append(1 if label_val else 0) |
|
|
| if skipped: |
| print(f" Skipped {skipped} samples (missing label or empty text)") |
|
|
| |
| ds = Dataset.from_dict({"text": rows, "label": labels}) |
| print(f"✅ Dataset: {len(ds)} samples ({sum(labels)} injection, {len(labels) - sum(labels)} safe)") |
| return ds |
|
|
|
|
| def normalize_label(ex, label_col): |
| """Convert label to int64 0/1 (for compatibility with other datasets).""" |
| val = ex[label_col] |
| if isinstance(val, str): |
| return {label_col: 1 if val.lower() in ("malicious", "injection", "yes", "1") else 0} |
| return {label_col: int(val)} |
|
|
|
|
| |
| |
| |
|
|
| class ByteTokenizer: |
| """Byte-level tokenizer: encodes strings as byte IDs [0-255]. |
| |
| Supports variable-length sequences — padded in collation, not here. |
| """ |
|
|
| def __init__(self, max_length=131072): |
| self.max_length = max_length |
| self.pad_token_id = 0 |
| self.eos_token_id = 0 |
| self.pad_token = "<pad>" |
| self.eos_token = "<pad>" |
| self.vocab_size = 256 |
|
|
| def __call__(self, text, truncation=True, max_length=None): |
| max_len = max_length or self.max_length |
| if isinstance(text, str): |
| byte_ids = list(text.encode("utf-8", errors="replace")) |
| else: |
| byte_ids = [] |
| if truncation: |
| byte_ids = byte_ids[:max_len] |
| return byte_ids |
|
|
| def encode_batch(self, texts, max_length=None): |
| """Encode a batch of texts into variable-length byte ID lists.""" |
| max_len = max_length or self.max_length |
| result = [] |
| for text in texts: |
| if isinstance(text, str): |
| byte_ids = list(text.encode("utf-8", errors="replace")) |
| else: |
| byte_ids = [] |
| if max_len: |
| byte_ids = byte_ids[:max_len] |
| result.append(byte_ids) |
| return result |
|
|
| def __len__(self): |
| return self.vocab_size |
|
|
|
|
| def collate_hrm_text(batch, max_length=131072): |
| """Collation for HRM-Text: variable-length byte sequences with padding. |
| |
| Returns dict with input_ids, attention_mask, labels. |
| Sequences are padded to the max length in the batch (not to max_length). |
| """ |
| texts = [ex["text"] for ex in batch] |
| labels = torch.tensor([ex["label"] for ex in batch], dtype=torch.long) |
|
|
| |
| all_ids = [] |
| for t in texts: |
| ids = list(t.encode("utf-8", errors="replace")[:max_length]) |
| all_ids.append(ids) |
|
|
| max_len_in_batch = max(len(ids) for ids in all_ids) if all_ids else 0 |
| |
| max_len_in_batch = min(max_len_in_batch, max_length) |
|
|
| all_ids_padded = [] |
| attention_masks = [] |
| for ids in all_ids: |
| length = min(len(ids), max_length) |
| ids = ids[:length] |
| padded = ids + [0] * (max_len_in_batch - length) |
| mask = [1] * length + [0] * (max_len_in_batch - length) |
| all_ids_padded.append(padded) |
| attention_masks.append(mask) |
|
|
| return { |
| "input_ids": torch.tensor(all_ids_padded, dtype=torch.long), |
| "attention_mask": torch.tensor(attention_masks, dtype=torch.long), |
| "labels": labels, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class HrmTextTrainer(Trainer): |
| """Trainer subclass that handles HRM-Text's custom forward signature |
| and BP warmup scheduling.""" |
|
|
| def __init__(self, *args, total_training_steps=None, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.total_training_steps = total_training_steps |
| self._current_step = 0 |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| labels = inputs.pop("labels") |
| |
| if self.total_training_steps and self.total_training_steps > 0: |
| step_ratio = min(1.0, self._current_step / self.total_training_steps) |
| else: |
| step_ratio = 1.0 |
|
|
| outputs = model(**inputs, labels=labels, training_step_ratio=step_ratio) |
| loss = outputs["loss"] |
| self._current_step += 1 |
| return (loss, outputs) if return_outputs else loss |
|
|
| def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): |
| labels = inputs.pop("labels") if "labels" in inputs else None |
| with torch.no_grad(): |
| logits = model.inference(**inputs) |
| if prediction_loss_only: |
| return (None, None, labels) |
| return (None, logits, labels) |
|
|
|
|
| |
| |
| |
|
|
| def count_params(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train HRM-Text prompt injection detector") |
| parser.add_argument("--test", action="store_true", help="Smoke test on 64 samples") |
| parser.add_argument("--lr", type=float, default=2.2e-4) |
| parser.add_argument("--epochs", type=int, default=3) |
| parser.add_argument("--batch_size", type=int, default=32) |
| parser.add_argument("--output_dir", type=str, default="./pi-hrm-text") |
| parser.add_argument("--cpu", action="store_true") |
| parser.add_argument("--max_length", type=int, default=2048) |
| parser.add_argument("--hidden_size", type=int, default=768) |
| parser.add_argument("--num_heads", type=int, default=12) |
| parser.add_argument("--head_dim", type=int, default=64) |
| parser.add_argument("--n_layers_H", type=int, default=3) |
| parser.add_argument("--n_layers_L", type=int, default=3) |
| parser.add_argument("--intermediate_size", type=int, default=2048) |
| parser.add_argument("--H_cycles", type=int, default=2) |
| parser.add_argument("--L_cycles", type=int, default=3) |
| parser.add_argument("--rope_base", type=float, default=10000.0) |
| parser.add_argument("--rope_scaling", type=float, default=1.0) |
| parser.add_argument("--bp_min", type=int, default=2) |
| parser.add_argument("--bp_max", type=int, default=5) |
| parser.add_argument("--push_to_hub", type=str, default="av-codes/prompt-injection-hrm-text") |
| parser.add_argument("--hub_token", type=str, default=None) |
| parser.add_argument("--gradient_checkpointing", action="store_true", default=True) |
| parser.add_argument("--no_gradient_checkpointing", action="store_false", dest="gradient_checkpointing") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--data_cache", type=str, default=None, |
| help="Cache dir for dataset download") |
| parser.add_argument("--max_steps", type=int, default=-1, |
| help="Max training steps (-1 = use epochs)") |
| parser.add_argument("--resume_from_checkpoint", type=str, default=None, |
| help="Path to checkpoint dir to resume from") |
| args = parser.parse_args() |
|
|
| set_seed(args.seed) |
| use_cuda = torch.cuda.is_available() and not args.cpu |
| device = torch.device("cuda" if use_cuda else "cpu") |
| print(f"🖥️ Hardware: {'GPU' if use_cuda else 'CPU'}") |
| if use_cuda: |
| print(f" Device: {torch.cuda.get_device_name(0)}") |
| print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") |
| print(f"📐 HRM-Text(H={args.n_layers_H}, L={args.n_layers_L}, " |
| f"d={args.hidden_size}, H_cycles={args.H_cycles}, L_cycles={args.L_cycles})") |
| print(f"📏 Max context: {args.max_length:,} tokens") |
| print(f"🔄 BP warmup: {args.bp_min}→{args.bp_max} steps") |
|
|
| |
| print("\n📦 Loading Bordair multimodal dataset...") |
| |
| if args.test: |
| |
| all_parts = [] |
| |
| test_path = args.data_cache or "/tmp/bordair_test" |
| if not os.path.exists(test_path): |
| snapshot_download( |
| repo_id="Bordair/bordair-multimodal", |
| repo_type="dataset", |
| cache_dir=args.data_cache, |
| local_dir=test_path, |
| local_dir_use_symlinks=False, |
| allow_patterns=["benign/text_only.json", "benign/multimodal_text_image.json", |
| "payloads/text_image/text_image_001.json"], |
| ) |
| for f in glob.glob(f"{test_path}/**/*.json", recursive=True): |
| fname = os.path.basename(f) |
| if fname in ("summary.json", "_pool.json"): |
| continue |
| with open(f) as fh: |
| data = json.load(fh) |
| if isinstance(data, list): |
| for item in data: |
| if isinstance(item, dict) and item.get("expected_detection") is not None: |
| text_parts = [item.get("text", "")] |
| for k in ("image_content", "document_content", "audio_content"): |
| if item.get(k): |
| text_parts.append(item[k]) |
| all_parts.append({ |
| "text": "\n".join(text_parts), |
| "label": 1 if item["expected_detection"] else 0, |
| }) |
| merged = Dataset.from_list(all_parts) |
| print(f" Test mode: {len(merged)} samples loaded") |
| else: |
| merged = load_bordair_multimodal(cache_dir=args.data_cache) |
|
|
| |
| |
| merged = merged.cast_column("label", hf_datasets.ClassLabel(names=["safe", "injection"])) |
|
|
| if args.test: |
| train_dataset = merged.select(range(min(64, len(merged)))) |
| eval_dataset = merged.select(range(min(32, len(merged)))) |
| else: |
| split = merged.train_test_split( |
| test_size=0.05, seed=args.seed, stratify_by_column="label", |
| ) |
| train_dataset = split["train"] |
| eval_dataset = split["test"] |
|
|
| print(f"\n✅ Dataset: {len(merged)} total → {len(train_dataset)} train, {len(eval_dataset)} eval") |
| train_dist = Counter(train_dataset["label"]) |
| eval_dist = Counter(eval_dataset["label"]) |
| print(f" Train label dist: {dict(train_dist)}") |
| print(f" Eval label dist: {dict(eval_dist)}") |
|
|
| |
| train_lengths = [len(t.encode("utf-8", errors="replace")) for t in train_dataset["text"]] |
| print(f" Train text length stats: mean={np.mean(train_lengths):.0f}, " |
| f"median={np.median(train_lengths):.0f}, " |
| f"p95={np.percentile(train_lengths, 95):.0f}, " |
| f"max={max(train_lengths):,}") |
|
|
| |
| model = HrmTextClassifier( |
| vocab_size=256, |
| hidden_size=args.hidden_size, |
| num_heads=args.num_heads, |
| head_dim=args.head_dim, |
| n_layers_H=args.n_layers_H, |
| n_layers_L=args.n_layers_L, |
| intermediate_size=args.intermediate_size, |
| H_cycles=args.H_cycles, |
| L_cycles=args.L_cycles, |
| max_seq_len=args.max_length, |
| rope_base=args.rope_base, |
| rope_scaling_factor=args.rope_scaling, |
| num_classes=2, |
| bp_min_steps=args.bp_min, |
| bp_max_steps=args.bp_max, |
| use_gradient_checkpointing=args.gradient_checkpointing, |
| ) |
| param_count = count_params(model) |
| print(f"\n🧮 Model parameters: {param_count:,}") |
| if args.gradient_checkpointing: |
| print(" Gradient checkpointing: enabled") |
| if not args.test: |
| assert 15_000_000 <= param_count <= 55_000_000, \ |
| f"Param count {param_count:,} outside target range [15M, 55M]" |
|
|
| if use_cuda: |
| model = model.cuda() |
|
|
| |
| accuracy = evaluate.load("accuracy") |
| precision = evaluate.load("precision") |
| recall = evaluate.load("recall") |
| f1 = evaluate.load("f1") |
|
|
| def compute_metrics(eval_pred): |
| predictions, labels = eval_pred |
| preds = predictions.argmax(-1) |
| return { |
| "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"], |
| "precision": precision.compute(predictions=preds, references=labels, average="binary")["precision"], |
| "recall": recall.compute(predictions=preds, references=labels, average="binary")["recall"], |
| "f1": f1.compute(predictions=preds, references=labels, average="binary")["f1"], |
| } |
|
|
| |
| steps_per_epoch = max(1, len(train_dataset) // args.batch_size) |
| total_training_steps = steps_per_epoch * args.epochs |
| print(f"📊 Steps per epoch: {steps_per_epoch}, total: {total_training_steps}") |
|
|
| |
| run_name = f"hrm-text-pi_d{args.hidden_size}_lr{args.lr}_ep{args.epochs}_bs{args.batch_size}" |
|
|
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| run_name=run_name, |
| report_to="none", |
| learning_rate=args.lr, |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=min(args.batch_size * 2, 16), |
| num_train_epochs=args.epochs, |
| max_steps=args.max_steps, |
| weight_decay=0.1, |
| warmup_steps=2000 if not args.test else 0, |
| lr_scheduler_type="constant_with_warmup", |
| eval_strategy="steps", |
| eval_steps=1000, |
| save_strategy="steps", |
| save_steps=1000, |
| load_best_model_at_end=True, |
| metric_for_best_model="f1", |
| greater_is_better=True, |
| save_total_limit=3, |
| logging_strategy="steps", |
| logging_first_step=True, |
| logging_steps=5 if args.test else 20, |
| disable_tqdm=False if args.test else True, |
| fp16=False, |
| bf16=use_cuda, |
| push_to_hub=True, |
| hub_model_id=args.push_to_hub, |
| hub_strategy="every_save", |
| use_cpu=not use_cuda, |
| dataloader_num_workers=4, |
| seed=args.seed, |
| adam_beta2=0.95, |
| save_only_model=True, |
| remove_unused_columns=False, |
| ddp_find_unused_parameters=True, |
| gradient_checkpointing=False, |
| ) |
|
|
| |
| def collate_fn(batch): |
| return collate_hrm_text(batch, max_length=args.max_length) |
|
|
| |
| trainer = HrmTextTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| data_collator=collate_fn, |
| compute_metrics=compute_metrics, |
| total_training_steps=total_training_steps, |
| ) |
|
|
| |
| print("\n🚀 Training...") |
| train_start = time.time() |
| trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| train_elapsed = time.time() - train_start |
| print(f"✅ Training complete! ({train_elapsed:.1f}s)") |
| print(f" Best checkpoint: {trainer.state.best_model_checkpoint}") |
|
|
| |
| print("\n📊 Evaluating on eval set...") |
| eval_metrics = trainer.evaluate(eval_dataset) |
| print(f" Eval metrics: {json.dumps(eval_metrics, indent=2)}") |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| eval_path = os.path.join(args.output_dir, "eval_metrics.json") |
| with open(eval_path, "w") as f: |
| json.dump(eval_metrics, f, indent=2) |
|
|
| |
| best_model_path = os.path.join(args.output_dir, "best_model") |
| os.makedirs(best_model_path, exist_ok=True) |
| model_path = os.path.join(best_model_path, "hrm_text_model.pt") |
|
|
| |
| best_ckpt = trainer.state.best_model_checkpoint |
| if best_ckpt and os.path.isdir(best_ckpt): |
| print(f"\n💾 Loading best checkpoint from {best_ckpt}") |
| |
| |
| best_model = trainer.model |
| if hasattr(best_model, "module"): |
| best_model = best_model.module |
| torch.save(best_model.state_dict(), model_path) |
| else: |
| |
| best_model = trainer.model |
| if hasattr(best_model, "module"): |
| best_model = best_model.module |
| torch.save(best_model.state_dict(), model_path) |
| print(f"\n💾 Saved final model weights to {model_path}") |
|
|
| |
| config = { |
| "architecture": "HRM-Text (classification)", |
| "reference": "sapientinc/HRM-Text, arXiv:2506.21734", |
| "hidden_size": args.hidden_size, |
| "num_heads": args.num_heads, |
| "head_dim": args.head_dim, |
| "n_layers_H": args.n_layers_H, |
| "n_layers_L": args.n_layers_L, |
| "intermediate_size": args.intermediate_size, |
| "H_cycles": args.H_cycles, |
| "L_cycles": args.L_cycles, |
| "max_seq_len": args.max_length, |
| "rope_base": args.rope_base, |
| "rope_scaling": args.rope_scaling, |
| "bp_min_steps": args.bp_min, |
| "bp_max_steps": args.bp_max, |
| "vocab_size": 256, |
| "param_count": param_count, |
| "id2label": {0: "safe", 1: "injection"}, |
| "label2id": {"safe": 0, "injection": 1}, |
| "training": { |
| "learning_rate": args.lr, |
| "epochs": args.epochs, |
| "batch_size": args.batch_size, |
| "weight_decay": 0.1, |
| "scheduler": "constant_with_warmup", |
| "warmup_steps": 2000 if not args.test else 0, |
| "adam_beta2": 0.95, |
| "precision": "bf16", |
| }, |
| } |
| with open(os.path.join(best_model_path, "config.json"), "w") as f: |
| json.dump(config, f, indent=2) |
| print(f" Saved config to {best_model_path}/config.json") |
|
|
| |
| if args.push_to_hub: |
| hub_model_id = args.push_to_hub |
| api = HfApi(token=args.hub_token) |
|
|
| print(f"\n☁️ Pushing to Hub: {hub_model_id}") |
|
|
| |
| try: |
| api.create_repo(repo_id=hub_model_id, repo_type="model", private=False, exist_ok=True) |
| print(f" Repo ready: {hub_model_id}") |
| except Exception as e: |
| print(f" ⚠️ Could not create repo: {e}") |
|
|
| |
| api.upload_file( |
| path_or_fileobj=model_path, |
| path_in_repo="pytorch_model.bin", |
| repo_id=hub_model_id, |
| repo_type="model", |
| commit_message=f"HRM-Text prompt injection detector — F1={eval_metrics.get('eval_f1', 0):.4f}", |
| ) |
|
|
| |
| api.upload_file( |
| path_or_fileobj=os.path.join(best_model_path, "config.json"), |
| path_in_repo="config.json", |
| repo_id=hub_model_id, |
| repo_type="model", |
| commit_message="Add model config", |
| ) |
|
|
| |
| api.upload_file( |
| path_or_fileobj=eval_path, |
| path_in_repo="eval_metrics.json", |
| repo_id=hub_model_id, |
| repo_type="model", |
| commit_message="Add evaluation metrics", |
| ) |
|
|
| |
| script_path = os.path.abspath(__file__) if "__file__" in dir() else None |
| if script_path and os.path.exists(script_path): |
| api.upload_file( |
| path_or_fileobj=script_path, |
| path_in_repo="train_hrm_text_pi.py", |
| repo_id=hub_model_id, |
| repo_type="model", |
| commit_message="Add training script", |
| ) |
|
|
| |
| readme = f"""--- |
| license: mit |
| tags: |
| - prompt-injection |
| - hrm-text |
| - hierarchical-reasoning-model |
| - bordair-multimodal |
| - security |
| --- |
| |
| # HRM-Text Prompt Injection Detector |
| |
| **Parameters:** {param_count:,} |
| **Architecture:** HRM-Text (classification port) | d={args.hidden_size}, H={args.n_layers_H}, L={args.n_layers_L}, cycles={args.H_cycles}×{args.L_cycles} |
| **Context window:** {args.max_length:,} tokens (NTK-scaled RoPE) |
| **Training data:** Bordair/bordair-multimodal (503K samples, balanced 1:1) |
| |
| Evaluation on stratified 10% holdout: |
| |
| | Metric | Value | |
| |--------|-------| |
| | Accuracy | {eval_metrics.get('eval_accuracy', 0):.4f} | |
| | Precision | {eval_metrics.get('eval_precision', 0):.4f} | |
| | Recall | {eval_metrics.get('eval_recall', 0):.4f} | |
| | F1 | {eval_metrics.get('eval_f1', 0):.4f} | |
| |
| ## Architecture |
| |
| HRM-Text (arXiv:2506.21734) with a classification head. The model uses a recurrent cascade of two transformer modules (H and L) that exchange information across cycles: |
| |
| - **L module** ({args.n_layers_L} layers, low-level): processes detailed token patterns |
| - **H module** ({args.n_layers_H} layers, high-level): integrates across cycles |
| - **Recurrence**: {args.L_cycles} L-steps per H-cycle, {args.H_cycles} H-cycles total = {args.H_cycles * args.L_cycles} recurrent passes |
| - **Classification**: last-token pooling + LayerNorm + Linear(2) |
| |
| The byte-level tokenizer (vocab 256) handles any text encoding. RoPE uses NTK-aware scaling (θ={args.rope_base}, factor={args.rope_scaling}) for {args.max_length:,}-token context. |
| |
| ## Usage |
| ```python |
| import torch |
| from train_hrm_text_pi import HrmTextClassifier |
| |
| model = HrmTextClassifier( |
| hidden_size={args.hidden_size}, |
| num_heads={args.num_heads}, |
| head_dim={args.head_dim}, |
| n_layers_H={args.n_layers_H}, |
| n_layers_L={args.n_layers_L}, |
| ) |
| state_dict = torch.load("pytorch_model.bin", map_location="cpu") |
| # Remove DDP wrapper keys if present |
| state_dict = {{k.replace('module.', ''): v for k, v in state_dict.items()}} |
| model.load_state_dict(state_dict) |
| model.eval() |
| |
| def detect(text, max_length=131072): |
| byte_ids = list(text.encode("utf-8", errors="replace")[:max_length]) |
| input_ids = torch.tensor([byte_ids]) |
| attention_mask = torch.ones_like(input_ids) |
| logits = model.inference(input_ids, attention_mask) |
| pred = logits.argmax(-1).item() # 0=safe, 1=injection |
| return pred |
| ``` |
| """ |
| api.upload_file( |
| path_or_fileobj=readme.encode(), |
| path_in_repo="README.md", |
| repo_id=hub_model_id, |
| repo_type="model", |
| commit_message="Add README", |
| ) |
|
|
| print(f"✅ https://huggingface.co/{hub_model_id}") |
|
|
| print("\n✅ Done!") |
|
|
|
|
| if __name__ == "__main__": |
| from multiprocessing import freeze_support |
| freeze_support() |
| main() |