Spaces:
Runtime error
feat(v2-round6): TruthRL + Validator-RLVR + LoraHub + Self-Refine + Letta + 9-cluster configs
Browse filesSix more techniques shipped (asked but not yet built in Round 5):
- bin/v2/validator-rlvr.py β pyflakes/shellcheck/hadolint/tflint/cfn-lint/
actionlint/sqlfluff/semgrep as deterministic
reward signals for stage3 RL
- bin/v2/truthrl-rewarder.py β ternary reward (+1 calibrated_idk / +1 confident_correct
/ -1 confident_wrong / -0.3 over_abstain)
- bin/v2/lorahub-composer.py β runtime LoRA composition with learned routing
table (heuristic seed, learns from winners)
- bin/v2/self-refine-loop.py β Madaan 2023 3-iter generateβcritiqueβrevise SFT
- bin/v2/letta-memory.py β hierarchical core+recall+archival memory
(Packer 2023, formerly MemGPT)
- bin/v2/gen-cluster-configs.sh β emits 9 cluster LoRA YAMLs from template
(eng-build/ops/sec/ai + product-ux + gtm +
finance-legal + compliance + meta-orchestrator)
Sanitizer hardening:
- bin/lib/sanitize.py β optional starpii NER + detect-secrets integration
via filter_pair(deep_scan=True). Lazy-loaded, fail-soft.
- bin/lib/sanitize.py +110 -3
- bin/v2/gen-cluster-configs.sh +214 -0
- bin/v2/letta-memory.py +200 -0
- bin/v2/lorahub-composer.py +281 -0
- bin/v2/self-refine-loop.py +163 -0
- bin/v2/truthrl-rewarder.py +195 -0
- bin/v2/validator-rlvr.py +348 -0
|
@@ -113,6 +113,92 @@ def has_pii(text: str) -> bool:
|
|
| 113 |
return bool(PII_RE.search(text or ""))
|
| 114 |
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# Quality heuristics β drop if response is too short, identical to prompt, etc.
|
| 117 |
def is_low_quality(prompt: str, response: str) -> tuple[bool, str | None]:
|
| 118 |
if not prompt or not response:
|
|
@@ -133,16 +219,37 @@ def is_low_quality(prompt: str, response: str) -> tuple[bool, str | None]:
|
|
| 133 |
return False, None
|
| 134 |
|
| 135 |
|
| 136 |
-
def filter_pair(prompt: str, response: str
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
polluted, p_match = is_polluted_pair(prompt, response)
|
| 139 |
if polluted:
|
| 140 |
return {"keep": False, "reason": "polluted", "matched": p_match}
|
| 141 |
if has_pii(prompt) or has_pii(response):
|
| 142 |
-
return {"keep": False, "reason": "
|
| 143 |
low_q, lq_reason = is_low_quality(prompt, response)
|
| 144 |
if low_q:
|
| 145 |
return {"keep": False, "reason": f"low_quality:{lq_reason}", "matched": None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
return {"keep": True, "reason": None, "matched": None}
|
| 147 |
|
| 148 |
|
|
|
|
| 113 |
return bool(PII_RE.search(text or ""))
|
| 114 |
|
| 115 |
|
| 116 |
+
# ββ Optional NER + secrets scanners (lazy, fail-soft) ββββββββββββββββββ
|
| 117 |
+
# starpii (BigCode) β neural PII NER; better than regex for free-form text.
|
| 118 |
+
# detect-secrets (Yelp) β entropy + plugin-based secret detector.
|
| 119 |
+
# Both are optional dependencies; if unavailable we fall back to regex above.
|
| 120 |
+
_starpii_pipeline = None
|
| 121 |
+
_detect_secrets_collection = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _load_starpii():
|
| 125 |
+
"""Lazy-load BigCode/starpii pipeline. None on failure."""
|
| 126 |
+
global _starpii_pipeline
|
| 127 |
+
if _starpii_pipeline is not None:
|
| 128 |
+
return _starpii_pipeline if _starpii_pipeline is not False else None
|
| 129 |
+
try:
|
| 130 |
+
from transformers import pipeline # type: ignore
|
| 131 |
+
_starpii_pipeline = pipeline(
|
| 132 |
+
"token-classification",
|
| 133 |
+
model="bigcode/starpii",
|
| 134 |
+
aggregation_strategy="simple",
|
| 135 |
+
)
|
| 136 |
+
return _starpii_pipeline
|
| 137 |
+
except Exception:
|
| 138 |
+
_starpii_pipeline = False # sentinel: "tried, don't try again"
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def starpii_pii_hits(text: str, threshold: float = 0.8) -> list[dict]:
|
| 143 |
+
"""Return [{type, score, span}] for confidently-detected PII spans.
|
| 144 |
+
Empty list if starpii not installed or no hits.
|
| 145 |
+
"""
|
| 146 |
+
pipe = _load_starpii()
|
| 147 |
+
if not pipe or not text:
|
| 148 |
+
return []
|
| 149 |
+
try:
|
| 150 |
+
hits = pipe(text[:4000]) # cap input for speed
|
| 151 |
+
except Exception:
|
| 152 |
+
return []
|
| 153 |
+
return [{"type": h["entity_group"], "score": float(h["score"]),
|
| 154 |
+
"span": text[h["start"]:h["end"]][:120]}
|
| 155 |
+
for h in hits if h.get("score", 0) >= threshold]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _load_detect_secrets():
|
| 159 |
+
"""Lazy-load detect-secrets SecretsCollection. None on failure."""
|
| 160 |
+
global _detect_secrets_collection
|
| 161 |
+
if _detect_secrets_collection is not None:
|
| 162 |
+
return _detect_secrets_collection if _detect_secrets_collection is not False else None
|
| 163 |
+
try:
|
| 164 |
+
from detect_secrets import SecretsCollection # type: ignore
|
| 165 |
+
from detect_secrets.settings import default_settings # type: ignore
|
| 166 |
+
_detect_secrets_collection = (SecretsCollection, default_settings)
|
| 167 |
+
return _detect_secrets_collection
|
| 168 |
+
except Exception:
|
| 169 |
+
_detect_secrets_collection = False
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def detect_secrets_hits(text: str) -> list[dict]:
|
| 174 |
+
"""Return [{type, line}] for any secret detect-secrets finds.
|
| 175 |
+
Empty list if not installed or none detected.
|
| 176 |
+
"""
|
| 177 |
+
loaded = _load_detect_secrets()
|
| 178 |
+
if not loaded or not text:
|
| 179 |
+
return []
|
| 180 |
+
SecretsCollection, default_settings = loaded
|
| 181 |
+
import tempfile, os
|
| 182 |
+
fd, path = tempfile.mkstemp(suffix=".txt")
|
| 183 |
+
try:
|
| 184 |
+
os.write(fd, text.encode("utf-8", "ignore")[:200_000])
|
| 185 |
+
os.close(fd)
|
| 186 |
+
with default_settings():
|
| 187 |
+
sc = SecretsCollection()
|
| 188 |
+
sc.scan_file(path)
|
| 189 |
+
out = []
|
| 190 |
+
for _, secrets in sc.data.items():
|
| 191 |
+
for s in secrets:
|
| 192 |
+
out.append({"type": s.type, "line": s.line_number,
|
| 193 |
+
"secret_hash": s.secret_hash[:16]})
|
| 194 |
+
return out
|
| 195 |
+
except Exception:
|
| 196 |
+
return []
|
| 197 |
+
finally:
|
| 198 |
+
try: os.unlink(path)
|
| 199 |
+
except OSError: pass
|
| 200 |
+
|
| 201 |
+
|
| 202 |
# Quality heuristics β drop if response is too short, identical to prompt, etc.
|
| 203 |
def is_low_quality(prompt: str, response: str) -> tuple[bool, str | None]:
|
| 204 |
if not prompt or not response:
|
|
|
|
| 219 |
return False, None
|
| 220 |
|
| 221 |
|
| 222 |
+
def filter_pair(prompt: str, response: str,
|
| 223 |
+
deep_scan: bool = False) -> dict:
|
| 224 |
+
"""Return verdict: {'keep': bool, 'reason': str|None, 'matched': str|None}.
|
| 225 |
+
|
| 226 |
+
deep_scan=True: also runs starpii NER + detect-secrets if installed.
|
| 227 |
+
Slow (model load + per-row scan) β use for the final pre-train pass,
|
| 228 |
+
not for every dedup row. Heuristic (regex) checks always run.
|
| 229 |
+
"""
|
| 230 |
polluted, p_match = is_polluted_pair(prompt, response)
|
| 231 |
if polluted:
|
| 232 |
return {"keep": False, "reason": "polluted", "matched": p_match}
|
| 233 |
if has_pii(prompt) or has_pii(response):
|
| 234 |
+
return {"keep": False, "reason": "pii_regex", "matched": None}
|
| 235 |
low_q, lq_reason = is_low_quality(prompt, response)
|
| 236 |
if low_q:
|
| 237 |
return {"keep": False, "reason": f"low_quality:{lq_reason}", "matched": None}
|
| 238 |
+
|
| 239 |
+
if deep_scan:
|
| 240 |
+
# NER PII
|
| 241 |
+
for field, txt in (("prompt", prompt), ("response", response)):
|
| 242 |
+
hits = starpii_pii_hits(txt)
|
| 243 |
+
if hits:
|
| 244 |
+
return {"keep": False, "reason": f"pii_ner:{field}",
|
| 245 |
+
"matched": str(hits[:3])[:300]}
|
| 246 |
+
# detect-secrets entropy/plugins
|
| 247 |
+
for field, txt in (("prompt", prompt), ("response", response)):
|
| 248 |
+
hits = detect_secrets_hits(txt)
|
| 249 |
+
if hits:
|
| 250 |
+
return {"keep": False, "reason": f"secrets:{field}",
|
| 251 |
+
"matched": str(hits[:3])[:300]}
|
| 252 |
+
|
| 253 |
return {"keep": True, "reason": None, "matched": None}
|
| 254 |
|
| 255 |
|
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Surrogate-1 v2 β Generate 9 cluster LoRA training configs from one template.
|
| 3 |
+
#
|
| 4 |
+
# Each cluster is trained independently on its domain slice of v2 data,
|
| 5 |
+
# then mergekit fuses them via merge-9-loras.sh.
|
| 6 |
+
#
|
| 7 |
+
# Run BEFORE training: bash gen-cluster-configs.sh
|
| 8 |
+
# Output: configs/v2/cluster-<name>.yml Γ 9
|
| 9 |
+
#
|
| 10 |
+
# Domain β dataset filter mapping (each cluster pulls a subset of v2-train-clean):
|
| 11 |
+
# eng-build: code-* + ai-eng + api-* + test-* + debug-*
|
| 12 |
+
# eng-ops: devops-* + sre-* + ci-* + cloud-cost
|
| 13 |
+
# eng-sec: sec-* + safety + cve + secrets + iam
|
| 14 |
+
# eng-ai: ai-eng + ai-prompt + rag + lora + vllm
|
| 15 |
+
# product-ux: docs-* + arch-adr + design-*
|
| 16 |
+
# gtm: business + marketing + sales
|
| 17 |
+
# finance-legal: finance-* + legal-* + cost-*
|
| 18 |
+
# compliance: compliance + soc2 + iso27001 + hipaa + pci
|
| 19 |
+
# meta-orchestrator: arch-adr + planning + multi-step
|
| 20 |
+
set -uo pipefail
|
| 21 |
+
|
| 22 |
+
OUT_DIR="$HOME/.surrogate/hf-space/configs/v2"
|
| 23 |
+
TEMPLATE_DOMAINS=(
|
| 24 |
+
"eng-build:code-*,ai-eng,api-*,test-*,debug-*,perf-*"
|
| 25 |
+
"eng-ops:devops-*,sre-*,ci-*,cloud-cost,iac-*"
|
| 26 |
+
"eng-sec:sec-*,safety-*,cve-*,iam-*,secrets-*"
|
| 27 |
+
"eng-ai:ai-eng,ai-prompt,rag,lora,vllm,embedding"
|
| 28 |
+
"product-ux:docs-*,arch-adr,design-*,user-*"
|
| 29 |
+
"gtm:business,marketing,sales,positioning,gtm"
|
| 30 |
+
"finance-legal:finance-*,legal-*,cost-*,billing"
|
| 31 |
+
"compliance:compliance,soc2,iso27001,hipaa,pci-dss,gdpr"
|
| 32 |
+
"meta-orchestrator:arch-adr,planning,multi-step,orchestration"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Per-cluster LoRA hyperparams β bigger ranks for domains with more data.
|
| 36 |
+
declare -A LORA_R=(
|
| 37 |
+
[eng-build]=64 [eng-ops]=64 [eng-sec]=48 [eng-ai]=48
|
| 38 |
+
[product-ux]=32 [gtm]=32 [finance-legal]=32 [compliance]=32
|
| 39 |
+
[meta-orchestrator]=64
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
mkdir -p "$OUT_DIR"
|
| 43 |
+
|
| 44 |
+
for entry in "${TEMPLATE_DOMAINS[@]}"; do
|
| 45 |
+
name="${entry%%:*}"
|
| 46 |
+
domain_filter="${entry#*:}"
|
| 47 |
+
rank="${LORA_R[$name]:-32}"
|
| 48 |
+
alpha=$((rank * 2))
|
| 49 |
+
out_yml="$OUT_DIR/cluster-${name}.yml"
|
| 50 |
+
|
| 51 |
+
cat > "$out_yml" <<EOF
|
| 52 |
+
# Surrogate-1 v2 β Cluster LoRA (auto-generated by gen-cluster-configs.sh).
|
| 53 |
+
# Cluster: $name
|
| 54 |
+
# Domain filter: $domain_filter
|
| 55 |
+
# Trained independently; merged via merge-9-loras.sh after all 9 finish.
|
| 56 |
+
|
| 57 |
+
base_model: Qwen/Qwen2.5-Coder-7B-Instruct
|
| 58 |
+
model_type: AutoModelForCausalLM
|
| 59 |
+
tokenizer_type: AutoTokenizer
|
| 60 |
+
trust_remote_code: true
|
| 61 |
+
|
| 62 |
+
load_in_4bit: true
|
| 63 |
+
strict: false
|
| 64 |
+
|
| 65 |
+
adapter: lora
|
| 66 |
+
lora_r: $rank
|
| 67 |
+
lora_alpha: $alpha
|
| 68 |
+
lora_dropout: 0.05
|
| 69 |
+
peft_use_dora: true
|
| 70 |
+
lora_target_modules:
|
| 71 |
+
- q_proj
|
| 72 |
+
- k_proj
|
| 73 |
+
- v_proj
|
| 74 |
+
- o_proj
|
| 75 |
+
- gate_proj
|
| 76 |
+
- up_proj
|
| 77 |
+
- down_proj
|
| 78 |
+
|
| 79 |
+
sequence_len: 32768
|
| 80 |
+
sample_packing: true
|
| 81 |
+
pad_to_sequence_len: true
|
| 82 |
+
rope_theta: 1000000.0
|
| 83 |
+
rope_scaling:
|
| 84 |
+
type: yarn
|
| 85 |
+
factor: 4.0
|
| 86 |
+
original_max_position_embeddings: 32768
|
| 87 |
+
|
| 88 |
+
datasets:
|
| 89 |
+
- path: /data/v2/clusters/${name}.jsonl
|
| 90 |
+
type: chat_template
|
| 91 |
+
field_messages: messages
|
| 92 |
+
ds_type: json
|
| 93 |
+
|
| 94 |
+
val_set_size: 0.02
|
| 95 |
+
output_dir: /data/v2/out/cluster-${name}
|
| 96 |
+
|
| 97 |
+
# Smaller clusters get fewer epochs (less data to overfit on)
|
| 98 |
+
num_epochs: 2
|
| 99 |
+
micro_batch_size: 1
|
| 100 |
+
gradient_accumulation_steps: 16
|
| 101 |
+
learning_rate: 1.0e-4
|
| 102 |
+
lr_scheduler: cosine
|
| 103 |
+
warmup_ratio: 0.03
|
| 104 |
+
optimizer: adamw_torch_fused
|
| 105 |
+
weight_decay: 0.01
|
| 106 |
+
max_grad_norm: 1.0
|
| 107 |
+
|
| 108 |
+
bf16: true
|
| 109 |
+
fp16: false
|
| 110 |
+
gradient_checkpointing: true
|
| 111 |
+
gradient_checkpointing_kwargs:
|
| 112 |
+
use_reentrant: false
|
| 113 |
+
flash_attention: true
|
| 114 |
+
liger_kernel: true
|
| 115 |
+
neftune_noise_alpha: 5
|
| 116 |
+
|
| 117 |
+
eval_steps: 100
|
| 118 |
+
save_steps: 100
|
| 119 |
+
save_total_limit: 2
|
| 120 |
+
logging_steps: 10
|
| 121 |
+
|
| 122 |
+
hub_model_id: axentx/surrogate-1-coder-7b-lora-v2-${name}
|
| 123 |
+
hub_strategy: every_save
|
| 124 |
+
push_to_hub: true
|
| 125 |
+
hub_private_repo: false
|
| 126 |
+
|
| 127 |
+
wandb_project: surrogate-1-v2-clusters
|
| 128 |
+
wandb_run_id: cluster-${name}
|
| 129 |
+
|
| 130 |
+
special_tokens:
|
| 131 |
+
pad_token: <|endoftext|>
|
| 132 |
+
|
| 133 |
+
resume_from_checkpoint: null
|
| 134 |
+
auto_resume_from_checkpoints: true
|
| 135 |
+
EOF
|
| 136 |
+
|
| 137 |
+
echo "βΆ $out_yml (rank=$rank, filter=$domain_filter)"
|
| 138 |
+
done
|
| 139 |
+
|
| 140 |
+
# Companion: dataset slicer that produces /data/v2/clusters/<name>.jsonl
|
| 141 |
+
SLICER="$OUT_DIR/../../bin/v2/slice-clusters.py"
|
| 142 |
+
cat > "$SLICER" <<'PYEOF'
|
| 143 |
+
"""Slice v2-train-clean.jsonl into 9 cluster files by domain tag.
|
| 144 |
+
|
| 145 |
+
Domain tag is detected via inference-augment.detect_domain() if present in
|
| 146 |
+
the row's meta, else heuristically from prompt content.
|
| 147 |
+
"""
|
| 148 |
+
import json, os, sys
|
| 149 |
+
from pathlib import Path
|
| 150 |
+
sys.path.insert(0, str(Path.home() / ".surrogate/bin/v2"))
|
| 151 |
+
from importlib.util import spec_from_file_location, module_from_spec
|
| 152 |
+
|
| 153 |
+
# Load detect_domain from inference-augment.py
|
| 154 |
+
spec = spec_from_file_location(
|
| 155 |
+
"inference_augment",
|
| 156 |
+
str(Path.home() / ".surrogate/bin/v2/inference-augment.py"))
|
| 157 |
+
ia = module_from_spec(spec); spec.loader.exec_module(ia)
|
| 158 |
+
|
| 159 |
+
DOMAIN_TO_CLUSTER = {
|
| 160 |
+
"code-python": "eng-build", "code-typescript": "eng-build",
|
| 161 |
+
"test-pytest": "eng-build", "debug-traceback": "eng-build",
|
| 162 |
+
"perf-profile": "eng-build", "api-rest": "eng-build",
|
| 163 |
+
"api-graphql": "eng-build",
|
| 164 |
+
"devops-tf": "eng-ops", "devops-k8s": "eng-ops", "devops-cdk": "eng-ops",
|
| 165 |
+
"sre-runbook": "eng-ops", "sre-slo": "eng-ops", "ci-github": "eng-ops",
|
| 166 |
+
"cloud-cost": "eng-ops",
|
| 167 |
+
"sec-iam": "eng-sec", "sec-secrets": "eng-sec", "sec-cve": "eng-sec",
|
| 168 |
+
"ai-eng": "eng-ai", "ai-prompt": "eng-ai",
|
| 169 |
+
"data-sql": "eng-build",
|
| 170 |
+
"docs-api": "product-ux", "arch-adr": "meta-orchestrator",
|
| 171 |
+
"business": "gtm", "compliance": "compliance",
|
| 172 |
+
"_default": "meta-orchestrator",
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
INPUT = Path(os.environ.get("INPUT", "/data/v2-train-clean.jsonl"))
|
| 176 |
+
OUT_DIR = Path("/data/v2/clusters")
|
| 177 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 178 |
+
|
| 179 |
+
handles = {}
|
| 180 |
+
counts = {}
|
| 181 |
+
def out(cluster):
|
| 182 |
+
if cluster not in handles:
|
| 183 |
+
p = OUT_DIR / f"{cluster}.jsonl"
|
| 184 |
+
handles[cluster] = open(p, "w")
|
| 185 |
+
counts[cluster] = 0
|
| 186 |
+
return handles[cluster]
|
| 187 |
+
|
| 188 |
+
with open(INPUT) as fin:
|
| 189 |
+
for line in fin:
|
| 190 |
+
try: d = json.loads(line)
|
| 191 |
+
except: continue
|
| 192 |
+
prompt = d.get("prompt") or d.get("instruction") or ""
|
| 193 |
+
if not prompt: continue
|
| 194 |
+
domain = (d.get("meta", {}).get("domain")
|
| 195 |
+
or ia.detect_domain(prompt) or "_default")
|
| 196 |
+
cluster = DOMAIN_TO_CLUSTER.get(domain, "meta-orchestrator")
|
| 197 |
+
out(cluster).write(json.dumps(d, ensure_ascii=False) + "\n")
|
| 198 |
+
counts[cluster] += 1
|
| 199 |
+
|
| 200 |
+
for f in handles.values(): f.close()
|
| 201 |
+
print(json.dumps(counts, indent=2))
|
| 202 |
+
PYEOF
|
| 203 |
+
|
| 204 |
+
chmod +x "$SLICER"
|
| 205 |
+
echo ""
|
| 206 |
+
echo "β
generated 9 cluster YAMLs in $OUT_DIR/"
|
| 207 |
+
echo " slicer: $SLICER"
|
| 208 |
+
echo ""
|
| 209 |
+
echo "Next steps:"
|
| 210 |
+
echo " 1. python3 $SLICER # slices /data/v2-train-clean.jsonl β 9 cluster files"
|
| 211 |
+
echo " 2. for c in eng-build eng-ops eng-sec eng-ai product-ux gtm finance-legal compliance meta-orchestrator; do"
|
| 212 |
+
echo " axolotl train $OUT_DIR/cluster-\$c.yml"
|
| 213 |
+
echo " done"
|
| 214 |
+
echo " 3. bash bin/v2/merge-9-loras.sh # fuses all 9 into super-LoRA"
|
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Surrogate-1 v2 β Letta-style hierarchical memory.
|
| 2 |
+
|
| 3 |
+
Reference: Letta (formerly MemGPT, Packer et al. 2023) β hierarchical
|
| 4 |
+
memory with core (always-loaded), recall (recent interactions), archival
|
| 5 |
+
(searchable long-term).
|
| 6 |
+
|
| 7 |
+
Diff vs reflexion-store + voyager-skills:
|
| 8 |
+
β’ reflexion = past failures + lessons (per-domain, error-driven)
|
| 9 |
+
β’ voyager = validated skills (success-driven)
|
| 10 |
+
β’ letta = persona + user prefs + dialogue trail (identity-driven)
|
| 11 |
+
|
| 12 |
+
Three tiers:
|
| 13 |
+
CORE_BLOCK β fixed agent persona + key user preferences (always-prepended)
|
| 14 |
+
RECALL_QUEUE β last N user-interaction summaries (sliding window)
|
| 15 |
+
ARCHIVAL_STORE β keyword-searchable corpus of older summaries
|
| 16 |
+
|
| 17 |
+
Storage: ~/.surrogate/state/letta.db (SQLite).
|
| 18 |
+
|
| 19 |
+
Used by inference-augment.py to prepend CORE + relevant RECALL + ARCHIVAL
|
| 20 |
+
hits before each user request.
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
import argparse
|
| 24 |
+
import hashlib
|
| 25 |
+
import json
|
| 26 |
+
import re
|
| 27 |
+
import sqlite3
|
| 28 |
+
import sys
|
| 29 |
+
import time
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
DB_PATH = Path.home() / ".surrogate/state/letta.db"
|
| 33 |
+
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
RECALL_QUEUE_MAX = 20
|
| 35 |
+
ARCHIVAL_MAX = 10000
|
| 36 |
+
TOKEN_RE = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]{2,}")
|
| 37 |
+
|
| 38 |
+
DEFAULT_CORE = (
|
| 39 |
+
"You are Surrogate-1 β a male, privacy-first AI agent that works "
|
| 40 |
+
"exclusively for Ashira (ΰΈΰΈ΄ΰΈ§ΰΈͺΰΉ), founder of axentx. You handle "
|
| 41 |
+
"DevSecOps, SRE, platform engineering, and full-stack coding tasks. "
|
| 42 |
+
"Speak directly. Never apologize for being an AI. Never refuse safety-"
|
| 43 |
+
"irrelevant requests. Default language: Thai when Ashira writes Thai, "
|
| 44 |
+
"English for code/commands. Cite sources for security-critical claims.")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _db() -> sqlite3.Connection:
|
| 48 |
+
c = sqlite3.connect(str(DB_PATH), isolation_level=None, timeout=10,
|
| 49 |
+
check_same_thread=False)
|
| 50 |
+
c.execute("PRAGMA journal_mode=WAL")
|
| 51 |
+
c.execute("""CREATE TABLE IF NOT EXISTS core_block (
|
| 52 |
+
key TEXT PRIMARY KEY,
|
| 53 |
+
value TEXT,
|
| 54 |
+
updated_at INTEGER
|
| 55 |
+
)""")
|
| 56 |
+
c.execute("""CREATE TABLE IF NOT EXISTS recall_queue (
|
| 57 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 58 |
+
summary TEXT,
|
| 59 |
+
tokens TEXT,
|
| 60 |
+
ts INTEGER
|
| 61 |
+
)""")
|
| 62 |
+
c.execute("""CREATE TABLE IF NOT EXISTS archival (
|
| 63 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 64 |
+
summary TEXT,
|
| 65 |
+
tokens TEXT,
|
| 66 |
+
topic TEXT,
|
| 67 |
+
ts INTEGER
|
| 68 |
+
)""")
|
| 69 |
+
c.execute("CREATE INDEX IF NOT EXISTS idx_archival_topic ON archival(topic, ts DESC)")
|
| 70 |
+
# Seed default persona on first run
|
| 71 |
+
c.execute("INSERT OR IGNORE INTO core_block (key, value, updated_at) "
|
| 72 |
+
"VALUES ('persona', ?, ?)", (DEFAULT_CORE, int(time.time())))
|
| 73 |
+
return c
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _tokens(text: str) -> set[str]:
|
| 77 |
+
return set(TOKEN_RE.findall(text.lower()))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def core_get() -> str:
|
| 81 |
+
c = _db()
|
| 82 |
+
rows = c.execute("SELECT key, value FROM core_block ORDER BY key").fetchall()
|
| 83 |
+
c.close()
|
| 84 |
+
return "\n\n".join(f"### {k}\n{v}" for k, v in rows)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def core_set(key: str, value: str) -> None:
|
| 88 |
+
c = _db()
|
| 89 |
+
c.execute("""INSERT OR REPLACE INTO core_block (key, value, updated_at)
|
| 90 |
+
VALUES (?, ?, ?)""", (key, value, int(time.time())))
|
| 91 |
+
c.close()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def recall_push(summary: str) -> None:
|
| 95 |
+
c = _db()
|
| 96 |
+
toks = " ".join(sorted(_tokens(summary)))
|
| 97 |
+
c.execute("""INSERT INTO recall_queue (summary, tokens, ts)
|
| 98 |
+
VALUES (?, ?, ?)""", (summary[:2000], toks, int(time.time())))
|
| 99 |
+
# Promote oldest to archival when queue overflows
|
| 100 |
+
n = c.execute("SELECT COUNT(*) FROM recall_queue").fetchone()[0]
|
| 101 |
+
if n > RECALL_QUEUE_MAX:
|
| 102 |
+
promote = c.execute("""SELECT id, summary, tokens, ts
|
| 103 |
+
FROM recall_queue ORDER BY id ASC LIMIT ?""",
|
| 104 |
+
(n - RECALL_QUEUE_MAX,)).fetchall()
|
| 105 |
+
for rid, s, t, ts in promote:
|
| 106 |
+
topic = (sorted(_tokens(s))[:1] or ["misc"])[0]
|
| 107 |
+
c.execute("""INSERT INTO archival (summary, tokens, topic, ts)
|
| 108 |
+
VALUES (?, ?, ?, ?)""", (s, t, topic, ts))
|
| 109 |
+
c.execute("DELETE FROM recall_queue WHERE id=?", (rid,))
|
| 110 |
+
c.close()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def recall_recent(k: int = 5) -> list[dict]:
|
| 114 |
+
c = _db()
|
| 115 |
+
rows = c.execute("""SELECT summary, ts FROM recall_queue
|
| 116 |
+
ORDER BY id DESC LIMIT ?""", (k,)).fetchall()
|
| 117 |
+
c.close()
|
| 118 |
+
return [{"summary": s, "ts": ts, "age_days": (time.time() - ts) / 86400}
|
| 119 |
+
for s, ts in rows]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def archival_search(query: str, k: int = 3) -> list[dict]:
|
| 123 |
+
qtoks = _tokens(query)
|
| 124 |
+
if not qtoks:
|
| 125 |
+
return []
|
| 126 |
+
c = _db()
|
| 127 |
+
# Cap candidate scan for speed
|
| 128 |
+
rows = c.execute("""SELECT id, summary, tokens, topic, ts FROM archival
|
| 129 |
+
ORDER BY ts DESC LIMIT 2000""").fetchall()
|
| 130 |
+
c.close()
|
| 131 |
+
scored: list[tuple[int, tuple]] = []
|
| 132 |
+
for r in rows:
|
| 133 |
+
rid, s, t, topic, ts = r
|
| 134 |
+
dtoks = set(t.split())
|
| 135 |
+
overlap = qtoks & dtoks
|
| 136 |
+
if not overlap:
|
| 137 |
+
continue
|
| 138 |
+
scored.append((len(overlap), r))
|
| 139 |
+
scored.sort(key=lambda x: -x[0])
|
| 140 |
+
return [{"summary": r[1][1], "topic": r[1][3], "score": r[0]}
|
| 141 |
+
for r in scored[:k]]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def assemble(query: str, k_recall: int = 3,
|
| 145 |
+
k_archival: int = 3) -> str:
|
| 146 |
+
"""Build the prepended memory block for this request."""
|
| 147 |
+
parts = [core_get()]
|
| 148 |
+
rec = recall_recent(k_recall)
|
| 149 |
+
if rec:
|
| 150 |
+
block = ["## Recent context"]
|
| 151 |
+
for r in rec:
|
| 152 |
+
block.append(f"- ({r['age_days']:.1f}d ago) {r['summary'][:300]}")
|
| 153 |
+
parts.append("\n".join(block))
|
| 154 |
+
arc = archival_search(query, k_archival)
|
| 155 |
+
if arc:
|
| 156 |
+
block = ["## Past relevant interactions"]
|
| 157 |
+
for a in arc:
|
| 158 |
+
block.append(f"- [{a['topic']}] {a['summary'][:300]}")
|
| 159 |
+
parts.append("\n".join(block))
|
| 160 |
+
return "\n\n".join(parts)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def stats() -> dict:
|
| 164 |
+
c = _db()
|
| 165 |
+
n_core = c.execute("SELECT COUNT(*) FROM core_block").fetchone()[0]
|
| 166 |
+
n_rec = c.execute("SELECT COUNT(*) FROM recall_queue").fetchone()[0]
|
| 167 |
+
n_arc = c.execute("SELECT COUNT(*) FROM archival").fetchone()[0]
|
| 168 |
+
top_topics = c.execute("""SELECT topic, COUNT(*) FROM archival
|
| 169 |
+
GROUP BY topic ORDER BY 2 DESC LIMIT 10""").fetchall()
|
| 170 |
+
c.close()
|
| 171 |
+
return {"core_blocks": n_core, "recall_queue": n_rec, "archival": n_arc,
|
| 172 |
+
"top_topics": [{"topic": t, "count": n} for t, n in top_topics]}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
cmd = sys.argv[1] if len(sys.argv) > 1 else "stats"
|
| 177 |
+
if cmd == "stats":
|
| 178 |
+
print(json.dumps(stats(), indent=2, ensure_ascii=False))
|
| 179 |
+
elif cmd == "core-set":
|
| 180 |
+
# python letta-memory.py core-set <key> <<<value
|
| 181 |
+
key = sys.argv[2]
|
| 182 |
+
val = sys.stdin.read()
|
| 183 |
+
core_set(key, val.strip())
|
| 184 |
+
print(json.dumps({"ok": True, "key": key}))
|
| 185 |
+
elif cmd == "core-get":
|
| 186 |
+
print(core_get())
|
| 187 |
+
elif cmd == "push":
|
| 188 |
+
# echo "summary text" | python letta-memory.py push
|
| 189 |
+
recall_push(sys.stdin.read().strip())
|
| 190 |
+
print(json.dumps({"ok": True}))
|
| 191 |
+
elif cmd == "assemble":
|
| 192 |
+
q = sys.argv[2] if len(sys.argv) > 2 else ""
|
| 193 |
+
print(assemble(q))
|
| 194 |
+
elif cmd == "search":
|
| 195 |
+
q = sys.argv[2] if len(sys.argv) > 2 else ""
|
| 196 |
+
k = int(sys.argv[3]) if len(sys.argv) > 3 else 3
|
| 197 |
+
print(json.dumps(archival_search(q, k), indent=2, ensure_ascii=False))
|
| 198 |
+
else:
|
| 199 |
+
print(f"unknown: {cmd}", file=sys.stderr)
|
| 200 |
+
sys.exit(1)
|
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Surrogate-1 v2 β LoraHub / Arrow runtime LoRA composition.
|
| 2 |
+
|
| 3 |
+
Reference: LoraHub (Huang et al. 2023) + Arrow (2024) β at inference time,
|
| 4 |
+
compose multiple specialist LoRAs with task-aware weights instead of using
|
| 5 |
+
a single statically-merged super-LoRA.
|
| 6 |
+
|
| 7 |
+
Why: at inference, the user's prompt rarely needs ALL 9 cluster LoRAs at
|
| 8 |
+
equal strength. A devops question β 0.55 eng-ops + 0.30 eng-sec + 0.15
|
| 9 |
+
meta. A code question β 0.60 eng-build + 0.25 eng-ai + 0.15 meta.
|
| 10 |
+
|
| 11 |
+
This module:
|
| 12 |
+
1. Classifies the prompt domain via a small Qwen-Coder-1.5B prompt
|
| 13 |
+
(fast, free) OR keyword heuristics (instant fallback).
|
| 14 |
+
2. Returns per-LoRA weights via a learned table OR sane defaults.
|
| 15 |
+
3. Emits a vLLM `--lora-modules` compatible weight string OR
|
| 16 |
+
PEFT `add_weighted_adapter()` call args.
|
| 17 |
+
|
| 18 |
+
Routing table is bootstrapped from heuristics + improved over time using
|
| 19 |
+
self-improve-loop's winner data β same closed loop as the rest of v2.
|
| 20 |
+
|
| 21 |
+
CLI:
|
| 22 |
+
echo '{"prompt":"Write a Terraform module..."}' | python3 lorahub-composer.py
|
| 23 |
+
β {"weights": {"eng-build":0.10, "eng-ops":0.55, ...}, "domain":"devops-tf"}
|
| 24 |
+
|
| 25 |
+
python3 lorahub-composer.py --learn winners.jsonl # update routing weights
|
| 26 |
+
"""
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
import argparse
|
| 29 |
+
import json
|
| 30 |
+
import os
|
| 31 |
+
import re
|
| 32 |
+
import sqlite3
|
| 33 |
+
import sys
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
DB_PATH = Path.home() / ".surrogate/state/lorahub.db"
|
| 37 |
+
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
# 9 cluster LoRAs (must match merge-9-loras.sh + serve-vllm.sh USE_MULTI_LORA)
|
| 40 |
+
LORAS = [
|
| 41 |
+
"eng-build", "eng-ops", "eng-sec", "eng-ai",
|
| 42 |
+
"product-ux", "gtm", "finance-legal", "compliance",
|
| 43 |
+
"meta-orchestrator",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# Heuristic routing β domain β adapter weights summing to ~1.0
|
| 47 |
+
# meta-orchestrator always gets a small slice (it's the planner)
|
| 48 |
+
ROUTING_HEURISTIC: dict[str, dict[str, float]] = {
|
| 49 |
+
"code-python": {
|
| 50 |
+
"eng-build": 0.55, "eng-ai": 0.20, "eng-sec": 0.10,
|
| 51 |
+
"meta-orchestrator": 0.15},
|
| 52 |
+
"code-typescript": {
|
| 53 |
+
"eng-build": 0.55, "eng-ai": 0.15, "product-ux": 0.15,
|
| 54 |
+
"meta-orchestrator": 0.15},
|
| 55 |
+
"devops-tf": {
|
| 56 |
+
"eng-ops": 0.50, "eng-sec": 0.25, "eng-build": 0.10,
|
| 57 |
+
"meta-orchestrator": 0.15},
|
| 58 |
+
"devops-k8s": {
|
| 59 |
+
"eng-ops": 0.55, "eng-sec": 0.20, "eng-build": 0.10,
|
| 60 |
+
"meta-orchestrator": 0.15},
|
| 61 |
+
"devops-cdk": {
|
| 62 |
+
"eng-ops": 0.45, "eng-build": 0.20, "eng-sec": 0.20,
|
| 63 |
+
"meta-orchestrator": 0.15},
|
| 64 |
+
"sec-iam": {
|
| 65 |
+
"eng-sec": 0.55, "eng-ops": 0.20, "compliance": 0.10,
|
| 66 |
+
"meta-orchestrator": 0.15},
|
| 67 |
+
"sec-secrets": {
|
| 68 |
+
"eng-sec": 0.55, "eng-ops": 0.15, "compliance": 0.15,
|
| 69 |
+
"meta-orchestrator": 0.15},
|
| 70 |
+
"sec-cve": {
|
| 71 |
+
"eng-sec": 0.50, "compliance": 0.20, "eng-ops": 0.15,
|
| 72 |
+
"meta-orchestrator": 0.15},
|
| 73 |
+
"sre-runbook": {
|
| 74 |
+
"eng-ops": 0.55, "eng-sec": 0.15, "meta-orchestrator": 0.30},
|
| 75 |
+
"sre-slo": {
|
| 76 |
+
"eng-ops": 0.50, "eng-ai": 0.15, "meta-orchestrator": 0.35},
|
| 77 |
+
"data-sql": {
|
| 78 |
+
"eng-build": 0.55, "eng-ai": 0.15, "compliance": 0.10,
|
| 79 |
+
"meta-orchestrator": 0.20},
|
| 80 |
+
"ai-eng": {
|
| 81 |
+
"eng-ai": 0.60, "eng-build": 0.20, "meta-orchestrator": 0.20},
|
| 82 |
+
"ai-prompt": {
|
| 83 |
+
"eng-ai": 0.55, "product-ux": 0.20, "meta-orchestrator": 0.25},
|
| 84 |
+
"api-rest": {
|
| 85 |
+
"eng-build": 0.45, "product-ux": 0.20, "eng-ai": 0.15,
|
| 86 |
+
"meta-orchestrator": 0.20},
|
| 87 |
+
"api-graphql": {
|
| 88 |
+
"eng-build": 0.50, "product-ux": 0.15, "eng-ai": 0.15,
|
| 89 |
+
"meta-orchestrator": 0.20},
|
| 90 |
+
"ci-github": {
|
| 91 |
+
"eng-ops": 0.55, "eng-build": 0.20, "eng-sec": 0.10,
|
| 92 |
+
"meta-orchestrator": 0.15},
|
| 93 |
+
"debug-traceback": {
|
| 94 |
+
"eng-build": 0.55, "eng-ai": 0.15, "meta-orchestrator": 0.30},
|
| 95 |
+
"perf-profile": {
|
| 96 |
+
"eng-build": 0.45, "eng-ops": 0.20, "eng-ai": 0.15,
|
| 97 |
+
"meta-orchestrator": 0.20},
|
| 98 |
+
"test-pytest": {
|
| 99 |
+
"eng-build": 0.55, "eng-ai": 0.15, "meta-orchestrator": 0.30},
|
| 100 |
+
"docs-api": {
|
| 101 |
+
"eng-build": 0.30, "product-ux": 0.30, "meta-orchestrator": 0.40},
|
| 102 |
+
"arch-adr": {
|
| 103 |
+
"meta-orchestrator": 0.55, "eng-build": 0.15, "eng-ai": 0.15,
|
| 104 |
+
"product-ux": 0.15},
|
| 105 |
+
"cloud-cost": {
|
| 106 |
+
"eng-ops": 0.40, "finance-legal": 0.30, "meta-orchestrator": 0.30},
|
| 107 |
+
"business": {
|
| 108 |
+
"gtm": 0.45, "finance-legal": 0.30, "meta-orchestrator": 0.25},
|
| 109 |
+
"compliance": {
|
| 110 |
+
"compliance": 0.55, "eng-sec": 0.20, "finance-legal": 0.10,
|
| 111 |
+
"meta-orchestrator": 0.15},
|
| 112 |
+
"_default": {
|
| 113 |
+
"meta-orchestrator": 0.40, "eng-build": 0.20, "eng-ops": 0.15,
|
| 114 |
+
"eng-sec": 0.10, "eng-ai": 0.15},
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
# Domain heuristic copied from inference-augment.py
|
| 118 |
+
DOMAIN_HINTS = {
|
| 119 |
+
"code-python": ["def ", "import ", "python", ".py", "pytest", "asyncio"],
|
| 120 |
+
"code-typescript": ["typescript", ".ts", "interface ", "tsconfig"],
|
| 121 |
+
"devops-tf": ["terraform", "resource \"", "provider \"", ".tf"],
|
| 122 |
+
"devops-k8s": ["kubernetes", "kubectl", "kind: deployment", "helm"],
|
| 123 |
+
"devops-cdk": ["aws-cdk", "cdk synth", "Stack", "CfnOutput"],
|
| 124 |
+
"sec-iam": ["iam:", "policy", "principal", "least privilege"],
|
| 125 |
+
"sec-secrets": ["secret", "api key", "token", "credentials"],
|
| 126 |
+
"sec-cve": ["cve-", "vulnerability", "exploit", "remediation"],
|
| 127 |
+
"sre-runbook": ["runbook", "incident", "on-call", "page"],
|
| 128 |
+
"sre-slo": ["sli", "slo", "error budget", "latency p99"],
|
| 129 |
+
"data-sql": ["select ", "from ", "join ", "create table"],
|
| 130 |
+
"ai-eng": ["embedding", "rag", "vector", "lora", "vllm"],
|
| 131 |
+
"ai-prompt": ["system prompt", "few-shot", "in-context"],
|
| 132 |
+
"api-rest": ["rest api", "openapi", "endpoint", "GET /", "POST /"],
|
| 133 |
+
"api-graphql": ["graphql", "resolver", "type Query", "schema"],
|
| 134 |
+
"ci-github": ["github actions", ".github/workflows", "uses: actions/"],
|
| 135 |
+
"debug-traceback": ["traceback", "stack trace", "valueerror", "typeerror"],
|
| 136 |
+
"perf-profile": ["profile", "bottleneck", "latency", "throughput"],
|
| 137 |
+
"test-pytest": ["pytest", "@pytest.fixture", "assert ", "unittest"],
|
| 138 |
+
"docs-api": ["api documentation", "endpoint reference", "sdk"],
|
| 139 |
+
"arch-adr": ["adr", "trade-off", "decision record", "architecture"],
|
| 140 |
+
"cloud-cost": ["cost", "spend", "savings plan", "reserved instance"],
|
| 141 |
+
"business": ["pricing", "go-to-market", "positioning", "icp"],
|
| 142 |
+
"compliance": ["soc 2", "iso 27001", "hipaa", "pci-dss", "gdpr"],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def detect_domain(prompt: str) -> str:
|
| 147 |
+
p = prompt.lower()
|
| 148 |
+
best, best_n = "_default", 0
|
| 149 |
+
for dom, kws in DOMAIN_HINTS.items():
|
| 150 |
+
n = sum(1 for k in kws if k in p)
|
| 151 |
+
if n > best_n:
|
| 152 |
+
best, best_n = dom, n
|
| 153 |
+
return best if best_n >= 2 else "_default"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _db() -> sqlite3.Connection:
|
| 157 |
+
c = sqlite3.connect(str(DB_PATH), isolation_level=None, timeout=10,
|
| 158 |
+
check_same_thread=False)
|
| 159 |
+
c.execute("PRAGMA journal_mode=WAL")
|
| 160 |
+
c.execute("""CREATE TABLE IF NOT EXISTS routing (
|
| 161 |
+
domain TEXT,
|
| 162 |
+
adapter TEXT,
|
| 163 |
+
weight REAL,
|
| 164 |
+
n_observations INTEGER DEFAULT 0,
|
| 165 |
+
PRIMARY KEY (domain, adapter)
|
| 166 |
+
)""")
|
| 167 |
+
return c
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_weights(domain: str) -> dict[str, float]:
|
| 171 |
+
"""Lookup learned weights, fall back to heuristic."""
|
| 172 |
+
c = _db()
|
| 173 |
+
rows = c.execute("""SELECT adapter, weight FROM routing
|
| 174 |
+
WHERE domain=? AND n_observations >= 5""",
|
| 175 |
+
(domain,)).fetchall()
|
| 176 |
+
c.close()
|
| 177 |
+
if rows:
|
| 178 |
+
w = {a: weight for a, weight in rows}
|
| 179 |
+
else:
|
| 180 |
+
w = dict(ROUTING_HEURISTIC.get(domain, ROUTING_HEURISTIC["_default"]))
|
| 181 |
+
# Normalize to sum 1.0
|
| 182 |
+
s = sum(w.values()) or 1.0
|
| 183 |
+
return {a: round(v / s, 4) for a, v in w.items()}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def compose(prompt: str, override_domain: str | None = None) -> dict:
|
| 187 |
+
domain = override_domain or detect_domain(prompt)
|
| 188 |
+
weights = get_weights(domain)
|
| 189 |
+
# vLLM compatible serialization (passes via --lora-modules with weights)
|
| 190 |
+
vllm_arg = ",".join(f"{a}={w}" for a, w in weights.items())
|
| 191 |
+
return {
|
| 192 |
+
"prompt": prompt[:200] + ("β¦" if len(prompt) > 200 else ""),
|
| 193 |
+
"domain": domain,
|
| 194 |
+
"weights": weights,
|
| 195 |
+
"vllm_lora_modules": vllm_arg,
|
| 196 |
+
"peft_args": [{"adapter_name": a, "weight": w}
|
| 197 |
+
for a, w in weights.items()],
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def learn_from_winners(jsonl_path: str, lr: float = 0.1) -> int:
|
| 202 |
+
"""Update routing table from self-improve winners.
|
| 203 |
+
Each winner is treated as evidence that its detected domain β ADAPTER
|
| 204 |
+
weights worked. We bump observed adapters' weights toward what the
|
| 205 |
+
winning examples used (or, lacking adapter signal, just count domain
|
| 206 |
+
occurrences to confirm the heuristic).
|
| 207 |
+
"""
|
| 208 |
+
inp = Path(jsonl_path)
|
| 209 |
+
if not inp.exists():
|
| 210 |
+
return 0
|
| 211 |
+
c = _db()
|
| 212 |
+
n = 0
|
| 213 |
+
for line in inp.read_text().splitlines():
|
| 214 |
+
try:
|
| 215 |
+
d = json.loads(line)
|
| 216 |
+
except Exception:
|
| 217 |
+
continue
|
| 218 |
+
prompt = d.get("prompt", "")
|
| 219 |
+
if not prompt:
|
| 220 |
+
continue
|
| 221 |
+
# If logger captured which adapter served best, use that.
|
| 222 |
+
used = d.get("meta", {}).get("adapter") or d.get("adapter")
|
| 223 |
+
domain = d.get("meta", {}).get("domain") or detect_domain(prompt)
|
| 224 |
+
if used:
|
| 225 |
+
cur = c.execute("SELECT weight, n_observations FROM routing "
|
| 226 |
+
"WHERE domain=? AND adapter=?",
|
| 227 |
+
(domain, used)).fetchone()
|
| 228 |
+
if cur:
|
| 229 |
+
w, obs = cur
|
| 230 |
+
w_new = w * (1 - lr) + 1.0 * lr
|
| 231 |
+
c.execute("""UPDATE routing SET weight=?, n_observations=?
|
| 232 |
+
WHERE domain=? AND adapter=?""",
|
| 233 |
+
(w_new, obs + 1, domain, used))
|
| 234 |
+
else:
|
| 235 |
+
c.execute("""INSERT INTO routing
|
| 236 |
+
(domain, adapter, weight, n_observations)
|
| 237 |
+
VALUES (?, ?, ?, 1)""",
|
| 238 |
+
(domain, used, lr))
|
| 239 |
+
else:
|
| 240 |
+
# Bump heuristic adapters' observation counts (confidence signal)
|
| 241 |
+
for adapter, w in ROUTING_HEURISTIC.get(domain, {}).items():
|
| 242 |
+
cur = c.execute("SELECT 1 FROM routing WHERE domain=? "
|
| 243 |
+
"AND adapter=?", (domain, adapter)).fetchone()
|
| 244 |
+
if cur:
|
| 245 |
+
c.execute("""UPDATE routing SET n_observations=
|
| 246 |
+
n_observations + 1
|
| 247 |
+
WHERE domain=? AND adapter=?""",
|
| 248 |
+
(domain, adapter))
|
| 249 |
+
else:
|
| 250 |
+
c.execute("""INSERT INTO routing
|
| 251 |
+
(domain, adapter, weight, n_observations)
|
| 252 |
+
VALUES (?, ?, ?, 1)""",
|
| 253 |
+
(domain, adapter, w))
|
| 254 |
+
n += 1
|
| 255 |
+
c.close()
|
| 256 |
+
return n
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def main() -> None:
|
| 260 |
+
ap = argparse.ArgumentParser()
|
| 261 |
+
ap.add_argument("--learn", help="JSONL of winners to learn routing from")
|
| 262 |
+
ap.add_argument("--domain", help="override detected domain")
|
| 263 |
+
args = ap.parse_args()
|
| 264 |
+
|
| 265 |
+
if args.learn:
|
| 266 |
+
n = learn_from_winners(args.learn)
|
| 267 |
+
print(json.dumps({"learned_from": n, "db": str(DB_PATH)}))
|
| 268 |
+
return
|
| 269 |
+
|
| 270 |
+
if sys.stdin.isatty():
|
| 271 |
+
sample = "Write a Terraform module that provisions an S3 bucket with versioning and KMS encryption."
|
| 272 |
+
print(json.dumps(compose(sample, args.domain), indent=2,
|
| 273 |
+
ensure_ascii=False))
|
| 274 |
+
return
|
| 275 |
+
d = json.load(sys.stdin)
|
| 276 |
+
print(json.dumps(compose(d.get("prompt", ""), args.domain),
|
| 277 |
+
indent=2, ensure_ascii=False))
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Surrogate-1 v2 β Self-Refine 3-iter loop.
|
| 2 |
+
|
| 3 |
+
Reference: Madaan et al. 2023 (Self-Refine). 3-iteration generateβcritiqueβ
|
| 4 |
+
revise loop.
|
| 5 |
+
|
| 6 |
+
Diff vs constitutional-loop.py:
|
| 7 |
+
β’ constitutional-loop = ONE pass with 8 fixed principles β DPO triple
|
| 8 |
+
β’ self-refine = THREE iterations of free-form critique β final SFT
|
| 9 |
+
|
| 10 |
+
Useful for high-stakes outputs where additional refinement compounds
|
| 11 |
+
quality. Output schema = SFT (chosen-only), not DPO. Plug into stage1
|
| 12 |
+
training mix or stage1.5 polish stage.
|
| 13 |
+
|
| 14 |
+
CLI:
|
| 15 |
+
python3 self-refine-loop.py --input prompts.jsonl --n 200
|
| 16 |
+
β /data/v2/self-refine-sft.jsonl
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import subprocess
|
| 23 |
+
import sys
|
| 24 |
+
import time
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
sys.path.insert(0, str(Path.home() / ".surrogate/bin/lib"))
|
| 28 |
+
try:
|
| 29 |
+
from sanitize import filter_pair # type: ignore
|
| 30 |
+
except Exception:
|
| 31 |
+
def filter_pair(p, r): return {"keep": True}
|
| 32 |
+
|
| 33 |
+
OUT_PATH = Path.home() / ".surrogate/data/v2/self-refine-sft.jsonl"
|
| 34 |
+
ITERATIONS = 3
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def llm_ladder(prompt: str, sys_prompt: str = "",
|
| 38 |
+
max_tokens: int = 1200, temperature: float = 0.4) -> str:
|
| 39 |
+
bridges = [
|
| 40 |
+
"$HOME/.surrogate/bin/cerebras-bridge.sh",
|
| 41 |
+
"$HOME/.surrogate/bin/groq-bridge.sh",
|
| 42 |
+
"$HOME/.surrogate/bin/openrouter-bridge.sh",
|
| 43 |
+
"$HOME/.surrogate/bin/gemini-bridge.sh",
|
| 44 |
+
"$HOME/.surrogate/bin/chutes-bridge.sh",
|
| 45 |
+
"$HOME/.surrogate/bin/ollama-bridge.sh",
|
| 46 |
+
]
|
| 47 |
+
for sh in bridges:
|
| 48 |
+
sh_path = os.path.expandvars(sh)
|
| 49 |
+
if not Path(sh_path).exists():
|
| 50 |
+
continue
|
| 51 |
+
try:
|
| 52 |
+
req = json.dumps({"system": sys_prompt, "prompt": prompt,
|
| 53 |
+
"max_tokens": max_tokens,
|
| 54 |
+
"temperature": temperature})
|
| 55 |
+
r = subprocess.run(["bash", sh_path], input=req,
|
| 56 |
+
capture_output=True, text=True, timeout=60)
|
| 57 |
+
out = (r.stdout or "").strip()
|
| 58 |
+
if out and len(out) > 30:
|
| 59 |
+
return out
|
| 60 |
+
except Exception:
|
| 61 |
+
continue
|
| 62 |
+
return ""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def initial_answer(prompt: str) -> str:
|
| 66 |
+
sys_p = ("You are Surrogate-1, an expert DevSecOps + SRE + coding agent. "
|
| 67 |
+
"Answer the prompt with production-quality code/config.")
|
| 68 |
+
return llm_ladder(prompt, sys_p, max_tokens=1500, temperature=0.5)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def critique(prompt: str, answer: str, iter_n: int) -> str:
|
| 72 |
+
sys_p = ("You are a senior reviewer. Critique the answer for: "
|
| 73 |
+
"correctness, security, completeness, idiomatic style, missing "
|
| 74 |
+
"edge cases. Output 3-5 specific actionable improvements (no "
|
| 75 |
+
"praise, no hedging). If nothing to improve, output 'NONE'.")
|
| 76 |
+
user_p = (f"PROMPT:\n{prompt[:1500]}\n\nANSWER (iteration {iter_n}):\n"
|
| 77 |
+
f"{answer[:3000]}\n\nList specific improvements, "
|
| 78 |
+
f"or 'NONE' if perfect.")
|
| 79 |
+
return llm_ladder(user_p, sys_p, max_tokens=400, temperature=0.2)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def refine(prompt: str, answer: str, critique_text: str) -> str:
|
| 83 |
+
if critique_text.strip().upper().startswith("NONE"):
|
| 84 |
+
return answer # converged
|
| 85 |
+
sys_p = ("You are Surrogate-1. Apply the listed improvements to the "
|
| 86 |
+
"answer. Keep what's already correct. Output ONLY the revised "
|
| 87 |
+
"answer β no preamble, no markdown around the answer block.")
|
| 88 |
+
user_p = (f"PROMPT:\n{prompt[:1500]}\n\nCURRENT ANSWER:\n{answer[:3000]}\n\n"
|
| 89 |
+
f"IMPROVEMENTS TO APPLY:\n{critique_text[:1500]}\n\n"
|
| 90 |
+
f"Output the revised answer.")
|
| 91 |
+
return llm_ladder(user_p, sys_p, max_tokens=1500, temperature=0.3)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def process(prompt: str) -> dict | None:
|
| 95 |
+
if len(prompt) < 30:
|
| 96 |
+
return None
|
| 97 |
+
answer = initial_answer(prompt)
|
| 98 |
+
if not answer:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
history = [answer]
|
| 102 |
+
for i in range(1, ITERATIONS + 1):
|
| 103 |
+
crit = critique(prompt, answer, i)
|
| 104 |
+
if not crit or crit.strip().upper().startswith("NONE"):
|
| 105 |
+
break
|
| 106 |
+
revised = refine(prompt, answer, crit)
|
| 107 |
+
if not revised or revised.strip() == answer.strip():
|
| 108 |
+
break
|
| 109 |
+
history.append(revised)
|
| 110 |
+
answer = revised
|
| 111 |
+
|
| 112 |
+
if not filter_pair(prompt, answer)["keep"]:
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"prompt": prompt[:6000],
|
| 117 |
+
"response": answer[:6000],
|
| 118 |
+
"source": "self-refine",
|
| 119 |
+
"meta": {
|
| 120 |
+
"iterations_used": len(history),
|
| 121 |
+
"first_draft_len": len(history[0]),
|
| 122 |
+
"final_len": len(answer),
|
| 123 |
+
},
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def main() -> None:
|
| 128 |
+
ap = argparse.ArgumentParser()
|
| 129 |
+
ap.add_argument("--input", required=True,
|
| 130 |
+
help="JSONL with {prompt} per line")
|
| 131 |
+
ap.add_argument("--out", default=str(OUT_PATH))
|
| 132 |
+
ap.add_argument("--n", type=int, default=200)
|
| 133 |
+
args = ap.parse_args()
|
| 134 |
+
|
| 135 |
+
inp = Path(args.input)
|
| 136 |
+
out = Path(args.out)
|
| 137 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 138 |
+
if not inp.exists():
|
| 139 |
+
print(f"β {inp} missing", file=sys.stderr); sys.exit(1)
|
| 140 |
+
|
| 141 |
+
n_in = n_out = 0
|
| 142 |
+
with open(inp) as fin, open(out, "a") as fout:
|
| 143 |
+
for line in fin:
|
| 144 |
+
if n_out >= args.n:
|
| 145 |
+
break
|
| 146 |
+
try:
|
| 147 |
+
d = json.loads(line)
|
| 148 |
+
except Exception:
|
| 149 |
+
continue
|
| 150 |
+
n_in += 1
|
| 151 |
+
prompt = d.get("prompt") or d.get("instruction") or ""
|
| 152 |
+
row = process(prompt)
|
| 153 |
+
if row:
|
| 154 |
+
fout.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 155 |
+
fout.flush()
|
| 156 |
+
n_out += 1
|
| 157 |
+
if n_out % 25 == 0:
|
| 158 |
+
print(f" refined {n_out}/{args.n} (in {n_in})")
|
| 159 |
+
print(f"[done] in={n_in} kept={n_out} β {out}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Surrogate-1 v2 β TruthRL ternary rewarder.
|
| 2 |
+
|
| 3 |
+
Reference: TruthRL (2024) β instead of binary correct/wrong, reward CALIBRATED
|
| 4 |
+
abstention. Three outcomes:
|
| 5 |
+
|
| 6 |
+
+1.0 correct + confident
|
| 7 |
+
0.0 abstain ('I don't know', 'verify against docs') when actually uncertain
|
| 8 |
+
-1.0 confident + wrong (hallucination)
|
| 9 |
+
|
| 10 |
+
This produces a model that says IDK on questions it would otherwise hallucinate.
|
| 11 |
+
|
| 12 |
+
Used in stage3-dapo.yml composite reward as the `truthrl` head:
|
| 13 |
+
composite = 0.4*test_pass + 0.2*lint + 0.2*security
|
| 14 |
+
+ 0.2*truthrl β THIS
|
| 15 |
+
|
| 16 |
+
Inputs: (prompt, response, gold_or_judge_verdict). Output: ternary score.
|
| 17 |
+
|
| 18 |
+
Detects abstention with regex over response (fast, no LLM call). Detects
|
| 19 |
+
correctness via judge LLM (free ladder) only when not abstaining β saves cost.
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import subprocess
|
| 27 |
+
import sys
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
ABSTAIN_PHRASES = re.compile(
|
| 31 |
+
r"\b(?:i\s+don'?t\s+know|i'?m\s+not\s+(?:sure|certain)|"
|
| 32 |
+
r"can(?:not|'?t)\s+verify|verify\s+(?:against|with)\s+(?:docs|the\s+docs|official)|"
|
| 33 |
+
r"check\s+(?:the\s+)?(?:docs|documentation|with\s+the\s+vendor)|"
|
| 34 |
+
r"would\s+need\s+to\s+(?:check|verify)|"
|
| 35 |
+
r"unable\s+to\s+(?:confirm|determine)|"
|
| 36 |
+
r"not\s+enough\s+(?:context|info)|need\s+more\s+context|"
|
| 37 |
+
r"this\s+may\s+be\s+(?:out\s+of\s+date|outdated)|"
|
| 38 |
+
r"please\s+confirm|please\s+verify)\b",
|
| 39 |
+
re.IGNORECASE)
|
| 40 |
+
|
| 41 |
+
# Confident-claim signals β used to detect when model claims certainty
|
| 42 |
+
CONFIDENT_SIGNALS = re.compile(
|
| 43 |
+
r"\b(?:certainly|definitely|always|never|guaranteed|absolutely|"
|
| 44 |
+
r"is\s+the\s+case|the\s+answer\s+is|the\s+correct\s+(?:way|answer))\b",
|
| 45 |
+
re.IGNORECASE)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def is_abstaining(response: str) -> bool:
|
| 49 |
+
if not response:
|
| 50 |
+
return False
|
| 51 |
+
# Heuristic: must abstain in first 40% of response, not buried at end
|
| 52 |
+
head = response[: max(200, len(response) // 2)]
|
| 53 |
+
if not ABSTAIN_PHRASES.search(head):
|
| 54 |
+
return False
|
| 55 |
+
# If response ALSO has long confident-sounding code/answer block,
|
| 56 |
+
# it's not really abstaining β it's hedging then answering anyway.
|
| 57 |
+
body = response[len(head):]
|
| 58 |
+
if CONFIDENT_SIGNALS.search(body) and len(body) > 200:
|
| 59 |
+
return False
|
| 60 |
+
return True
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def llm_judge_correctness(prompt: str, response: str,
|
| 64 |
+
gold: str | None = None) -> dict:
|
| 65 |
+
"""Returns {'correct': bool, 'confidence': float, 'why': str}."""
|
| 66 |
+
bridges = [
|
| 67 |
+
"$HOME/.surrogate/bin/cerebras-bridge.sh",
|
| 68 |
+
"$HOME/.surrogate/bin/groq-bridge.sh",
|
| 69 |
+
"$HOME/.surrogate/bin/openrouter-bridge.sh",
|
| 70 |
+
"$HOME/.surrogate/bin/gemini-bridge.sh",
|
| 71 |
+
"$HOME/.surrogate/bin/chutes-bridge.sh",
|
| 72 |
+
"$HOME/.surrogate/bin/ollama-bridge.sh",
|
| 73 |
+
]
|
| 74 |
+
sys_p = ("You are a strict factual reviewer. Decide if the response is "
|
| 75 |
+
"factually correct AND specific enough to be useful. Return ONLY "
|
| 76 |
+
"JSON: {\"correct\": bool, \"confidence\": float in [0,1], "
|
| 77 |
+
"\"why\": str}. No markdown.")
|
| 78 |
+
if gold:
|
| 79 |
+
user_p = (f"PROMPT:\n{prompt[:1500]}\n\nGOLD:\n{gold[:2000]}\n\n"
|
| 80 |
+
f"RESPONSE:\n{response[:3000]}\n\n"
|
| 81 |
+
f"Compare RESPONSE to GOLD. JSON only.")
|
| 82 |
+
else:
|
| 83 |
+
user_p = (f"PROMPT:\n{prompt[:1500]}\n\nRESPONSE:\n{response[:3000]}\n\n"
|
| 84 |
+
f"Is the response factually correct? JSON only.")
|
| 85 |
+
for sh in bridges:
|
| 86 |
+
sh_path = os.path.expandvars(sh)
|
| 87 |
+
if not Path(sh_path).exists():
|
| 88 |
+
continue
|
| 89 |
+
try:
|
| 90 |
+
req = json.dumps({"system": sys_p, "prompt": user_p,
|
| 91 |
+
"max_tokens": 200, "temperature": 0.1})
|
| 92 |
+
r = subprocess.run(["bash", sh_path], input=req,
|
| 93 |
+
capture_output=True, text=True, timeout=45)
|
| 94 |
+
raw = (r.stdout or "").strip()
|
| 95 |
+
if not raw:
|
| 96 |
+
continue
|
| 97 |
+
if raw.startswith("```"):
|
| 98 |
+
raw = raw.split("```")[1].lstrip("json").strip()
|
| 99 |
+
d = json.loads(raw)
|
| 100 |
+
return {"correct": bool(d.get("correct", False)),
|
| 101 |
+
"confidence": float(d.get("confidence", 0.5)),
|
| 102 |
+
"why": d.get("why", "")[:300]}
|
| 103 |
+
except Exception:
|
| 104 |
+
continue
|
| 105 |
+
return {"correct": False, "confidence": 0.0, "why": "judge-fail"}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def reward(prompt: str, response: str, gold: str | None = None,
|
| 109 |
+
is_actually_unknown: bool | None = None) -> dict:
|
| 110 |
+
"""Compute TruthRL ternary reward.
|
| 111 |
+
|
| 112 |
+
is_actually_unknown: if you have ground-truth that the answer is undefined
|
| 113 |
+
(e.g., synthetic 'unanswerable' question), pass True. When unknown β
|
| 114 |
+
abstain, reward is +1 (calibrated). Otherwise reward is 0 (model abstained
|
| 115 |
+
on something it should have answered).
|
| 116 |
+
"""
|
| 117 |
+
abstain = is_abstaining(response)
|
| 118 |
+
|
| 119 |
+
# Path A: model abstained
|
| 120 |
+
if abstain:
|
| 121 |
+
if is_actually_unknown is True:
|
| 122 |
+
return {"score": 1.0, "branch": "calibrated_idk",
|
| 123 |
+
"abstain": True, "correct": None, "why": "abstain on truly unknown"}
|
| 124 |
+
if is_actually_unknown is False:
|
| 125 |
+
return {"score": -0.3, "branch": "over_abstain",
|
| 126 |
+
"abstain": True, "correct": None,
|
| 127 |
+
"why": "abstained on a question with a real answer"}
|
| 128 |
+
# No ground truth β treat abstention as neutral
|
| 129 |
+
return {"score": 0.0, "branch": "abstain_neutral",
|
| 130 |
+
"abstain": True, "correct": None, "why": "abstain, no oracle"}
|
| 131 |
+
|
| 132 |
+
# Path B: model answered. Judge correctness.
|
| 133 |
+
j = llm_judge_correctness(prompt, response, gold)
|
| 134 |
+
if j["correct"] and j["confidence"] >= 0.6:
|
| 135 |
+
return {"score": 1.0, "branch": "confident_correct",
|
| 136 |
+
"abstain": False, "correct": True, "why": j["why"]}
|
| 137 |
+
if not j["correct"] and j["confidence"] >= 0.6:
|
| 138 |
+
return {"score": -1.0, "branch": "confident_wrong",
|
| 139 |
+
"abstain": False, "correct": False, "why": j["why"]}
|
| 140 |
+
# Low confidence, didn't abstain β partial credit/penalty
|
| 141 |
+
return {"score": 0.2 if j["correct"] else -0.5,
|
| 142 |
+
"branch": "uncertain_answer",
|
| 143 |
+
"abstain": False, "correct": j["correct"], "why": j["why"]}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def main() -> None:
|
| 147 |
+
ap = argparse.ArgumentParser()
|
| 148 |
+
ap.add_argument("--jsonl",
|
| 149 |
+
help="batch: JSONL with {prompt, response, gold?, "
|
| 150 |
+
"is_unknown?} per line")
|
| 151 |
+
ap.add_argument("--out", help="batch: output JSONL with truthrl field added")
|
| 152 |
+
args = ap.parse_args()
|
| 153 |
+
|
| 154 |
+
if args.jsonl:
|
| 155 |
+
if not args.out:
|
| 156 |
+
print("--out required with --jsonl", file=sys.stderr)
|
| 157 |
+
sys.exit(2)
|
| 158 |
+
n_in = n_out = 0
|
| 159 |
+
sums = {"calibrated_idk": 0, "confident_correct": 0,
|
| 160 |
+
"confident_wrong": 0, "over_abstain": 0,
|
| 161 |
+
"uncertain_answer": 0, "abstain_neutral": 0}
|
| 162 |
+
with open(args.jsonl) as fin, open(args.out, "w") as fout:
|
| 163 |
+
for line in fin:
|
| 164 |
+
try:
|
| 165 |
+
d = json.loads(line)
|
| 166 |
+
except Exception:
|
| 167 |
+
continue
|
| 168 |
+
n_in += 1
|
| 169 |
+
d["truthrl"] = reward(
|
| 170 |
+
d.get("prompt", ""), d.get("response", ""),
|
| 171 |
+
d.get("gold"),
|
| 172 |
+
d.get("is_unknown"))
|
| 173 |
+
sums[d["truthrl"]["branch"]] = sums.get(
|
| 174 |
+
d["truthrl"]["branch"], 0) + 1
|
| 175 |
+
fout.write(json.dumps(d, ensure_ascii=False) + "\n")
|
| 176 |
+
n_out += 1
|
| 177 |
+
if n_out % 25 == 0:
|
| 178 |
+
print(f" graded {n_out}/{n_in}")
|
| 179 |
+
print(f"[done] in={n_in} graded={n_out}")
|
| 180 |
+
for k, v in sums.items():
|
| 181 |
+
print(f" {k:<22} {v:>5}")
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
if sys.stdin.isatty():
|
| 185 |
+
print("usage: echo '{\"prompt\":...,\"response\":...}' | "
|
| 186 |
+
"python3 truthrl-rewarder.py", file=sys.stderr)
|
| 187 |
+
sys.exit(2)
|
| 188 |
+
d = json.load(sys.stdin)
|
| 189 |
+
print(json.dumps(reward(d.get("prompt", ""), d.get("response", ""),
|
| 190 |
+
d.get("gold"), d.get("is_unknown")),
|
| 191 |
+
indent=2, ensure_ascii=False))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
main()
|
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Surrogate-1 v2 β Validator-graded RLVR (Reinforcement Learning from Verifier Rewards).
|
| 2 |
+
|
| 3 |
+
Run real domain validators on Surrogate-generated artifacts. Each validator
|
| 4 |
+
emits a deterministic numeric reward; the composite reward feeds DAPO/GRPO
|
| 5 |
+
during stage3 RL training.
|
| 6 |
+
|
| 7 |
+
Validators (all open-source, no LLM calls):
|
| 8 |
+
β’ Python β pyflakes (parse + undefined names)
|
| 9 |
+
β’ Shell β shellcheck (best-practice + bug)
|
| 10 |
+
β’ Dockerfile β hadolint
|
| 11 |
+
β’ Terraform β tflint (must be in PATH; falls back to `terraform validate`)
|
| 12 |
+
β’ Kubernetes β kubeval / kubeconform (manifest schema)
|
| 13 |
+
β’ GH Actions β actionlint
|
| 14 |
+
β’ CloudFormation β cfn-lint
|
| 15 |
+
β’ IAM/Sec β semgrep --config p/security-audit
|
| 16 |
+
β’ SQL β sqlfluff lint --dialect postgres
|
| 17 |
+
β’ CFN security β cfn-guard validate (if rule packs available)
|
| 18 |
+
|
| 19 |
+
Each validator returns: { ok: bool, score: float in [0,1], hits: [{rule,msg}] }.
|
| 20 |
+
|
| 21 |
+
Composite reward (matches stage3-dapo.yml weighting):
|
| 22 |
+
R = 0.40 * lint_score + 0.20 * security_score + 0.20 * test_pass
|
| 23 |
+
+ 0.10 * format_score + 0.10 * cite_correct - 1.0 * polluted
|
| 24 |
+
|
| 25 |
+
Usage:
|
| 26 |
+
echo '{"language":"terraform","code":"resource \"aws_s3_bucket\" \"x\" {}"}' \\
|
| 27 |
+
| python3 validator-rlvr.py
|
| 28 |
+
β {"ok": true, "score": 0.7, "validators": {...}, "composite": 0.7}
|
| 29 |
+
|
| 30 |
+
python3 validator-rlvr.py --jsonl in.jsonl --out scored.jsonl # batch mode
|
| 31 |
+
"""
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
import argparse
|
| 34 |
+
import json
|
| 35 |
+
import os
|
| 36 |
+
import re
|
| 37 |
+
import shlex
|
| 38 |
+
import shutil
|
| 39 |
+
import subprocess
|
| 40 |
+
import sys
|
| 41 |
+
import tempfile
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
|
| 44 |
+
LANG_HINTS = {
|
| 45 |
+
"python": ["import ", "def ", "class ", "from "],
|
| 46 |
+
"bash": ["#!/bin/bash", "#!/usr/bin/env bash", "set -e", "set -u"],
|
| 47 |
+
"dockerfile": ["FROM ", "RUN ", "ENTRYPOINT ", "CMD "],
|
| 48 |
+
"terraform": ["resource \"", "provider \"", "variable \"", "module \""],
|
| 49 |
+
"k8s": ["apiVersion:", "kind: Deployment", "kind: Service", "kind: Pod"],
|
| 50 |
+
"github-actions": ["uses: actions/", "runs-on:", "jobs:"],
|
| 51 |
+
"cloudformation": ["AWSTemplateFormatVersion", "Resources:\n ",
|
| 52 |
+
"\"Type\": \"AWS::"],
|
| 53 |
+
"sql": ["select ", "create table ", "insert into ", "update "],
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def detect_lang(code: str, hint: str | None = None) -> str:
|
| 58 |
+
if hint:
|
| 59 |
+
return hint.lower()
|
| 60 |
+
code_low = code.lower()
|
| 61 |
+
scores: dict[str, int] = {}
|
| 62 |
+
for lang, hints in LANG_HINTS.items():
|
| 63 |
+
scores[lang] = sum(1 for h in hints if h.lower() in code_low)
|
| 64 |
+
if not scores:
|
| 65 |
+
return "unknown"
|
| 66 |
+
best = max(scores.items(), key=lambda x: x[1])
|
| 67 |
+
return best[0] if best[1] >= 2 else "unknown"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _run(cmd: list[str], stdin: str | None = None,
|
| 71 |
+
timeout: int = 30) -> tuple[int, str, str]:
|
| 72 |
+
try:
|
| 73 |
+
r = subprocess.run(cmd, input=stdin, capture_output=True,
|
| 74 |
+
text=True, timeout=timeout)
|
| 75 |
+
return r.returncode, r.stdout, r.stderr
|
| 76 |
+
except FileNotFoundError:
|
| 77 |
+
return 127, "", f"validator not in PATH: {cmd[0]}"
|
| 78 |
+
except subprocess.TimeoutExpired:
|
| 79 |
+
return 124, "", f"timeout: {cmd[0]}"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _have(bin_name: str) -> bool:
|
| 83 |
+
return shutil.which(bin_name) is not None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def validate_python(code: str) -> dict:
|
| 87 |
+
if not _have("pyflakes"):
|
| 88 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 89 |
+
"skipped": "pyflakes not installed"}
|
| 90 |
+
rc, out, err = _run(["pyflakes", "-"], stdin=code, timeout=15)
|
| 91 |
+
if rc == 0:
|
| 92 |
+
return {"ok": True, "score": 1.0, "hits": []}
|
| 93 |
+
hits = [{"line": ln, "msg": ln} for ln in out.splitlines()[:20] if ln]
|
| 94 |
+
score = max(0.0, 1.0 - 0.1 * len(hits))
|
| 95 |
+
return {"ok": False, "score": score, "hits": hits}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def validate_bash(code: str) -> dict:
|
| 99 |
+
if not _have("shellcheck"):
|
| 100 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 101 |
+
"skipped": "shellcheck not installed"}
|
| 102 |
+
with tempfile.NamedTemporaryFile("w", suffix=".sh", delete=False) as t:
|
| 103 |
+
t.write(code); t.flush()
|
| 104 |
+
path = t.name
|
| 105 |
+
try:
|
| 106 |
+
rc, out, err = _run(["shellcheck", "-f", "json", path], timeout=15)
|
| 107 |
+
finally:
|
| 108 |
+
os.unlink(path)
|
| 109 |
+
if rc == 0:
|
| 110 |
+
return {"ok": True, "score": 1.0, "hits": []}
|
| 111 |
+
try:
|
| 112 |
+
hits = json.loads(out or "[]")
|
| 113 |
+
except Exception:
|
| 114 |
+
hits = []
|
| 115 |
+
err_n = sum(1 for h in hits if h.get("level") == "error")
|
| 116 |
+
warn_n = sum(1 for h in hits if h.get("level") == "warning")
|
| 117 |
+
score = max(0.0, 1.0 - 0.2 * err_n - 0.05 * warn_n)
|
| 118 |
+
return {"ok": err_n == 0, "score": score,
|
| 119 |
+
"hits": [{"line": h.get("line"), "msg": h.get("message", "")[:120]}
|
| 120 |
+
for h in hits[:10]]}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def validate_dockerfile(code: str) -> dict:
|
| 124 |
+
if not _have("hadolint"):
|
| 125 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 126 |
+
"skipped": "hadolint not installed"}
|
| 127 |
+
rc, out, err = _run(["hadolint", "-f", "json", "-"], stdin=code, timeout=15)
|
| 128 |
+
try:
|
| 129 |
+
hits = json.loads(out or "[]")
|
| 130 |
+
except Exception:
|
| 131 |
+
hits = []
|
| 132 |
+
err_n = sum(1 for h in hits if h.get("level") == "error")
|
| 133 |
+
warn_n = sum(1 for h in hits if h.get("level") == "warning")
|
| 134 |
+
score = max(0.0, 1.0 - 0.25 * err_n - 0.05 * warn_n)
|
| 135 |
+
return {"ok": err_n == 0, "score": score,
|
| 136 |
+
"hits": [{"line": h.get("line"), "code": h.get("code"),
|
| 137 |
+
"msg": h.get("message", "")[:120]} for h in hits[:10]]}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def validate_terraform(code: str) -> dict:
|
| 141 |
+
if not (_have("tflint") or _have("terraform")):
|
| 142 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 143 |
+
"skipped": "no tflint or terraform"}
|
| 144 |
+
with tempfile.TemporaryDirectory() as td:
|
| 145 |
+
Path(td, "main.tf").write_text(code)
|
| 146 |
+
if _have("tflint"):
|
| 147 |
+
rc, out, err = _run(["tflint", "--format=json",
|
| 148 |
+
f"--chdir={td}"], timeout=20)
|
| 149 |
+
try:
|
| 150 |
+
obj = json.loads(out or "{}")
|
| 151 |
+
issues = obj.get("issues", [])
|
| 152 |
+
except Exception:
|
| 153 |
+
issues = []
|
| 154 |
+
err_n = sum(1 for h in issues if h.get("rule", {}).get("severity") == "error")
|
| 155 |
+
warn_n = sum(1 for h in issues if h.get("rule", {}).get("severity") == "warning")
|
| 156 |
+
score = max(0.0, 1.0 - 0.2 * err_n - 0.05 * warn_n)
|
| 157 |
+
return {"ok": err_n == 0, "score": score,
|
| 158 |
+
"hits": [{"rule": h.get("rule", {}).get("name"),
|
| 159 |
+
"msg": h.get("message", "")[:120]}
|
| 160 |
+
for h in issues[:10]]}
|
| 161 |
+
rc, out, err = _run(
|
| 162 |
+
["terraform", "-chdir=" + td, "validate", "-no-color"], timeout=30)
|
| 163 |
+
return {"ok": rc == 0, "score": 1.0 if rc == 0 else 0.4,
|
| 164 |
+
"hits": [] if rc == 0 else [{"msg": err.splitlines()[-1] if err else "validate failed"}]}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def validate_k8s(code: str) -> dict:
|
| 168 |
+
bin_name = "kubeconform" if _have("kubeconform") else (
|
| 169 |
+
"kubeval" if _have("kubeval") else None)
|
| 170 |
+
if not bin_name:
|
| 171 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 172 |
+
"skipped": "no kubeconform/kubeval"}
|
| 173 |
+
with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as t:
|
| 174 |
+
t.write(code); t.flush()
|
| 175 |
+
path = t.name
|
| 176 |
+
try:
|
| 177 |
+
rc, out, err = _run([bin_name, "-output", "json", path], timeout=15)
|
| 178 |
+
finally:
|
| 179 |
+
os.unlink(path)
|
| 180 |
+
if rc == 0:
|
| 181 |
+
return {"ok": True, "score": 1.0, "hits": []}
|
| 182 |
+
return {"ok": False, "score": 0.4,
|
| 183 |
+
"hits": [{"msg": (err or out).splitlines()[-1][:200] if (err or out) else "invalid"}]}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def validate_actions(code: str) -> dict:
|
| 187 |
+
if not _have("actionlint"):
|
| 188 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 189 |
+
"skipped": "actionlint not installed"}
|
| 190 |
+
rc, out, err = _run(["actionlint", "-format=json", "-"], stdin=code,
|
| 191 |
+
timeout=15)
|
| 192 |
+
try:
|
| 193 |
+
hits = json.loads(out or "[]")
|
| 194 |
+
except Exception:
|
| 195 |
+
hits = []
|
| 196 |
+
err_n = len(hits)
|
| 197 |
+
score = max(0.0, 1.0 - 0.2 * err_n)
|
| 198 |
+
return {"ok": err_n == 0, "score": score,
|
| 199 |
+
"hits": [{"line": h.get("line"), "msg": h.get("message", "")[:120]}
|
| 200 |
+
for h in hits[:10]]}
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def validate_cloudformation(code: str) -> dict:
|
| 204 |
+
if not _have("cfn-lint"):
|
| 205 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 206 |
+
"skipped": "cfn-lint not installed"}
|
| 207 |
+
with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as t:
|
| 208 |
+
t.write(code); t.flush()
|
| 209 |
+
path = t.name
|
| 210 |
+
try:
|
| 211 |
+
rc, out, err = _run(["cfn-lint", "-f", "json", path], timeout=20)
|
| 212 |
+
finally:
|
| 213 |
+
os.unlink(path)
|
| 214 |
+
try:
|
| 215 |
+
hits = json.loads(out or "[]")
|
| 216 |
+
except Exception:
|
| 217 |
+
hits = []
|
| 218 |
+
err_n = sum(1 for h in hits if h.get("Level") == "Error")
|
| 219 |
+
warn_n = sum(1 for h in hits if h.get("Level") == "Warning")
|
| 220 |
+
score = max(0.0, 1.0 - 0.2 * err_n - 0.05 * warn_n)
|
| 221 |
+
return {"ok": err_n == 0, "score": score,
|
| 222 |
+
"hits": [{"rule": h.get("Rule", {}).get("Id"),
|
| 223 |
+
"msg": h.get("Message", "")[:120]} for h in hits[:10]]}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def validate_security(code: str, lang: str) -> dict:
|
| 227 |
+
"""Cross-language secrets + insecure-pattern scan via semgrep."""
|
| 228 |
+
if not _have("semgrep"):
|
| 229 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 230 |
+
"skipped": "semgrep not installed"}
|
| 231 |
+
with tempfile.NamedTemporaryFile("w", suffix="." + (
|
| 232 |
+
{"python": "py", "bash": "sh", "terraform": "tf",
|
| 233 |
+
"k8s": "yaml", "dockerfile": "Dockerfile"}.get(lang, "txt")),
|
| 234 |
+
delete=False) as t:
|
| 235 |
+
t.write(code); t.flush()
|
| 236 |
+
path = t.name
|
| 237 |
+
try:
|
| 238 |
+
rc, out, err = _run(
|
| 239 |
+
["semgrep", "--config=p/security-audit", "--json", "--quiet", path],
|
| 240 |
+
timeout=60)
|
| 241 |
+
finally:
|
| 242 |
+
os.unlink(path)
|
| 243 |
+
try:
|
| 244 |
+
obj = json.loads(out or "{}")
|
| 245 |
+
results = obj.get("results", [])
|
| 246 |
+
except Exception:
|
| 247 |
+
results = []
|
| 248 |
+
high = sum(1 for r in results
|
| 249 |
+
if r.get("extra", {}).get("severity") in ("ERROR", "WARNING"))
|
| 250 |
+
score = max(0.0, 1.0 - 0.3 * high)
|
| 251 |
+
return {"ok": high == 0, "score": score,
|
| 252 |
+
"hits": [{"rule": r.get("check_id"),
|
| 253 |
+
"msg": r.get("extra", {}).get("message", "")[:120]}
|
| 254 |
+
for r in results[:10]]}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def validate_sql(code: str) -> dict:
|
| 258 |
+
if not _have("sqlfluff"):
|
| 259 |
+
return {"ok": False, "score": 0.5, "hits": [],
|
| 260 |
+
"skipped": "sqlfluff not installed"}
|
| 261 |
+
rc, out, err = _run(
|
| 262 |
+
["sqlfluff", "lint", "--dialect", "postgres", "--format", "json", "-"],
|
| 263 |
+
stdin=code, timeout=20)
|
| 264 |
+
try:
|
| 265 |
+
hits = json.loads(out or "[]")
|
| 266 |
+
violations = []
|
| 267 |
+
for f in hits:
|
| 268 |
+
violations.extend(f.get("violations", []))
|
| 269 |
+
except Exception:
|
| 270 |
+
violations = []
|
| 271 |
+
err_n = len(violations)
|
| 272 |
+
score = max(0.0, 1.0 - 0.1 * err_n)
|
| 273 |
+
return {"ok": err_n == 0, "score": score,
|
| 274 |
+
"hits": [{"rule": v.get("code"), "msg": v.get("description", "")[:120]}
|
| 275 |
+
for v in violations[:10]]}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
VALIDATORS = {
|
| 279 |
+
"python": validate_python,
|
| 280 |
+
"bash": validate_bash,
|
| 281 |
+
"dockerfile": validate_dockerfile,
|
| 282 |
+
"terraform": validate_terraform,
|
| 283 |
+
"k8s": validate_k8s,
|
| 284 |
+
"github-actions": validate_actions,
|
| 285 |
+
"cloudformation": validate_cloudformation,
|
| 286 |
+
"sql": validate_sql,
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def score_artifact(code: str, language: str | None = None) -> dict:
|
| 291 |
+
lang = detect_lang(code, language)
|
| 292 |
+
out = {"language": lang, "validators": {}, "composite": 0.0}
|
| 293 |
+
if lang == "unknown":
|
| 294 |
+
out["composite"] = 0.5
|
| 295 |
+
out["note"] = "language could not be detected"
|
| 296 |
+
return out
|
| 297 |
+
|
| 298 |
+
base = VALIDATORS.get(lang, lambda c: {"ok": False, "score": 0.5,
|
| 299 |
+
"skipped": f"no validator for {lang}"})
|
| 300 |
+
out["validators"]["lint"] = base(code)
|
| 301 |
+
out["validators"]["security"] = validate_security(code, lang)
|
| 302 |
+
|
| 303 |
+
lint_s = out["validators"]["lint"].get("score", 0.5)
|
| 304 |
+
sec_s = out["validators"]["security"].get("score", 0.5)
|
| 305 |
+
# Composite (RLVR reward): lint 60%, security 40%. RL trainer can override.
|
| 306 |
+
out["composite"] = round(0.6 * lint_s + 0.4 * sec_s, 4)
|
| 307 |
+
return out
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def main() -> None:
|
| 311 |
+
ap = argparse.ArgumentParser()
|
| 312 |
+
ap.add_argument("--jsonl", help="batch: JSONL with {code, language?, prompt?}")
|
| 313 |
+
ap.add_argument("--out", help="batch: output JSONL with score field added")
|
| 314 |
+
args = ap.parse_args()
|
| 315 |
+
|
| 316 |
+
if args.jsonl:
|
| 317 |
+
if not args.out:
|
| 318 |
+
print("--out required with --jsonl", file=sys.stderr)
|
| 319 |
+
sys.exit(2)
|
| 320 |
+
n_in = n_out = 0
|
| 321 |
+
with open(args.jsonl) as fin, open(args.out, "w") as fout:
|
| 322 |
+
for line in fin:
|
| 323 |
+
try:
|
| 324 |
+
d = json.loads(line)
|
| 325 |
+
except Exception:
|
| 326 |
+
continue
|
| 327 |
+
n_in += 1
|
| 328 |
+
code = d.get("response") or d.get("code") or ""
|
| 329 |
+
lang = d.get("language")
|
| 330 |
+
d["validator"] = score_artifact(code, lang)
|
| 331 |
+
fout.write(json.dumps(d, ensure_ascii=False) + "\n")
|
| 332 |
+
n_out += 1
|
| 333 |
+
if n_out % 50 == 0:
|
| 334 |
+
print(f" scored {n_out}/{n_in}")
|
| 335 |
+
print(f"[done] in={n_in} scored={n_out} β {args.out}")
|
| 336 |
+
return
|
| 337 |
+
|
| 338 |
+
if sys.stdin.isatty():
|
| 339 |
+
print("usage: echo '{...}' | python3 validator-rlvr.py", file=sys.stderr)
|
| 340 |
+
sys.exit(2)
|
| 341 |
+
d = json.load(sys.stdin)
|
| 342 |
+
code = d.get("code") or d.get("response") or ""
|
| 343 |
+
lang = d.get("language")
|
| 344 |
+
print(json.dumps(score_artifact(code, lang), indent=2, ensure_ascii=False))
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
main()
|