SKwra commited on
Commit
6741591
·
verified ·
1 Parent(s): 1cdd4de

Add model card

Browse files
Files changed (1) hide show
  1. README.md +17 -105
README.md CHANGED
@@ -7,129 +7,41 @@ tags:
7
  - gemma
8
  - ministral
9
  - qwen
10
- - safelens
11
- - steering-vectors
12
- - llm-agents
13
  arxiv: 2605.18882
14
  ---
15
 
16
- # SafeLens SAE Checkpoints
17
 
18
- Pre-trained **TopK Sparse Autoencoders (SAEs)** for diagnosing and correcting intrinsic tool-calling bias in LLM agents, as described in:
19
 
20
- > **To Call or Not to Call: Diagnosing Intrinsic Over-Calling Bias in LLM Agents**
21
- > Wei Shi, Ziheng Peng, Sihang Li, Xiting Wang, Xiang Wang, Mengnan Du, Na Zou
22
- > [arXiv:2605.18882](https://arxiv.org/abs/2605.18882)
23
 
24
- ---
25
-
26
- ## What are these checkpoints?
27
-
28
- Each checkpoint is a **TopK SAE** trained on residual stream activations at a specific layer of a base LLM. The SAE learns a sparse dictionary of features, among which we identify features encoding the "tool-call" vs. "request-for-info" decision boundary. These features are then used for:
29
-
30
- - **H1 Feature discovery**: isolating tool-call-aligned features via mean-diff & AUROC
31
- - **H2 Bias quantification**: fitting a logistic probe to measure intrinsic call offset β₀
32
- - **H3 Causal steering**: suppressing TC features / promoting RFI features to shift model decisions
33
- - **AMCS** (Adaptive Margin-Calibrated Steering): closed-form inference-time bias correction
34
-
35
- ---
36
-
37
- ## Available Checkpoints
38
-
39
- | Model | Layer | SAE Dict Size | k | Stage 1 Tokens | Stage 2 Tokens |
40
- |-------|-------|--------------|---|----------------|----------------|
41
- | gemma-3-1b-it | L17 | 9 216 | 128 | 50M | 5M |
42
- | gemma-3-4b-it | L29 | 20 480 | 128 | 50M | 5M |
43
- | gemma-4-E2B-it | L30 | 12 288 | 128 | 50M | 5M |
44
- | gemma-4-E4B-it | L30 | 20 480 | 128 | 50M | 5M |
45
- | Ministral-3-3B-Instruct-2512 | L21 | 24 576 | 128 | 50M | 5M |
46
- | Ministral-3-8B-Instruct-2512 | L31 | 32 768 | 128 | 50M | 5M |
47
- | Qwen3.5-4B | L25 | 20 480 | 128 | 50M | 5M |
48
- | Qwen3.5-9B | L25 | 32 768 | 128 | 50M | 5M |
49
-
50
- **Stage 1**: General-purpose SAE pre-training on 50M tokens from the base model's residual stream.
51
- **Stage 2**: Fine-tuned on 5M tool-calling-specific activations (When2Call benchmark data).
52
 
 
 
53
  All checkpoints use `bfloat16` precision.
54
 
55
- ---
56
-
57
- ## File Structure
58
-
59
- ```
60
- gemma-3-1b-it/
61
- stage1/
62
- gemma-3-1b-it-L17-d9216-50M-stage1.pt
63
- gemma-3-1b-it-L17-d9216-50M-stage1_stats.json
64
- stage2/
65
- gemma-3-1b-it-L17-d9216-5M-stage2.pt
66
- gemma-3-1b-it-L17-d9216-5M-stage2_stats.json
67
- ...
68
- ```
69
-
70
- ---
71
-
72
  ## Usage
73
 
74
- ### Load a checkpoint
75
-
76
  ```python
77
- import torch
78
  from huggingface_hub import hf_hub_download
 
79
 
80
- # Download a checkpoint
81
  ckpt_path = hf_hub_download(
82
  repo_id="SKwra/toolcalling-sae",
83
  filename="gemma-3-1b-it/stage2/gemma-3-1b-it-L17-d9216-5M-stage2.pt"
84
  )
85
-
86
- # Load (requires sae_model.py from the GitHub repo)
87
- from sae_model import TopKSAE
88
  sae = TopKSAE.load(ckpt_path, device="cuda")
89
  ```
90
 
91
- ### Encode activations
92
-
93
- ```python
94
- # activations: [batch, input_dim] residual stream tensor at the target layer
95
- latents = sae.encode(activations) # [batch, dict_size] sparse activations
96
- reconstruction = sae.decode(latents) # [batch, input_dim]
97
- ```
98
-
99
- ### Steer a feature
100
-
101
- ```python
102
- # Suppress feature 42 by 80% (strength=0.2 → nearly zero out)
103
- steered = sae.steer(activations, feature_idx=42, strength=0.2)
104
- ```
105
-
106
- ### SAEConfig fields
107
-
108
- ```python
109
- @dataclass
110
- class SAEConfig:
111
- input_dim: int # residual stream width of the base model
112
- dict_size: int # SAE dictionary size
113
- k: int = 128 # TopK sparsity
114
- device: str = "cuda"
115
- dtype: str = "bfloat16"
116
- ```
117
-
118
- ---
119
-
120
- ## Citation
121
-
122
- ```bibtex
123
- @article{shi2025call,
124
- title={To Call or Not to Call: Diagnosing Intrinsic Over-Calling Bias in LLM Agents},
125
- author={Shi, Wei and Peng, Ziheng and Li, Sihang and Wang, Xiting and Wang, Xiang and Du, Mengnan and Zou, Na},
126
- journal={arXiv preprint arXiv:2605.18882},
127
- year={2025}
128
- }
129
- ```
130
-
131
- ---
132
-
133
- ## License
134
-
135
- Apache 2.0. See [LICENSE](https://github.com/your-repo/blob/main/LICENSE) for details.
 
7
  - gemma
8
  - ministral
9
  - qwen
 
 
 
10
  arxiv: 2605.18882
11
  ---
12
 
13
+ # toolcalling-sae
14
 
15
+ TopK Sparse Autoencoder checkpoints from [To Call or Not to Call: Diagnosing Intrinsic Over-Calling Bias in LLM Agents](https://arxiv.org/abs/2605.18882).
16
 
17
+ ## Checkpoints
 
 
18
 
19
+ | Model | Layer | Dict Size | k | Stage 1 | Stage 2 |
20
+ |-------|-------|-----------|---|---------|---------|
21
+ | gemma-3-1b-it | L17 | 9 216 | 128 | 50M tokens | 5M tokens |
22
+ | gemma-3-4b-it | L29 | 20 480 | 128 | 50M tokens | 5M tokens |
23
+ | gemma-4-E2B-it | L30 | 12 288 | 128 | 50M tokens | 5M tokens |
24
+ | gemma-4-E4B-it | L30 | 20 480 | 128 | 50M tokens | 5M tokens |
25
+ | Ministral-3-3B-Instruct-2512 | L21 | 24 576 | 128 | 50M tokens | 5M tokens |
26
+ | Ministral-3-8B-Instruct-2512 | L31 | 32 768 | 128 | 50M tokens | 5M tokens |
27
+ | Qwen3.5-4B | L25 | 20 480 | 128 | 50M tokens | 5M tokens |
28
+ | Qwen3.5-9B | L25 | 32 768 | 128 | 50M tokens | 5M tokens |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ **Stage 1**: Pre-trained on [OpenWebText2](https://openwebtext2.readthedocs.io/).
31
+ **Stage 2**: Fine-tuned on tool-calling activations from the [When2Call](https://arxiv.org/abs/2605.18882) benchmark.
32
  All checkpoints use `bfloat16` precision.
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ## Usage
35
 
 
 
36
  ```python
 
37
  from huggingface_hub import hf_hub_download
38
+ from sae_model import TopKSAE
39
 
 
40
  ckpt_path = hf_hub_download(
41
  repo_id="SKwra/toolcalling-sae",
42
  filename="gemma-3-1b-it/stage2/gemma-3-1b-it-L17-d9216-5M-stage2.pt"
43
  )
 
 
 
44
  sae = TopKSAE.load(ckpt_path, device="cuda")
45
  ```
46
 
47
+ `sae_model.py` is included in this repo. Full code at [GitHub](https://github.com/SKURA502/agent-sae).