File size: 1,943 Bytes
6836f3e
 
 
 
 
 
44cbed0
 
6836f3e
 
 
 
44cbed0
 
 
 
6836f3e
44cbed0
6836f3e
 
 
44cbed0
 
 
 
6836f3e
 
 
44cbed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
---
license: mit
tags:
  - prompt-injection
  - hrm-text
  - hierarchical-reasoning-model
  - bordair-multimodal
  - security
---

# HRM-Text Prompt Injection Detector

**Parameters:** 46,206,722  
**Architecture:** HRM-Text (classification port) | d=768, H=3, L=3, cycles=2×3  
**Context window:** 2,048 tokens (NTK-scaled RoPE)  
**Training data:** Bordair/bordair-multimodal (503K samples, balanced 1:1)  

Evaluation on stratified 10% holdout:

| Metric | Value |
|--------|-------|
| Accuracy | 0.9893 |
| Precision | 0.9934 |
| Recall | 0.9838 |
| F1 | 0.9886 |

## Architecture

HRM-Text (arXiv:2506.21734) with a classification head. The model uses a recurrent cascade of two transformer modules (H and L) that exchange information across cycles:

- **L module** (3 layers, low-level): processes detailed token patterns
- **H module** (3 layers, high-level): integrates across cycles  
- **Recurrence**: 3 L-steps per H-cycle, 2 H-cycles total = 6 recurrent passes
- **Classification**: last-token pooling + LayerNorm + Linear(2)

The byte-level tokenizer (vocab 256) handles any text encoding. RoPE uses NTK-aware scaling (θ=10000.0, factor=1.0) for 2,048-token context.

## Usage
```python
import torch
from train_hrm_text_pi import HrmTextClassifier

model = HrmTextClassifier(
    hidden_size=768,
    num_heads=12,
    head_dim=64,
    n_layers_H=3,
    n_layers_L=3,
)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
# Remove DDP wrapper keys if present
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()

def detect(text, max_length=131072):
    byte_ids = list(text.encode("utf-8", errors="replace")[:max_length])
    input_ids = torch.tensor([byte_ids])
    attention_mask = torch.ones_like(input_ids)
    logits = model.inference(input_ids, attention_mask)
    pred = logits.argmax(-1).item()  # 0=safe, 1=injection
    return pred
```