| --- |
| license: mit |
| tags: |
| - prompt-injection |
| - hrm-text |
| - hierarchical-reasoning-model |
| - bordair-multimodal |
| - security |
| --- |
| |
| # HRM-Text Prompt Injection Detector |
|
|
| **Parameters:** 46,206,722 |
| **Architecture:** HRM-Text (classification port) | d=768, H=3, L=3, cycles=2×3 |
| **Context window:** 2,048 tokens (NTK-scaled RoPE) |
| **Training data:** Bordair/bordair-multimodal (503K samples, balanced 1:1) |
|
|
| Evaluation on stratified 10% holdout: |
|
|
| | Metric | Value | |
| |--------|-------| |
| | Accuracy | 0.9893 | |
| | Precision | 0.9934 | |
| | Recall | 0.9838 | |
| | F1 | 0.9886 | |
|
|
| ## Architecture |
|
|
| HRM-Text (arXiv:2506.21734) with a classification head. The model uses a recurrent cascade of two transformer modules (H and L) that exchange information across cycles: |
|
|
| - **L module** (3 layers, low-level): processes detailed token patterns |
| - **H module** (3 layers, high-level): integrates across cycles |
| - **Recurrence**: 3 L-steps per H-cycle, 2 H-cycles total = 6 recurrent passes |
| - **Classification**: last-token pooling + LayerNorm + Linear(2) |
|
|
| The byte-level tokenizer (vocab 256) handles any text encoding. RoPE uses NTK-aware scaling (θ=10000.0, factor=1.0) for 2,048-token context. |
|
|
| ## Usage |
| ```python |
| import torch |
| from train_hrm_text_pi import HrmTextClassifier |
| |
| model = HrmTextClassifier( |
| hidden_size=768, |
| num_heads=12, |
| head_dim=64, |
| n_layers_H=3, |
| n_layers_L=3, |
| ) |
| state_dict = torch.load("pytorch_model.bin", map_location="cpu") |
| # Remove DDP wrapper keys if present |
| state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
| model.load_state_dict(state_dict) |
| model.eval() |
| |
| def detect(text, max_length=131072): |
| byte_ids = list(text.encode("utf-8", errors="replace")[:max_length]) |
| input_ids = torch.tensor([byte_ids]) |
| attention_mask = torch.ones_like(input_ids) |
| logits = model.inference(input_ids, attention_mask) |
| pred = logits.argmax(-1).item() # 0=safe, 1=injection |
| return pred |
| ``` |
|
|