SKwra commited on
Commit
89da923
·
verified ·
1 Parent(s): 167a3e0

Add SAE model definition

Browse files
Files changed (1) hide show
  1. sae_model.py +213 -0
sae_model.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAE Model - TopK Sparse Autoencoder
3
+ """
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ @dataclass
15
+ class SAEConfig:
16
+ """SAE 配置"""
17
+ input_dim: int = 4096
18
+ dict_size: int = 32768
19
+ k: int = 128
20
+ device: str = "cuda"
21
+ dtype: str = "bfloat16"
22
+
23
+ def get_torch_dtype(self) -> torch.dtype:
24
+ dtype_map = {
25
+ "float32": torch.float32,
26
+ "float16": torch.float16,
27
+ "bfloat16": torch.bfloat16,
28
+ }
29
+ return dtype_map.get(self.dtype, torch.bfloat16)
30
+
31
+
32
+ class TopKSAE(nn.Module):
33
+ """
34
+ TopK Sparse Autoencoder
35
+ """
36
+ def __init__(self, config: SAEConfig):
37
+ super().__init__()
38
+ self.config = config
39
+
40
+ self.decoder = nn.Linear(config.dict_size, config.input_dim, bias=False)
41
+ self._normalize_decoder()
42
+ self.encoder = nn.Linear(config.input_dim, config.dict_size, bias=False)
43
+ self.encoder.weight.data = self.decoder.weight.data.T.clone()
44
+ self.pre_bias = nn.Parameter(torch.zeros(config.input_dim))
45
+
46
+ self.to(config.device)
47
+ self.to(config.get_torch_dtype())
48
+
49
+ def _normalize_decoder(self):
50
+ """归一化 decoder 权重(每列归一化)"""
51
+ with torch.no_grad():
52
+ self.decoder.weight.data = F.normalize(
53
+ self.decoder.weight.data, dim=0
54
+ )
55
+
56
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
57
+ """编码:input -> latent activations
58
+ Args:
59
+ x: [batch, input_dim]
60
+ Returns:
61
+ latent: [batch, dict_size] 稀疏激活
62
+ """
63
+ centered_x = x - self.pre_bias
64
+ pre_activation = self.encoder(centered_x)
65
+
66
+ topk_values, topk_indices = torch.topk(
67
+ pre_activation, k=self.config.k, dim=-1
68
+ )
69
+
70
+ latents = torch.zeros_like(pre_activation)
71
+ latents.scatter_(-1, topk_indices, F.relu(topk_values))
72
+
73
+ return latents
74
+
75
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
76
+ """解码:latent -> reconstruction
77
+ Args:
78
+ latent: [batch, dict_size]
79
+ Returns:
80
+ reconstruction: [batch, input_dim]
81
+ """
82
+ return self.decoder(latents) + self.pre_bias
83
+
84
+ def forward(
85
+ self,
86
+ x: torch.Tensor,
87
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ latents = self.encode(x)
89
+ x_hat = self.decode(latents)
90
+ return x_hat, latents
91
+
92
+ def compute_loss(
93
+ self,
94
+ x: torch.Tensor,
95
+ ) -> Tuple[torch.Tensor, dict]:
96
+ """计算损失
97
+ Args:
98
+ x: [batch, input_dim]
99
+ Returns:
100
+ (loss, loss_dict)
101
+ """
102
+ x_hat, latent = self.forward(x)
103
+ loss = (((x_hat - x) ** 2).mean(dim=-1) / (x**2).mean(dim=-1)).mean()
104
+ loss_dict = {
105
+ "mean_activation": latent[latent > 0].mean().item() if (latent > 0).any() else 0,
106
+ }
107
+ return loss, loss_dict
108
+
109
+ def get_feature_activations(
110
+ self,
111
+ x: torch.Tensor,
112
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
113
+ """获取特征激活及其索引
114
+ Args:
115
+ x: [batch, input_dim]
116
+ Returns:
117
+ values: [batch, k] TopK 激活值
118
+ indices: [batch, k] TopK 特征索引
119
+ """
120
+ centered_x = x - self.pre_bias
121
+ pre_activation = self.encoder(centered_x)
122
+ topk_values, topk_indices = torch.topk(
123
+ pre_activation, k=self.config.k, dim=-1
124
+ )
125
+ return F.relu(topk_values), topk_indices
126
+
127
+ def get_decoder_vectors(self, indices: torch.Tensor) -> torch.Tensor:
128
+ """获取指定特征的 decoder 向量
129
+ Args:
130
+ indices: [batch, k] 特征索引
131
+ Returns:
132
+ vectors: [batch, k, input_dim] decoder 向量
133
+ """
134
+ return self.decoder.weight[:, indices].permute(1, 2, 0)
135
+
136
+ def steer(
137
+ self,
138
+ x: torch.Tensor,
139
+ feature_idx: int,
140
+ strength: float,
141
+ ) -> torch.Tensor:
142
+ """在原始 activation 上直接加减 decoder 方向来 steer 指定特征。
143
+
144
+ h_new = h + (strength - 1) * z_i * d_i
145
+
146
+ 其中 z_i 是 feature_idx 的 TopK 编码激活值,d_i 是其 decoder 列向量。
147
+ 不走完整 encode→decode,避免 TopK 重建误差污染其他特征。
148
+ strength=1.0 时变化量为 0(真正的 baseline)。
149
+
150
+ Args:
151
+ x: [batch, input_dim] 输入激活
152
+ feature_idx: 要 steering 的特征索引
153
+ strength: 目标特征缩放系数(1.0 不变,>1 增强,<1 抑制,0 完全消除)
154
+ Returns:
155
+ steered_x: [batch, input_dim] 调整后的激活
156
+ """
157
+ with torch.no_grad():
158
+ latents = self.encode(x) # [batch, dict_size]
159
+ z_i = latents[:, feature_idx] # [batch]
160
+ d_i = self.decoder.weight[:, feature_idx] # [input_dim]
161
+ delta = (strength - 1) * z_i.unsqueeze(-1) * d_i.unsqueeze(0)
162
+ return x + delta
163
+
164
+ def steer_multi(
165
+ self,
166
+ x: torch.Tensor,
167
+ feature_indices: list,
168
+ strengths: list,
169
+ ) -> torch.Tensor:
170
+ """在原始 activation 上直接加减 decoder 方向来同时 steer 多个特征。
171
+
172
+ h_new = h + sum_i (strength_i - 1) * z_i * d_i
173
+
174
+ 不走完整 encode→decode,避免 TopK 重建误差污染其他特征。
175
+ strength=1.0 时对应特征的变化量为 0(真正的 baseline)。
176
+
177
+ Args:
178
+ x: [batch, input_dim] 输入激活
179
+ feature_indices: 要 steering 的特征索引列表
180
+ strengths: 对应的缩放系数列表
181
+ Returns:
182
+ steered_x: [batch, input_dim] 调整后的激活
183
+ """
184
+ with torch.no_grad():
185
+ latents = self.encode(x) # [batch, dict_size]
186
+ delta = torch.zeros_like(x)
187
+ for feat_idx, strength in zip(feature_indices, strengths):
188
+ z_i = latents[:, feat_idx] # [batch]
189
+ d_i = self.decoder.weight[:, feat_idx] # [input_dim]
190
+ delta += (strength - 1) * z_i.unsqueeze(-1) * d_i.unsqueeze(0)
191
+ return x + delta
192
+
193
+ def save(self, path: str):
194
+ """保存模型"""
195
+ torch.save({
196
+ "config": self.config,
197
+ "state_dict": self.state_dict(),
198
+ }, path)
199
+
200
+ @classmethod
201
+ def load(cls, path: str, device: str = "cuda") -> "TopKSAE":
202
+ """加载模型"""
203
+ checkpoint = torch.load(path, map_location=device, weights_only=False)
204
+ config = checkpoint["config"]
205
+ config.device = device
206
+
207
+ model = cls(config)
208
+ model.load_state_dict(checkpoint["state_dict"])
209
+ model.to(config.get_torch_dtype())
210
+
211
+ return model
212
+
213
+