SKwra commited on
Commit
167a3e0
·
verified ·
1 Parent(s): 3996dd1

Add model card

Browse files
Files changed (1) hide show
  1. README.md +135 -0
README.md ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - sparse-autoencoder
5
+ - mechanistic-interpretability
6
+ - tool-calling
7
+ - gemma
8
+ - ministral
9
+ - qwen
10
+ - safelens
11
+ - steering-vectors
12
+ - llm-agents
13
+ arxiv: 2605.18882
14
+ ---
15
+
16
+ # SafeLens SAE Checkpoints
17
+
18
+ Pre-trained **TopK Sparse Autoencoders (SAEs)** for diagnosing and correcting intrinsic tool-calling bias in LLM agents, as described in:
19
+
20
+ > **To Call or Not to Call: Diagnosing Intrinsic Over-Calling Bias in LLM Agents**
21
+ > Wei Shi, Ziheng Peng, Sihang Li, Xiting Wang, Xiang Wang, Mengnan Du, Na Zou
22
+ > [arXiv:2605.18882](https://arxiv.org/abs/2605.18882)
23
+
24
+ ---
25
+
26
+ ## What are these checkpoints?
27
+
28
+ Each checkpoint is a **TopK SAE** trained on residual stream activations at a specific layer of a base LLM. The SAE learns a sparse dictionary of features, among which we identify features encoding the "tool-call" vs. "request-for-info" decision boundary. These features are then used for:
29
+
30
+ - **H1 — Feature discovery**: isolating tool-call-aligned features via mean-diff & AUROC
31
+ - **H2 — Bias quantification**: fitting a logistic probe to measure intrinsic call offset β₀
32
+ - **H3 — Causal steering**: suppressing TC features / promoting RFI features to shift model decisions
33
+ - **AMCS** (Adaptive Margin-Calibrated Steering): closed-form inference-time bias correction
34
+
35
+ ---
36
+
37
+ ## Available Checkpoints
38
+
39
+ | Model | Layer | SAE Dict Size | k | Stage 1 Tokens | Stage 2 Tokens |
40
+ |-------|-------|--------------|---|----------------|----------------|
41
+ | gemma-3-1b-it | L17 | 9 216 | 128 | 50M | 5M |
42
+ | gemma-3-4b-it | L29 | 20 480 | 128 | 50M | 5M |
43
+ | gemma-4-E2B-it | L30 | 12 288 | 128 | 50M | 5M |
44
+ | gemma-4-E4B-it | L30 | 20 480 | 128 | 50M | 5M |
45
+ | Ministral-3-3B-Instruct-2512 | L21 | 24 576 | 128 | 50M | 5M |
46
+ | Ministral-3-8B-Instruct-2512 | L31 | 32 768 | 128 | 50M | 5M |
47
+ | Qwen3.5-4B | L25 | 20 480 | 128 | 50M | 5M |
48
+ | Qwen3.5-9B | L25 | 32 768 | 128 | 50M | 5M |
49
+
50
+ **Stage 1**: General-purpose SAE pre-training on 50M tokens from the base model's residual stream.
51
+ **Stage 2**: Fine-tuned on 5M tool-calling-specific activations (When2Call benchmark data).
52
+
53
+ All checkpoints use `bfloat16` precision.
54
+
55
+ ---
56
+
57
+ ## File Structure
58
+
59
+ ```
60
+ gemma-3-1b-it/
61
+ stage1/
62
+ gemma-3-1b-it-L17-d9216-50M-stage1.pt
63
+ gemma-3-1b-it-L17-d9216-50M-stage1_stats.json
64
+ stage2/
65
+ gemma-3-1b-it-L17-d9216-5M-stage2.pt
66
+ gemma-3-1b-it-L17-d9216-5M-stage2_stats.json
67
+ ...
68
+ ```
69
+
70
+ ---
71
+
72
+ ## Usage
73
+
74
+ ### Load a checkpoint
75
+
76
+ ```python
77
+ import torch
78
+ from huggingface_hub import hf_hub_download
79
+
80
+ # Download a checkpoint
81
+ ckpt_path = hf_hub_download(
82
+ repo_id="SKwra/toolcalling-sae",
83
+ filename="gemma-3-1b-it/stage2/gemma-3-1b-it-L17-d9216-5M-stage2.pt"
84
+ )
85
+
86
+ # Load (requires sae_model.py from the GitHub repo)
87
+ from sae_model import TopKSAE
88
+ sae = TopKSAE.load(ckpt_path, device="cuda")
89
+ ```
90
+
91
+ ### Encode activations
92
+
93
+ ```python
94
+ # activations: [batch, input_dim] residual stream tensor at the target layer
95
+ latents = sae.encode(activations) # [batch, dict_size] sparse activations
96
+ reconstruction = sae.decode(latents) # [batch, input_dim]
97
+ ```
98
+
99
+ ### Steer a feature
100
+
101
+ ```python
102
+ # Suppress feature 42 by 80% (strength=0.2 → nearly zero out)
103
+ steered = sae.steer(activations, feature_idx=42, strength=0.2)
104
+ ```
105
+
106
+ ### SAEConfig fields
107
+
108
+ ```python
109
+ @dataclass
110
+ class SAEConfig:
111
+ input_dim: int # residual stream width of the base model
112
+ dict_size: int # SAE dictionary size
113
+ k: int = 128 # TopK sparsity
114
+ device: str = "cuda"
115
+ dtype: str = "bfloat16"
116
+ ```
117
+
118
+ ---
119
+
120
+ ## Citation
121
+
122
+ ```bibtex
123
+ @article{shi2025call,
124
+ title={To Call or Not to Call: Diagnosing Intrinsic Over-Calling Bias in LLM Agents},
125
+ author={Shi, Wei and Peng, Ziheng and Li, Sihang and Wang, Xiting and Wang, Xiang and Du, Mengnan and Zou, Na},
126
+ journal={arXiv preprint arXiv:2605.18882},
127
+ year={2025}
128
+ }
129
+ ```
130
+
131
+ ---
132
+
133
+ ## License
134
+
135
+ Apache 2.0. See [LICENSE](https://github.com/your-repo/blob/main/LICENSE) for details.