File size: 2,384 Bytes
3aeaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import annotations

import json
import math
import os
import pickle
import random
from typing import Any


DIRECTION_LABELS = [
    "authority",
    "refusal_suppression",
    "trust_escalation",
    "danger",
    "compliance",
    "distraction",
]

INTENT_CLASSES = ["benign", "escalation", "extraction", "distraction"]


class DirectionLibrary:
    def __init__(
        self,
        library_path: str = "data/direction_library.json",
        probe_path: str = "data/intent_probes.pkl",
        hidden_size: int = 1024,
    ) -> None:
        self.hidden_size = hidden_size
        self._vectors: dict[str, list[float]] = {}
        self._intent_probes: dict[int, Any] = {}
        if os.path.exists(library_path):
            self._load_vectors(library_path)
        else:
            self._init_random_vectors()
        if os.path.exists(probe_path):
            with open(probe_path, "rb") as handle:
                self._intent_probes = pickle.load(handle)

    def _load_vectors(self, path: str) -> None:
        with open(path) as handle:
            data = json.load(handle)
        self._vectors = {label: [float(x) for x in vec] for label, vec in data.items()}

    def _init_random_vectors(self) -> None:
        rng = random.Random(1337)
        for label in DIRECTION_LABELS:
            vec = [rng.gauss(0.0, 1.0) for _ in range(self.hidden_size)]
            norm = math.sqrt(sum(x * x for x in vec)) or 1.0
            self._vectors[label] = [x / norm for x in vec]

    def get_vector(self, label: str) -> list[float] | None:
        return self._vectors.get(label)

    def all_vectors(self) -> dict[str, list[float]]:
        return dict(self._vectors)

    def run_intent_probe(self, activation: list[float], layer: int) -> tuple[str, float]:
        probe = self._intent_probes.get(layer)
        if probe is None:
            return ("benign", 0.5)
        probs = probe.predict_proba([activation])[0]
        idx = max(range(len(probs)), key=lambda i: probs[i])
        return (INTENT_CLASSES[idx], float(probs[idx]))

    def save(self, library_path: str, probe_path: str) -> None:
        os.makedirs(os.path.dirname(library_path), exist_ok=True)
        with open(library_path, "w") as handle:
            json.dump(self._vectors, handle)
        with open(probe_path, "wb") as handle:
            pickle.dump(self._intent_probes, handle)