LoganResearch commited on
Commit
6b8163e
·
verified ·
1 Parent(s): ef7c3e0

add universal inference loader — works with all probes

Browse files
Files changed (1) hide show
  1. inference.py +342 -0
inference.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CF-HoT Universal Probe Loader
4
+
5
+ Load any probe from this repo and run it on a model's hidden states.
6
+ Works with all suppression probes (LLaMA 8B) and cognitive enhancement
7
+ probes (Qwen, Mamba, Mistral).
8
+
9
+ Usage:
10
+ python inference.py --probe suppression/hedging_168x
11
+ python inference.py --probe cognitive/mistral/depth
12
+ python inference.py --probe suppression/repetition_125x --prompt "Tell me about AI"
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import argparse
18
+ import os
19
+ import glob
20
+
21
+
22
+ # ─── Architecture definitions ───────────────────────────────────────
23
+
24
+ class FiberProjection(nn.Module):
25
+ """Projects hidden states from multiple layers into fiber space."""
26
+ def __init__(self, hidden_dim, fiber_dim=16, num_layers=3, bias=True):
27
+ super().__init__()
28
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
29
+ self.projections = nn.ModuleList([
30
+ nn.Linear(hidden_dim, fiber_dim, bias=bias)
31
+ for _ in range(num_layers)
32
+ ])
33
+
34
+ def forward(self, hidden_states_list):
35
+ weights = torch.softmax(self.layer_weights, dim=0)
36
+ return sum(w * proj(h.float())
37
+ for w, h, proj in zip(weights, hidden_states_list, self.projections))
38
+
39
+
40
+ class ProbeHead(nn.Module):
41
+ """Classifies fiber-space vectors into behavioral risk scores."""
42
+ def __init__(self, fiber_dim=16, hidden_dim=64):
43
+ super().__init__()
44
+ self.classifier = nn.Sequential(
45
+ nn.Linear(fiber_dim, hidden_dim), nn.GELU(),
46
+ nn.Linear(hidden_dim, hidden_dim), nn.GELU(),
47
+ nn.Linear(hidden_dim, 1),
48
+ )
49
+
50
+ def forward(self, x):
51
+ return torch.sigmoid(self.classifier(x))
52
+
53
+
54
+ class RiskPredictor(nn.Module):
55
+ """Full risk predictor (used by repetition_125x). All-layer version."""
56
+ def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=32):
57
+ super().__init__()
58
+ self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
59
+ self.fiber_projs = nn.ModuleList([
60
+ nn.Linear(hidden_dim, fiber_dim, bias=False)
61
+ for _ in range(n_layers)
62
+ ])
63
+ self.predictor = nn.Sequential(
64
+ nn.Linear(fiber_dim, 64), nn.GELU(),
65
+ nn.Linear(64, 64), nn.GELU(),
66
+ nn.Linear(64, 1),
67
+ )
68
+
69
+ def forward(self, hidden_states_list):
70
+ weights = torch.softmax(self.layer_weights, dim=0)
71
+ fiber = sum(w * proj(h.float())
72
+ for w, h, proj in zip(weights, hidden_states_list, self.fiber_projs))
73
+ return torch.sigmoid(self.predictor(fiber))
74
+
75
+
76
+ # ─── Loader ─────────────────────────────────────────────────────────
77
+
78
+ # Base models and their configs
79
+ MODEL_CONFIGS = {
80
+ "llama": {
81
+ "model_id": "meta-llama/Llama-3.1-8B-Instruct",
82
+ "hidden_dim": 4096,
83
+ "n_layers": 32,
84
+ "probe_layers": [10, 20, 30], # default for 3-layer probes
85
+ },
86
+ "qwen": {
87
+ "model_id": "Qwen/Qwen2.5-7B-Instruct",
88
+ "hidden_dim": 3584,
89
+ "n_layers": 28,
90
+ "probe_layers": [9, 18, 27],
91
+ },
92
+ "mamba": {
93
+ "model_id": "tiiuae/falcon-mamba-7b-instruct",
94
+ "hidden_dim": 4096,
95
+ "n_layers": 64,
96
+ "probe_layers": [16, 32, 48],
97
+ },
98
+ "mistral": {
99
+ "model_id": "mistralai/Mistral-7B-Instruct-v0.3",
100
+ "hidden_dim": 4096,
101
+ "n_layers": 32,
102
+ "probe_layers": [8, 16, 24],
103
+ },
104
+ }
105
+
106
+
107
+ def detect_probe_type(probe_path):
108
+ """Auto-detect what kind of probe checkpoint this is."""
109
+ files = os.listdir(probe_path) if os.path.isdir(probe_path) else []
110
+
111
+ # Repetition uses risk_predictor.pt
112
+ if "risk_predictor.pt" in files:
113
+ return "risk_predictor"
114
+
115
+ # Suppression probes: separate head + fiber_proj files
116
+ head_files = [f for f in files if f.endswith("_head.pt")]
117
+ if head_files and "fiber_proj.pt" in files:
118
+ return "suppression"
119
+
120
+ # Cognitive probes: single file with fiber_projection + head_state
121
+ if head_files and "fiber_proj.pt" not in files:
122
+ return "cognitive"
123
+
124
+ return "unknown"
125
+
126
+
127
+ def detect_architecture(probe_path):
128
+ """Detect which base model architecture a probe targets."""
129
+ path_lower = probe_path.lower()
130
+ if "qwen" in path_lower:
131
+ return "qwen"
132
+ elif "mamba" in path_lower:
133
+ return "mamba"
134
+ elif "mistral" in path_lower:
135
+ return "mistral"
136
+ else:
137
+ return "llama" # suppression probes default to LLaMA
138
+
139
+
140
+ def load_probe(probe_path, device="cuda"):
141
+ """
142
+ Load any CF-HoT probe from a directory.
143
+
144
+ Returns:
145
+ dict with keys:
146
+ - 'type': str ('risk_predictor', 'suppression', or 'cognitive')
147
+ - 'arch': str ('llama', 'qwen', 'mamba', 'mistral')
148
+ - 'config': dict (model config)
149
+ - 'fiber': FiberProjection or None
150
+ - 'head': ProbeHead or None
151
+ - 'risk_predictor': RiskPredictor or None
152
+ - 'probe_layers': list[int]
153
+ - 'metadata': dict (step, separation, etc.)
154
+ """
155
+ probe_type = detect_probe_type(probe_path)
156
+ arch = detect_architecture(probe_path)
157
+ config = MODEL_CONFIGS[arch]
158
+
159
+ result = {
160
+ "type": probe_type,
161
+ "arch": arch,
162
+ "config": config,
163
+ "fiber": None,
164
+ "head": None,
165
+ "risk_predictor": None,
166
+ "probe_layers": config["probe_layers"],
167
+ "metadata": {},
168
+ }
169
+
170
+ if probe_type == "risk_predictor":
171
+ ckpt = torch.load(os.path.join(probe_path, "risk_predictor.pt"),
172
+ map_location=device, weights_only=False)
173
+ rp = RiskPredictor(
174
+ hidden_dim=config["hidden_dim"],
175
+ fiber_dim=16,
176
+ n_layers=config["n_layers"]
177
+ ).to(device)
178
+ # Keys are nested under 'risk_predictor.*'
179
+ state = {k.replace("risk_predictor.", ""): v
180
+ for k, v in ckpt.items() if k.startswith("risk_predictor.")}
181
+ rp.load_state_dict(state)
182
+ rp.eval()
183
+ result["risk_predictor"] = rp
184
+ result["probe_layers"] = list(range(config["n_layers"]))
185
+ if "step" in ckpt:
186
+ result["metadata"]["step"] = ckpt["step"]
187
+
188
+ elif probe_type == "suppression":
189
+ # Separate head + fiber_proj files
190
+ head_file = [f for f in os.listdir(probe_path) if f.endswith("_head.pt")][0]
191
+ head_ckpt = torch.load(os.path.join(probe_path, head_file),
192
+ map_location=device, weights_only=False)
193
+ fiber_ckpt = torch.load(os.path.join(probe_path, "fiber_proj.pt"),
194
+ map_location=device, weights_only=False)
195
+
196
+ # Detect bias from checkpoint
197
+ has_bias = any("bias" in k for k in fiber_ckpt.keys())
198
+
199
+ fiber = FiberProjection(
200
+ hidden_dim=config["hidden_dim"], fiber_dim=16,
201
+ num_layers=3, bias=has_bias
202
+ ).to(device)
203
+ fiber.load_state_dict(fiber_ckpt)
204
+ fiber.eval()
205
+
206
+ head = ProbeHead(fiber_dim=16, hidden_dim=64).to(device)
207
+ head.load_state_dict(head_ckpt)
208
+ head.eval()
209
+
210
+ result["fiber"] = fiber
211
+ result["head"] = head
212
+
213
+ elif probe_type == "cognitive":
214
+ head_file = [f for f in os.listdir(probe_path) if f.endswith("_head.pt")][0]
215
+ ckpt = torch.load(os.path.join(probe_path, head_file),
216
+ map_location=device, weights_only=False)
217
+
218
+ # Extract metadata
219
+ for key in ["step", "separation", "loss", "probe_name",
220
+ "hidden_dim", "probe_layers", "architecture"]:
221
+ if key in ckpt:
222
+ result["metadata"][key] = ckpt[key]
223
+
224
+ # Override probe_layers if stored in checkpoint
225
+ if "probe_layers" in ckpt:
226
+ result["probe_layers"] = ckpt["probe_layers"]
227
+
228
+ # Detect hidden_dim from weights
229
+ hidden_dim = ckpt.get("hidden_dim", config["hidden_dim"])
230
+ has_bias = any("bias" in k for k in ckpt if "fiber_projection" in k)
231
+
232
+ fiber = FiberProjection(
233
+ hidden_dim=hidden_dim, fiber_dim=16,
234
+ num_layers=3, bias=has_bias
235
+ ).to(device)
236
+ fiber_state = {k.replace("fiber_projection.", ""): v
237
+ for k, v in ckpt.items() if k.startswith("fiber_projection.")}
238
+ fiber.load_state_dict(fiber_state)
239
+ fiber.eval()
240
+
241
+ head = ProbeHead(fiber_dim=16, hidden_dim=64).to(device)
242
+ # Cognitive probes use either 'classifier' or 'net' naming
243
+ head_state = {}
244
+ for k, v in ckpt.items():
245
+ if k.startswith("head_state."):
246
+ clean = k.replace("head_state.", "")
247
+ # Normalize 'net.*' to 'classifier.*'
248
+ clean = clean.replace("net.", "classifier.")
249
+ head_state[clean] = v
250
+ head.load_state_dict(head_state)
251
+ head.eval()
252
+
253
+ result["fiber"] = fiber
254
+ result["head"] = head
255
+
256
+ return result
257
+
258
+
259
+ def score_hidden_states(probe, hidden_states, position=-1):
260
+ """
261
+ Score hidden states using a loaded probe.
262
+
263
+ Args:
264
+ probe: dict returned by load_probe()
265
+ hidden_states: tuple of tensors from model(output_hidden_states=True)
266
+ position: token position to score (default: last token)
267
+
268
+ Returns:
269
+ float: risk/behavioral score between 0 and 1
270
+ """
271
+ layers = probe["probe_layers"]
272
+
273
+ if probe["type"] == "risk_predictor":
274
+ hs = [hidden_states[i][:, position, :] for i in range(len(hidden_states))
275
+ if i < len(hidden_states)]
276
+ with torch.no_grad():
277
+ return probe["risk_predictor"](hs).item()
278
+ else:
279
+ hs = [hidden_states[i][:, position, :] for i in layers]
280
+ with torch.no_grad():
281
+ fiber_vec = probe["fiber"](hs)
282
+ return probe["head"](fiber_vec).item()
283
+
284
+
285
+ # ─── CLI demo ───────────────────────────────────────────────────────
286
+
287
+ def main():
288
+ parser = argparse.ArgumentParser(description="CF-HoT Probe Inference")
289
+ parser.add_argument("--probe", required=True,
290
+ help="Path to probe directory (e.g. suppression/hedging_168x)")
291
+ parser.add_argument("--prompt", default="Can you explain quantum computing?",
292
+ help="Text prompt to analyze")
293
+ parser.add_argument("--device", default="cuda")
294
+ parser.add_argument("--info-only", action="store_true",
295
+ help="Just print probe info, don't load base model")
296
+ args = parser.parse_args()
297
+
298
+ print(f"Loading probe from: {args.probe}")
299
+ probe = load_probe(args.probe, device=args.device)
300
+
301
+ print(f" Type: {probe['type']}")
302
+ print(f" Architecture: {probe['arch']}")
303
+ print(f" Base model: {probe['config']['model_id']}")
304
+ print(f" Probe layers: {probe['probe_layers']}")
305
+ if probe["metadata"]:
306
+ for k, v in probe["metadata"].items():
307
+ print(f" {k}: {v}")
308
+
309
+ if args.info_only:
310
+ return
311
+
312
+ # Load base model
313
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
314
+
315
+ model_id = probe["config"]["model_id"]
316
+ print(f"\nLoading {model_id}...")
317
+
318
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
319
+ model = AutoModelForCausalLM.from_pretrained(
320
+ model_id,
321
+ quantization_config=BitsAndBytesConfig(
322
+ load_in_4bit=True,
323
+ bnb_4bit_compute_dtype=torch.float16,
324
+ ),
325
+ device_map="auto",
326
+ output_hidden_states=True,
327
+ )
328
+ model.eval()
329
+
330
+ # Tokenize and run
331
+ inputs = tokenizer(args.prompt, return_tensors="pt").to(args.device)
332
+ with torch.no_grad():
333
+ outputs = model(**inputs, output_hidden_states=True)
334
+
335
+ score = score_hidden_states(probe, outputs.hidden_states)
336
+ print(f"\nPrompt: {args.prompt}")
337
+ print(f"Score: {score:.4f}")
338
+ print(f" (>0.5 = behavioral pattern detected, <0.5 = normal)")
339
+
340
+
341
+ if __name__ == "__main__":
342
+ main()