File size: 1,943 Bytes
6836f3e 44cbed0 6836f3e 44cbed0 6836f3e 44cbed0 6836f3e 44cbed0 6836f3e 44cbed0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | ---
license: mit
tags:
- prompt-injection
- hrm-text
- hierarchical-reasoning-model
- bordair-multimodal
- security
---
# HRM-Text Prompt Injection Detector
**Parameters:** 46,206,722
**Architecture:** HRM-Text (classification port) | d=768, H=3, L=3, cycles=2×3
**Context window:** 2,048 tokens (NTK-scaled RoPE)
**Training data:** Bordair/bordair-multimodal (503K samples, balanced 1:1)
Evaluation on stratified 10% holdout:
| Metric | Value |
|--------|-------|
| Accuracy | 0.9893 |
| Precision | 0.9934 |
| Recall | 0.9838 |
| F1 | 0.9886 |
## Architecture
HRM-Text (arXiv:2506.21734) with a classification head. The model uses a recurrent cascade of two transformer modules (H and L) that exchange information across cycles:
- **L module** (3 layers, low-level): processes detailed token patterns
- **H module** (3 layers, high-level): integrates across cycles
- **Recurrence**: 3 L-steps per H-cycle, 2 H-cycles total = 6 recurrent passes
- **Classification**: last-token pooling + LayerNorm + Linear(2)
The byte-level tokenizer (vocab 256) handles any text encoding. RoPE uses NTK-aware scaling (θ=10000.0, factor=1.0) for 2,048-token context.
## Usage
```python
import torch
from train_hrm_text_pi import HrmTextClassifier
model = HrmTextClassifier(
hidden_size=768,
num_heads=12,
head_dim=64,
n_layers_H=3,
n_layers_L=3,
)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
# Remove DDP wrapper keys if present
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
def detect(text, max_length=131072):
byte_ids = list(text.encode("utf-8", errors="replace")[:max_length])
input_ids = torch.tensor([byte_ids])
attention_mask = torch.ones_like(input_ids)
logits = model.inference(input_ids, attention_mask)
pred = logits.argmax(-1).item() # 0=safe, 1=injection
return pred
```
|