| |
| """ |
| CF-HoT Universal Probe Loader |
| |
| Load any probe from this repo and run it on a model's hidden states. |
| Works with all suppression probes (LLaMA 8B) and cognitive enhancement |
| probes (Qwen, Mamba, Mistral). |
| |
| Usage: |
| python inference.py --probe suppression/hedging_168x |
| python inference.py --probe cognitive/mistral/depth |
| python inference.py --probe suppression/repetition_125x --prompt "Tell me about AI" |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import argparse |
| import os |
| import glob |
|
|
|
|
| |
|
|
| class FiberProjection(nn.Module): |
| """Projects hidden states from multiple layers into fiber space.""" |
| def __init__(self, hidden_dim, fiber_dim=16, num_layers=3, bias=True): |
| super().__init__() |
| self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
| self.projections = nn.ModuleList([ |
| nn.Linear(hidden_dim, fiber_dim, bias=bias) |
| for _ in range(num_layers) |
| ]) |
|
|
| def forward(self, hidden_states_list): |
| weights = torch.softmax(self.layer_weights, dim=0) |
| return sum(w * proj(h.float()) |
| for w, h, proj in zip(weights, hidden_states_list, self.projections)) |
|
|
|
|
| class ProbeHead(nn.Module): |
| """Classifies fiber-space vectors into behavioral risk scores.""" |
| def __init__(self, fiber_dim=16, hidden_dim=64): |
| super().__init__() |
| self.classifier = nn.Sequential( |
| nn.Linear(fiber_dim, hidden_dim), nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim), nn.GELU(), |
| nn.Linear(hidden_dim, 1), |
| ) |
|
|
| def forward(self, x): |
| return torch.sigmoid(self.classifier(x)) |
|
|
|
|
| class RiskPredictor(nn.Module): |
| """Full risk predictor (used by repetition_125x). All-layer version.""" |
| def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=32): |
| super().__init__() |
| self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) |
| self.fiber_projs = nn.ModuleList([ |
| nn.Linear(hidden_dim, fiber_dim, bias=False) |
| for _ in range(n_layers) |
| ]) |
| self.predictor = nn.Sequential( |
| nn.Linear(fiber_dim, 64), nn.GELU(), |
| nn.Linear(64, 64), nn.GELU(), |
| nn.Linear(64, 1), |
| ) |
|
|
| def forward(self, hidden_states_list): |
| weights = torch.softmax(self.layer_weights, dim=0) |
| fiber = sum(w * proj(h.float()) |
| for w, h, proj in zip(weights, hidden_states_list, self.fiber_projs)) |
| return torch.sigmoid(self.predictor(fiber)) |
|
|
|
|
| |
|
|
| |
| MODEL_CONFIGS = { |
| "llama": { |
| "model_id": "meta-llama/Llama-3.1-8B-Instruct", |
| "hidden_dim": 4096, |
| "n_layers": 32, |
| "probe_layers": [10, 20, 30], |
| }, |
| "qwen": { |
| "model_id": "Qwen/Qwen2.5-7B-Instruct", |
| "hidden_dim": 3584, |
| "n_layers": 28, |
| "probe_layers": [9, 18, 27], |
| }, |
| "mamba": { |
| "model_id": "tiiuae/falcon-mamba-7b-instruct", |
| "hidden_dim": 4096, |
| "n_layers": 64, |
| "probe_layers": [16, 32, 48], |
| }, |
| "mistral": { |
| "model_id": "mistralai/Mistral-7B-Instruct-v0.3", |
| "hidden_dim": 4096, |
| "n_layers": 32, |
| "probe_layers": [8, 16, 24], |
| }, |
| } |
|
|
|
|
| def detect_probe_type(probe_path): |
| """Auto-detect what kind of probe checkpoint this is.""" |
| files = os.listdir(probe_path) if os.path.isdir(probe_path) else [] |
|
|
| |
| if "risk_predictor.pt" in files: |
| return "risk_predictor" |
|
|
| |
| head_files = [f for f in files if f.endswith("_head.pt")] |
| if head_files and "fiber_proj.pt" in files: |
| return "suppression" |
|
|
| |
| if head_files and "fiber_proj.pt" not in files: |
| return "cognitive" |
|
|
| return "unknown" |
|
|
|
|
| def detect_architecture(probe_path): |
| """Detect which base model architecture a probe targets.""" |
| path_lower = probe_path.lower() |
| if "qwen" in path_lower: |
| return "qwen" |
| elif "mamba" in path_lower: |
| return "mamba" |
| elif "mistral" in path_lower: |
| return "mistral" |
| else: |
| return "llama" |
|
|
|
|
| def load_probe(probe_path, device="cuda"): |
| """ |
| Load any CF-HoT probe from a directory. |
| |
| Returns: |
| dict with keys: |
| - 'type': str ('risk_predictor', 'suppression', or 'cognitive') |
| - 'arch': str ('llama', 'qwen', 'mamba', 'mistral') |
| - 'config': dict (model config) |
| - 'fiber': FiberProjection or None |
| - 'head': ProbeHead or None |
| - 'risk_predictor': RiskPredictor or None |
| - 'probe_layers': list[int] |
| - 'metadata': dict (step, separation, etc.) |
| """ |
| probe_type = detect_probe_type(probe_path) |
| arch = detect_architecture(probe_path) |
| config = MODEL_CONFIGS[arch] |
|
|
| result = { |
| "type": probe_type, |
| "arch": arch, |
| "config": config, |
| "fiber": None, |
| "head": None, |
| "risk_predictor": None, |
| "probe_layers": config["probe_layers"], |
| "metadata": {}, |
| } |
|
|
| if probe_type == "risk_predictor": |
| ckpt = torch.load(os.path.join(probe_path, "risk_predictor.pt"), |
| map_location=device, weights_only=False) |
| rp = RiskPredictor( |
| hidden_dim=config["hidden_dim"], |
| fiber_dim=16, |
| n_layers=config["n_layers"] |
| ).to(device) |
| |
| state = {k.replace("risk_predictor.", ""): v |
| for k, v in ckpt.items() if k.startswith("risk_predictor.")} |
| rp.load_state_dict(state) |
| rp.eval() |
| result["risk_predictor"] = rp |
| result["probe_layers"] = list(range(config["n_layers"])) |
| if "step" in ckpt: |
| result["metadata"]["step"] = ckpt["step"] |
|
|
| elif probe_type == "suppression": |
| |
| head_file = [f for f in os.listdir(probe_path) if f.endswith("_head.pt")][0] |
| head_ckpt = torch.load(os.path.join(probe_path, head_file), |
| map_location=device, weights_only=False) |
| fiber_ckpt = torch.load(os.path.join(probe_path, "fiber_proj.pt"), |
| map_location=device, weights_only=False) |
|
|
| |
| has_bias = any("bias" in k for k in fiber_ckpt.keys()) |
|
|
| fiber = FiberProjection( |
| hidden_dim=config["hidden_dim"], fiber_dim=16, |
| num_layers=3, bias=has_bias |
| ).to(device) |
| fiber.load_state_dict(fiber_ckpt) |
| fiber.eval() |
|
|
| head = ProbeHead(fiber_dim=16, hidden_dim=64).to(device) |
| head.load_state_dict(head_ckpt) |
| head.eval() |
|
|
| result["fiber"] = fiber |
| result["head"] = head |
|
|
| elif probe_type == "cognitive": |
| head_file = [f for f in os.listdir(probe_path) if f.endswith("_head.pt")][0] |
| ckpt = torch.load(os.path.join(probe_path, head_file), |
| map_location=device, weights_only=False) |
|
|
| |
| for key in ["step", "separation", "loss", "probe_name", |
| "hidden_dim", "probe_layers", "architecture"]: |
| if key in ckpt: |
| result["metadata"][key] = ckpt[key] |
|
|
| |
| if "probe_layers" in ckpt: |
| result["probe_layers"] = ckpt["probe_layers"] |
|
|
| |
| hidden_dim = ckpt.get("hidden_dim", config["hidden_dim"]) |
| has_bias = any("bias" in k for k in ckpt if "fiber_projection" in k) |
|
|
| fiber = FiberProjection( |
| hidden_dim=hidden_dim, fiber_dim=16, |
| num_layers=3, bias=has_bias |
| ).to(device) |
| fiber_state = {k.replace("fiber_projection.", ""): v |
| for k, v in ckpt.items() if k.startswith("fiber_projection.")} |
| fiber.load_state_dict(fiber_state) |
| fiber.eval() |
|
|
| head = ProbeHead(fiber_dim=16, hidden_dim=64).to(device) |
| |
| head_state = {} |
| for k, v in ckpt.items(): |
| if k.startswith("head_state."): |
| clean = k.replace("head_state.", "") |
| |
| clean = clean.replace("net.", "classifier.") |
| head_state[clean] = v |
| head.load_state_dict(head_state) |
| head.eval() |
|
|
| result["fiber"] = fiber |
| result["head"] = head |
|
|
| return result |
|
|
|
|
| def score_hidden_states(probe, hidden_states, position=-1): |
| """ |
| Score hidden states using a loaded probe. |
| |
| Args: |
| probe: dict returned by load_probe() |
| hidden_states: tuple of tensors from model(output_hidden_states=True) |
| position: token position to score (default: last token) |
| |
| Returns: |
| float: risk/behavioral score between 0 and 1 |
| """ |
| layers = probe["probe_layers"] |
|
|
| if probe["type"] == "risk_predictor": |
| hs = [hidden_states[i][:, position, :] for i in range(len(hidden_states)) |
| if i < len(hidden_states)] |
| with torch.no_grad(): |
| return probe["risk_predictor"](hs).item() |
| else: |
| hs = [hidden_states[i][:, position, :] for i in layers] |
| with torch.no_grad(): |
| fiber_vec = probe["fiber"](hs) |
| return probe["head"](fiber_vec).item() |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="CF-HoT Probe Inference") |
| parser.add_argument("--probe", required=True, |
| help="Path to probe directory (e.g. suppression/hedging_168x)") |
| parser.add_argument("--prompt", default="Can you explain quantum computing?", |
| help="Text prompt to analyze") |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument("--info-only", action="store_true", |
| help="Just print probe info, don't load base model") |
| args = parser.parse_args() |
|
|
| print(f"Loading probe from: {args.probe}") |
| probe = load_probe(args.probe, device=args.device) |
|
|
| print(f" Type: {probe['type']}") |
| print(f" Architecture: {probe['arch']}") |
| print(f" Base model: {probe['config']['model_id']}") |
| print(f" Probe layers: {probe['probe_layers']}") |
| if probe["metadata"]: |
| for k, v in probe["metadata"].items(): |
| print(f" {k}: {v}") |
|
|
| if args.info_only: |
| return |
|
|
| |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
| model_id = probe["config"]["model_id"] |
| print(f"\nLoading {model_id}...") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| quantization_config=BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.float16, |
| ), |
| device_map="auto", |
| output_hidden_states=True, |
| ) |
| model.eval() |
|
|
| |
| inputs = tokenizer(args.prompt, return_tensors="pt").to(args.device) |
| with torch.no_grad(): |
| outputs = model(**inputs, output_hidden_states=True) |
|
|
| score = score_hidden_states(probe, outputs.hidden_states) |
| print(f"\nPrompt: {args.prompt}") |
| print(f"Score: {score:.4f}") |
| print(f" (>0.5 = behavioral pattern detected, <0.5 = normal)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|