specimba commited on
Commit
375fc30
·
verified ·
1 Parent(s): baea714

Copy nexus_os_v2/ckplug_retriever.py from dataset for module imports

Browse files
Files changed (1) hide show
  1. nexus_os_v2/ckplug_retriever.py +192 -0
nexus_os_v2/ckplug_retriever.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CK-PLUG Integration for NEXUS OS v2
3
+ Implements Confidence Gain (CG) as the concrete μ_ret chemical potential.
4
+
5
+ Paper: arXiv:2503.15888 — Parameters vs. Context: Fine-Grained Control
6
+ of Knowledge Reliance in Language Models
7
+
8
+ Model-specific ε thresholds (from Appendix B):
9
+ LLaMA2-7B: -2 | LLaMA3-8B: -1
10
+ Mistral-0.3-7B: -1 | Qwen2.5-7B: -3
11
+ For general use: default ε = -1
12
+ """
13
+ import math
14
+ import torch
15
+ from typing import List, Optional, Dict, Tuple, Callable
16
+ from dataclasses import dataclass
17
+
18
+ @dataclass
19
+ class TokenModulation:
20
+ """Result of CK-PLUG token-level modulation."""
21
+ token_id: int
22
+ original_prob: float
23
+ modulated_prob: float
24
+ cg: float # Confidence Gain
25
+ H_para: float # Entropy (query-only)
26
+ H_cont: float # Entropy (query+retrieval)
27
+ was_modulated: bool # True if this token was in V_head and CG < threshold
28
+ alpha: float # Adaptive blending weight
29
+
30
+
31
+ class CKPLUGCoupling:
32
+ """
33
+ Concrete implementation of the retrieval chemical potential μ_ret
34
+ from the NEXUS OS Landau-Ginzburg framework.
35
+
36
+ μ_ret(x) = μ_0 * grounding_score(x)
37
+ where grounding_score is derived from CK-PLUG Confidence Gain:
38
+ - CG > 0 → retrieval SUPPORTS parametric knowledge (high grounding)
39
+ - CG < 0 → retrieval CONFLICTS with parametric knowledge (low grounding)
40
+ - |CG| → magnitude of confidence shift
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ epsilon: float = -1.0, # Model-specific detection threshold
46
+ top_k: int = 50, # Union top-k for V_head
47
+ mu_0: float = 0.5, # Base chemical potential (from LG framework)
48
+ device: str = "cpu",
49
+ ):
50
+ self.epsilon = epsilon
51
+ self.top_k = top_k
52
+ self.mu_0 = mu_0
53
+ self.device = device
54
+
55
+ @staticmethod
56
+ def entropy(probs: torch.Tensor) -> float:
57
+ """Shannon entropy H = -Σ p_i log₂ p_i."""
58
+ p = probs[probs > 0]
59
+ return float(-(p * torch.log2(p)).sum().item())
60
+
61
+ @staticmethod
62
+ def confidence_gain(
63
+ p_query: torch.Tensor, # p(x | X_q) — parametric only
64
+ p_rag: torch.Tensor, # p(x | X_r + X_q) — with retrieval
65
+ ) -> Tuple[float, float, float]:
66
+ """
67
+ Returns: (CG, H_para, H_cont)
68
+ CG = H(p(x|X_q)) - H(p(x|X_r+X_q))
69
+ Positive CG → retrieval supports (reduces entropy)
70
+ Negative CG → retrieval conflicts (increases entropy)
71
+ """
72
+ H_para = CKPLUGCoupling.entropy(p_query)
73
+ H_cont = CKPLUGCoupling.entropy(p_rag)
74
+ CG = H_para - H_cont
75
+ return CG, H_para, H_cont
76
+
77
+ def compute_chemical_potential(
78
+ self,
79
+ p_query: torch.Tensor,
80
+ p_rag: torch.Tensor,
81
+ ) -> float:
82
+ """
83
+ Map CK-PLUG Confidence Gain to Landau-Ginzburg chemical potential μ_ret.
84
+
85
+ Logic:
86
+ CG >> 0 → retrieval strongly supports → μ_ret ≈ μ_0 (max grounding)
87
+ CG ≈ 0 → neutral → μ_ret ≈ 0 (no coupling)
88
+ CG << 0 → retrieval conflicts → μ_ret ≈ -μ_0 (adversarial)
89
+
90
+ We use a tanh-sigmoid for smooth interpolation:
91
+ μ_ret = μ_0 * tanh(CG / τ) where τ controls transition sharpness.
92
+ """
93
+ CG, _, _ = self.confidence_gain(p_query, p_rag)
94
+ tau = 0.5 # Transition width in nats
95
+ mu_ret = self.mu_0 * math.tanh(CG / tau)
96
+ return mu_ret
97
+
98
+ def modulate_token(
99
+ self,
100
+ p_query: torch.Tensor, # Shape: (vocab_size,)
101
+ p_rag: torch.Tensor, # Shape: (vocab_size,)
102
+ ) -> Tuple[torch.Tensor, TokenModulation]:
103
+ """
104
+ Apply CK-PLUG token-level modulation (Eq. 7-10 from paper).
105
+ Returns: (modulated_distribution, modulation_metadata)
106
+ """
107
+ CG, H_para, H_cont = self.confidence_gain(p_query, p_rag)
108
+
109
+ # Refined detection threshold (Eq. 11 / Appendix B)
110
+ threshold = self.epsilon * abs(H_cont)
111
+
112
+ if CG >= threshold:
113
+ # No conflict — pass through RAG distribution unchanged
114
+ return p_rag, TokenModulation(
115
+ token_id=-1, original_prob=0.0, modulated_prob=0.0,
116
+ cg=CG, H_para=H_para, H_cont=H_cont,
117
+ was_modulated=False, alpha=0.0,
118
+ )
119
+
120
+ # Conflict detected — apply modulation
121
+ # Eq. 5: Parameter-aware log probability
122
+ q_para = torch.log(p_query + 1e-10)
123
+
124
+ # Eq. 6: Context-aware log probability
125
+ q_cont = torch.log((p_rag + 1e-10) / (p_query + 1e-10))
126
+
127
+ # Eq. 10: Adaptive alpha
128
+ alpha = H_cont / (H_para + H_cont + 1e-10)
129
+ alpha = float(torch.clamp(torch.tensor(alpha), 0.0, 1.0).item())
130
+
131
+ # Build V_head: union of top-k from both distributions
132
+ topk_para = torch.topk(q_para, self.top_k).indices
133
+ topk_cont = torch.topk(q_cont, self.top_k).indices
134
+ V_head = torch.unique(torch.cat([topk_para, topk_cont]))
135
+
136
+ # Eq. 8: Modulation function F
137
+ F = torch.full_like(q_para, -float('inf'))
138
+ F[V_head] = alpha * q_para[V_head] + (1.0 - alpha) * q_cont[V_head]
139
+
140
+ # Softmax to get modulated distribution
141
+ p_mod = torch.softmax(F, dim=-1)
142
+
143
+ # Find most changed token for metadata
144
+ diff = torch.abs(p_rag - p_mod)
145
+ changed_id = int(torch.argmax(diff).item())
146
+
147
+ modulation = TokenModulation(
148
+ token_id=changed_id,
149
+ original_prob=float(p_rag[changed_id].item()),
150
+ modulated_prob=float(p_mod[changed_id].item()),
151
+ cg=CG, H_para=H_para, H_cont=H_cont,
152
+ was_modulated=True, alpha=alpha,
153
+ )
154
+
155
+ return p_mod, modulation
156
+
157
+ def batch_modulate(
158
+ self,
159
+ p_queries: List[torch.Tensor], # List of (vocab_size,) tensors
160
+ p_rags: List[torch.Tensor], # Same length
161
+ ) -> List[Tuple[torch.Tensor, TokenModulation]]:
162
+ """Apply CK-PLUG to a batch of token positions."""
163
+ return [self.modulate_token(pq, pr) for pq, pr in zip(p_queries, p_rags)]
164
+
165
+ def get_grounding_field(self, p_query: torch.Tensor, p_rag: torch.Tensor) -> float:
166
+ """
167
+ Return the scalar μ_ret value for insertion into Landau-Ginzburg functional.
168
+ This is the key bridge between CK-PLUG (empirical) and NEXUS OS physics.
169
+ """
170
+ return self.compute_chemical_potential(p_query, p_rag)
171
+
172
+
173
+ # Model-specific epsilon presets (from CK-PLUG Appendix B)
174
+ CKPLUG_PRESETS = {
175
+ "llama2": -2.0,
176
+ "llama3": -1.0,
177
+ "mistral": -1.0,
178
+ "qwen2.5": -3.0,
179
+ "granite": -1.5, # Estimated from paper patterns
180
+ "gemma": -1.0, # Estimated
181
+ "deepseek": -2.0, # Estimated (large MoE, conservative)
182
+ "default": -1.0,
183
+ }
184
+
185
+
186
+ def get_preset_epsilon(model_family: str) -> float:
187
+ """Get recommended epsilon for a model family."""
188
+ key = model_family.lower()
189
+ for k, v in CKPLUG_PRESETS.items():
190
+ if k in key:
191
+ return v
192
+ return CKPLUG_PRESETS["default"]