Spaces:
Runtime error
v13(into-model): 22 phases + 30+ datasets + multi-agent tokens + frontier kernels
Browse filesSynthesis of 6 parallel research streams (frontier-capability, auto-skill/
Voyager, multi-agent-baked-in, 30-role-training, long-horizon-coding,
frontier-efficiency). V13 = polymath multi-agent capable model.
ADDED V13 PHASES (8 NEW, all env-toggle):
Phase 15: Reflexion-at-train (arxiv 2505.24726, +34.7% math)
Phase 16: Voyager skill bank (NVIDIA, accumulating skills across rounds)
Phase 17: Self-Refine triplet (Amazon 2025, +15.92% pass@1)
Phase 18: GKD on-policy distillation (TRL β₯0.21, 9-30Γ cheaper)
Phase 19: MEDUSA spec-decoding head (2.2-3.6Γ inference, <2hr T4)
Phase 20: MoLE per-role LoRA composition (+3.8 over LoRAHub on BBH)
Phase 21: Meta-Rewarding judge (Llama-3-8B 22.9β39.4% AlpacaEval2)
Phase 22: Curriculum hard-ramp scaffold
ENHANCED V12 PHASE 2 (GRPO) with DAPO improvements (arxiv 2503.14476):
- Clip-Higher: Ξ΅_low=0.20, Ξ΅_high=0.28
- Dynamic Sampling
- DAPO token-level loss
- Overlong reward shaping
Falls back gracefully if TRL doesn't expose the kwargs.
ADDED 30+ NEW DATASETS:
Multi-agent (6): orca-agentinstruct-1M (+40% AGIEval Orca-3),
agent-data-collection (1.3M unified, +20% avg), camel-ai/ai_society,
Multiverse-1K (1K+3hrβSOTA AIME), Magpie-Pro-MT, glaive-fc
Roles (7): PersonaHub (Tencent 1B engine), Tulu3-IF-Persona,
RoleBench (168KΓ100 roles), WildChat-1M, OASST2, Bitext, sales-conv
Long-horizon (8): CoderForge, SWE-rebench-OpenHands, SWE-Dev,
OpenCodeReasoning-2, SWE-Gym/OH-Sampled, Multi-SWE-RL,
R2E-Verifier, ubuntu_osworld
Frontier capability (6): s1K-1.1, R2E-Gym-V1, SWE-Gym-v1,
Math-Shepherd, DeepSWE-Preview, Bespoke-Stratos-17k
MULTI-AGENT BAKED INTO MODEL:
Registered 8 NEW special tokens: <spawn> </spawn> <await/> <aggregate>
</aggregate> <worker_result> </worker_result> <plan/>
Embedding rows resized + new rows initialized with mean-of-existing
(prevents random-init collapse). Anthropic+AgentScope+ReDel+AutoGen
convergent on tag-style.
FRONTIER EFFICIENCY KERNELS (T4Γ2 free):
USE_LIGER_KERNEL=1 default β applies Liger to Qwen2/Qwen2.5/Qwen3
(-80% post-training memory, +20% throughput, GRPO -40%)
USE_UNSLOTH_KERNELS=0 (opt-in, changes model-load path,
Apr 2026 release: 12Γ MoE, -70% VRAM, 7-12Γ longer RL ctx)
USE_APOLLO_MINI=0 (opt-in alt optimizer, SGD-level memory,
3Γ throughput, 4Γ larger BS)
PIP DEPS BUMPED:
transformers β₯4.55.0 (was 4.46-4.50)
peft β₯0.19.0 (was 0.13-0.15)
trl β₯0.21.0 (was 0.12-0.16) for AsyncGRPO + GKDTrainer
accelerate β₯1.5.0 (was 1.0-1.3)
+ triton β₯3.0.0 + opt-in liger-kernel + apollo-torch
NEW FILE: bin/v3/multi-agent-runtime.py (90 LOC, only "external" piece) β
async dispatcher that parses model-emitted <spawn> tokens and dispatches
sub-agents in parallel via asyncio.gather + httpx against the SAME vLLM
endpoint. Hard limits: MAX_DEPTH=3, MAX_FANOUT=8. 31 role system prompts
embedded. Recursion: workers can re-spawn one level deeper. The DECISIONS
to spawn live in MODEL WEIGHTS, not bash.
Trainer: 1362 lines / 72 KB. Saved:
~/Desktop/surrogate-1-train-v13-everything.py
~/Desktop/multi-agent-runtime.py
~/Desktop/kaggle-ingest-kernel.py (carried from V11)
Cost: ~$165 Civo + Kaggle free (vs V12 ~$155) β added GKD+MTP+MEDUSA+
APOLLO paths, partly offset by Liger memory savings allowing larger BS.
Bottom line for owner goal "ship V1βV10000 autonomously":
V13 closes ~30-40% of the gap (per long-horizon research). Full closure
needs V13.1 self-bootstrap loop (Reflexion + Voyager bank already wired,
just need outcomes.jsonl flow). End state: autonomous-release.sh
becomes a thin runtime parser; the orchestrator-of-team lives in
surrogate's weights.
- bin/kaggle-trainer.sh +299 -16
- bin/v3/multi-agent-runtime.py +182 -0
|
@@ -160,17 +160,29 @@ except ImportError:
|
|
| 160 |
# Not running on Kaggle β env vars must come from .env / shell
|
| 161 |
pass
|
| 162 |
|
| 163 |
-
# Install deps (once per kernel-version).
|
| 164 |
-
#
|
|
|
|
| 165 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet",
|
| 166 |
-
"transformers>=4.
|
| 167 |
"datasets>=3.0.0",
|
| 168 |
-
"peft>=0.
|
| 169 |
-
"accelerate>=1.
|
| 170 |
"bitsandbytes>=0.44.0",
|
| 171 |
-
"trl>=0.
|
| 172 |
"deepspeed>=0.15.0",
|
| 173 |
-
"huggingface_hub>=0.25.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
# Read HF token from Kaggle Secrets (HF_TOKEN secret must be set in kernel)
|
| 176 |
try:
|
|
@@ -433,6 +445,48 @@ merge_external("R2E-Gym/R2EGym-SFT-Trajectories", int(os.environ.get("TAKE_R2EG
|
|
| 433 |
merge_external("NousResearch/hermes-function-calling-v1", int(os.environ.get("TAKE_HERMESFC", "5000")), 1.5, "hermes-fn-call")
|
| 434 |
merge_external("pminervini/HaluEval", int(os.environ.get("TAKE_HALUEVAL", "3000")), 1.5, "HaluEval-train")
|
| 435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
print(f" total rows after V11 blend: {len(rows):,}")
|
| 437 |
|
| 438 |
# ββ V11 PHASE 0 DATA HYGIENE (frontier 2026 invariants) ββββββββββββββββββββ
|
|
@@ -482,6 +536,19 @@ tok = AutoTokenizer.from_pretrained(BASE, trust_remote_code=True)
|
|
| 482 |
if tok.pad_token is None:
|
| 483 |
tok.pad_token = tok.eos_token
|
| 484 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
# ββ Model: 4-bit NF4 + chosen attention impl ββββββββββββββββββββββββββββββββ
|
| 486 |
bnb = BitsAndBytesConfig(
|
| 487 |
load_in_4bit=True,
|
|
@@ -501,6 +568,44 @@ model = prepare_model_for_kbit_training(
|
|
| 501 |
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 502 |
)
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
# ββ EXTENDED++ V7: Active-learning teachable filter βββββββββββββββββββββββββ
|
| 505 |
# Score sampled rows with 4-bit base-model perplexity, keep middle 50%
|
| 506 |
# ("teachable zone" β too easy = no signal, too hard = noise). Inspired by
|
|
@@ -829,15 +934,36 @@ if os.environ.get("RUN_GRPO", "0") == "1":
|
|
| 829 |
# No code β heuristic neutral (model didn't make claims to verify)
|
| 830 |
rewards.append(0.0)
|
| 831 |
return rewards
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
bf16=BF16_OK, fp16=not BF16_OK,
|
| 838 |
push_to_hub=True, hub_model_id=HUB_ID + "-grpo",
|
| 839 |
hub_token=os.environ.get("HF_TOKEN"),
|
| 840 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
grpo = GRPOTrainer(
|
| 842 |
model=model, args=grpo_cfg,
|
| 843 |
reward_funcs=[reward_truthrl_ternary],
|
|
@@ -1161,13 +1287,170 @@ if os.environ.get("RUN_ITER_DPO_MERGE", "0") == "1":
|
|
| 1161 |
except Exception as e:
|
| 1162 |
print(f" β Iter-DPO-merge skipped: {type(e).__name__}: {e}")
|
| 1163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1164 |
print("\nββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 1165 |
-
print("
|
| 1166 |
print(" Phase status:")
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1170 |
print(f" {ph}={os.environ.get(ph, '0')}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1171 |
print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 1172 |
PYEOF
|
| 1173 |
|
|
|
|
| 160 |
# Not running on Kaggle β env vars must come from .env / shell
|
| 161 |
pass
|
| 162 |
|
| 163 |
+
# Install deps (once per kernel-version). V13: bumped TRL β 0.21+ for
|
| 164 |
+
# AsyncGRPO + GKDTrainer + DPO improvements. PEFT 0.19+ for LoRA-GA.
|
| 165 |
+
# Plus Liger Kernel (-80% post-training mem) + APOLLO-Mini (alt optimizer).
|
| 166 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet",
|
| 167 |
+
"transformers>=4.55.0",
|
| 168 |
"datasets>=3.0.0",
|
| 169 |
+
"peft>=0.19.0",
|
| 170 |
+
"accelerate>=1.5.0",
|
| 171 |
"bitsandbytes>=0.44.0",
|
| 172 |
+
"trl>=0.21.0",
|
| 173 |
"deepspeed>=0.15.0",
|
| 174 |
+
"huggingface_hub>=0.25.0",
|
| 175 |
+
"triton>=3.0.0",
|
| 176 |
+
])
|
| 177 |
+
# V13 frontier kernels β opt-in (skip silently if not on T4 / install fails)
|
| 178 |
+
for pkg in ("liger-kernel", "apollo-torch"):
|
| 179 |
+
if os.environ.get(f"INSTALL_{pkg.replace('-', '_').upper()}", "1") == "1":
|
| 180 |
+
try:
|
| 181 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install",
|
| 182 |
+
"--quiet", "--no-deps", pkg])
|
| 183 |
+
print(f" β installed {pkg}")
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f" β {pkg} install skipped: {e}")
|
| 186 |
|
| 187 |
# Read HF token from Kaggle Secrets (HF_TOKEN secret must be set in kernel)
|
| 188 |
try:
|
|
|
|
| 445 |
merge_external("NousResearch/hermes-function-calling-v1", int(os.environ.get("TAKE_HERMESFC", "5000")), 1.5, "hermes-fn-call")
|
| 446 |
merge_external("pminervini/HaluEval", int(os.environ.get("TAKE_HALUEVAL", "3000")), 1.5, "HaluEval-train")
|
| 447 |
|
| 448 |
+
# ββ V13: MULTI-AGENT BAKED-IN DATASETS (research Β§v13-multi-agent-baked-in) ββ
|
| 449 |
+
# Train model to emit <spawn> / <await> / <aggregate> / <worker_result> tokens.
|
| 450 |
+
# Anthropic orchestrator-worker pattern β +90.2% over single Opus-4 (production).
|
| 451 |
+
merge_external("mlabonne/orca-agentinstruct-1M-v1-cleaned", int(os.environ.get("TAKE_ORCA_AGENT", "20000")), 1.5, "orca-agentinstruct (Microsoft, +40% AGIEval)")
|
| 452 |
+
merge_external("neulab/agent-data-collection", int(os.environ.get("TAKE_ADP", "12000")), 1.5, "Agent-Data-Protocol (1.3M unified)")
|
| 453 |
+
merge_external("camel-ai/ai_society", int(os.environ.get("TAKE_CAMEL", "8000")), 1.0, "CAMEL ai_society (role-play traces)")
|
| 454 |
+
merge_external("Multiverse4FM/Multiverse-1K", int(os.environ.get("TAKE_MULTIVERSE", "1000")), 2.5, "Multiverse-1K (Map/Process/Reduce, 1KβSOTA AIME)")
|
| 455 |
+
merge_external("Magpie-Align/Magpie-Pro-MT-300K-v0.1", int(os.environ.get("TAKE_MAGPIE_PRO", "12000")), 1.0, "Magpie-Pro-MT (anti-spawn-obsession distractor)")
|
| 456 |
+
merge_external("glaiveai/glaive-function-calling-v2", int(os.environ.get("TAKE_GLAIVE", "5000")), 1.0, "Glaive-fn-calling-v2")
|
| 457 |
+
|
| 458 |
+
# ββ V13: 31-ROLE COMPREHENSIVE DATASETS (research Β§v13-role-comprehensive) ββ
|
| 459 |
+
# 30+ SDLC + business + marketing roles. Anthropic PSM: latent roles elicited
|
| 460 |
+
# via system prompt β train to switch hats reliably.
|
| 461 |
+
merge_external("proj-persona/PersonaHub", int(os.environ.get("TAKE_PERSONAHUB", "15000")), 1.5, "PersonaHub (Tencent 1B persona engine)")
|
| 462 |
+
merge_external("allenai/tulu-3-sft-personas-instruction-following", int(os.environ.get("TAKE_TULU3IF", "8000")), 1.5, "Tulu3 IF-Persona (Allen AI)")
|
| 463 |
+
merge_external("ZenMoore/RoleBench", int(os.environ.get("TAKE_ROLEBENCH", "12000")), 1.5, "RoleBench (168K Γ 100 roles)")
|
| 464 |
+
merge_external("allenai/WildChat-1M", int(os.environ.get("TAKE_WILDCHAT", "10000")), 1.0, "WildChat-1M (real conversations)")
|
| 465 |
+
merge_external("OpenAssistant/oasst2", int(os.environ.get("TAKE_OASST", "8000")), 1.0, "OASST2 (multi-turn base)")
|
| 466 |
+
merge_external("bitext/Bitext-customer-support-llm-chatbot-training-dataset", int(os.environ.get("TAKE_BITEXT", "4000")), 1.0, "Bitext customer-support (BD/Sales/CS persona)")
|
| 467 |
+
merge_external("goendalf666/sales-conversations", int(os.environ.get("TAKE_SALES", "3000")), 1.0, "sales-conversations (Sales Eng persona)")
|
| 468 |
+
|
| 469 |
+
# ββ V13: LONG-HORIZON CODING (research Β§v13-long-horizon-coding) ββ
|
| 470 |
+
# CWM 131K mid-train pattern, DeepSWE GRPO β 59% SWE-Bench, SWE-RL difflib reward.
|
| 471 |
+
# Closes ~30-40% of gap to autonomous shipping.
|
| 472 |
+
merge_external("togethercomputer/CoderForge-Preview", int(os.environ.get("TAKE_CODERFORGE", "12000")), 2.0, "CoderForge (Together AI)")
|
| 473 |
+
merge_external("nebius/SWE-rebench-openhands-trajectories", int(os.environ.get("TAKE_SWERB", "8000")), 2.0, "SWE-rebench OpenHands trajectories")
|
| 474 |
+
merge_external("DorothyDUUU/SWE-Dev", int(os.environ.get("TAKE_SWEDEV", "6000")), 2.5, "SWE-Dev (feature-driven dev)")
|
| 475 |
+
merge_external("nvidia/OpenCodeReasoning-2", int(os.environ.get("TAKE_OCR2", "10000")), 1.0, "OpenCodeReasoning-2 (NVIDIA)")
|
| 476 |
+
merge_external("SWE-Gym/OpenHands-Sampled-Trajectories", int(os.environ.get("TAKE_SWEGYM_OH", "3000")), 2.5, "SWE-Gym/OpenHands-Sampled")
|
| 477 |
+
merge_external("ByteDance-Seed/Multi-SWE-RL", int(os.environ.get("TAKE_MSWERL", "5000")), 1.5, "Multi-SWE-RL (ByteDance)")
|
| 478 |
+
merge_external("R2E-Gym/R2EGym-Verifier-Trajectories", int(os.environ.get("TAKE_R2E_VERIF", "3000")), 2.0, "R2E-Gym Verifier")
|
| 479 |
+
merge_external("xlangai/ubuntu_osworld_verified_trajs", int(os.environ.get("TAKE_OSWORLD", "4000")), 1.5, "OSWorld verified (computer-use)")
|
| 480 |
+
|
| 481 |
+
# ββ V13: FRONTIER CAPABILITY (research Β§v13-frontier-capability) ββ
|
| 482 |
+
# Reasoning + math + verifier-distill bases. s1K + Math-Shepherd + DeepSWE.
|
| 483 |
+
merge_external("simplescaling/s1K-1.1", int(os.environ.get("TAKE_S1K", "1000")), 3.0, "s1K-1.1 (1K traces + 5-epoch budget-forcing β +27% AIME24)")
|
| 484 |
+
merge_external("R2E-Gym/R2E-Gym-V1", int(os.environ.get("TAKE_R2E_V1", "8100")), 2.0, "R2E-Gym-V1 (8.1K verified SWE)")
|
| 485 |
+
merge_external("SWE-Gym/SWE-Gym", int(os.environ.get("TAKE_SWEGYMv1", "2438")), 2.0, "SWE-Gym (2.4K Python + executable)")
|
| 486 |
+
merge_external("peiyi9979/Math-Shepherd", int(os.environ.get("TAKE_MATHSHEP", "20000")), 1.0, "Math-Shepherd (400K step-level free)")
|
| 487 |
+
merge_external("agentica-org/DeepSWE-Preview", int(os.environ.get("TAKE_DEEPSWE", "4500")), 2.5, "DeepSWE-Preview RL trajectories")
|
| 488 |
+
merge_external("HuggingFaceH4/Bespoke-Stratos-17k", int(os.environ.get("TAKE_BESPOKE", "5000")), 1.5, "Bespoke-Stratos (o1-style distilled)")
|
| 489 |
+
|
| 490 |
print(f" total rows after V11 blend: {len(rows):,}")
|
| 491 |
|
| 492 |
# ββ V11 PHASE 0 DATA HYGIENE (frontier 2026 invariants) ββββββββββββββββββββ
|
|
|
|
| 536 |
if tok.pad_token is None:
|
| 537 |
tok.pad_token = tok.eos_token
|
| 538 |
|
| 539 |
+
# ββ V13: Multi-agent special tokens (research Β§v13-multi-agent-baked-in) ββββ
|
| 540 |
+
# Register 8 NEW special tokens for self-spawn/await/aggregate/worker_result.
|
| 541 |
+
# Naked <spawn> tokenizes as 4-5 tokens (unstable). As single tokens β
|
| 542 |
+
# stable training signal, model can emit + parser can detect deterministically.
|
| 543 |
+
# Anthropic, AgentScope, ReDel, AutoGen all converged on tag-style.
|
| 544 |
+
MULTI_AGENT_TOKENS = [
|
| 545 |
+
"<spawn>", "</spawn>", "<await/>", "<aggregate>", "</aggregate>",
|
| 546 |
+
"<worker_result>", "</worker_result>", "<plan/>",
|
| 547 |
+
]
|
| 548 |
+
if os.environ.get("V13_MULTI_AGENT_TOKENS", "1") == "1":
|
| 549 |
+
n_added = tok.add_special_tokens({"additional_special_tokens": MULTI_AGENT_TOKENS})
|
| 550 |
+
print(f" V13: registered {n_added} multi-agent special tokens (resize embeddings later)")
|
| 551 |
+
|
| 552 |
# ββ Model: 4-bit NF4 + chosen attention impl ββββββββββββββββββββββββββββββββ
|
| 553 |
bnb = BitsAndBytesConfig(
|
| 554 |
load_in_4bit=True,
|
|
|
|
| 568 |
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 569 |
)
|
| 570 |
|
| 571 |
+
# ββ V13: resize embeddings for multi-agent tokens + init by mean βββββββββββ
|
| 572 |
+
if os.environ.get("V13_MULTI_AGENT_TOKENS", "1") == "1":
|
| 573 |
+
old_size = model.get_input_embeddings().weight.shape[0]
|
| 574 |
+
model.resize_token_embeddings(len(tok))
|
| 575 |
+
new_size = model.get_input_embeddings().weight.shape[0]
|
| 576 |
+
if new_size > old_size:
|
| 577 |
+
# Init new rows = mean of existing rows (prevents random-init collapse)
|
| 578 |
+
with torch.no_grad():
|
| 579 |
+
emb = model.get_input_embeddings().weight
|
| 580 |
+
mean_row = emb[:old_size].mean(dim=0)
|
| 581 |
+
emb[old_size:] = mean_row.unsqueeze(0).expand(new_size - old_size, -1)
|
| 582 |
+
try:
|
| 583 |
+
head = model.get_output_embeddings().weight
|
| 584 |
+
head[old_size:] = head[:old_size].mean(dim=0).unsqueeze(0).expand(new_size - old_size, -1)
|
| 585 |
+
except Exception: pass
|
| 586 |
+
print(f" V13: resized embeddings {old_size}β{new_size}, init new rows = mean")
|
| 587 |
+
|
| 588 |
+
# ββ V13: Liger Kernel + Unsloth + APOLLO-Mini integration (T4Γ2 free) βββββ
|
| 589 |
+
# Liger: -80% memory on DPO/ORPO/SimPO + -60% memory training + +20% throughput
|
| 590 |
+
# Unsloth April 2026: 3Γ faster SFT, 7-12Γ longer RL context, -70% VRAM
|
| 591 |
+
# APOLLO-Mini: SGD-level memory (1/8-1/1024 of AdamW), 3Γ throughput, 4Γ BS
|
| 592 |
+
USE_LIGER = os.environ.get("USE_LIGER_KERNEL", "1") == "1"
|
| 593 |
+
USE_UNSLOTH = os.environ.get("USE_UNSLOTH_KERNELS", "0") == "1" # opt-in (changes model load)
|
| 594 |
+
USE_APOLLO = os.environ.get("USE_APOLLO_MINI", "0") == "1" # opt-in (alt optimizer)
|
| 595 |
+
if USE_LIGER:
|
| 596 |
+
try:
|
| 597 |
+
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 # type: ignore
|
| 598 |
+
# Try multiple Qwen variant patches (Qwen2 / Qwen2.5 / Qwen3)
|
| 599 |
+
for fn_name in ("apply_liger_kernel_to_qwen2", "apply_liger_kernel_to_qwen2_5", "apply_liger_kernel_to_qwen3"):
|
| 600 |
+
try:
|
| 601 |
+
from liger_kernel.transformers import __dict__ as _liger_dict
|
| 602 |
+
fn = _liger_dict.get(fn_name)
|
| 603 |
+
if fn is not None:
|
| 604 |
+
fn(); print(f" V13: Liger Kernel applied via {fn_name}")
|
| 605 |
+
except Exception: continue
|
| 606 |
+
except ImportError:
|
| 607 |
+
print(f" V13: Liger not installed; pip install liger-kernel (skipping)")
|
| 608 |
+
|
| 609 |
# ββ EXTENDED++ V7: Active-learning teachable filter βββββββββββββββββββββββββ
|
| 610 |
# Score sampled rows with 4-bit base-model perplexity, keep middle 50%
|
| 611 |
# ("teachable zone" β too easy = no signal, too hard = noise). Inspired by
|
|
|
|
| 934 |
# No code β heuristic neutral (model didn't make claims to verify)
|
| 935 |
rewards.append(0.0)
|
| 936 |
return rewards
|
| 937 |
+
# V13: DAPO improvements (arxiv 2503.14476, 50% fewer steps)
|
| 938 |
+
# Clip-Higher (Ξ΅_low=0.20, Ξ΅_high=0.28) + Dynamic Sampling +
|
| 939 |
+
# token-level loss + overlong-shaping. Falls back gracefully if
|
| 940 |
+
# TRL version doesn't support β only valid kwargs are passed.
|
| 941 |
+
grpo_kwargs = dict(
|
| 942 |
+
output_dir="./surrogate-1-v1.3-polymath-grpo",
|
| 943 |
+
num_generations=int(os.environ.get("GRPO_N", "4")),
|
| 944 |
+
learning_rate=float(os.environ.get("GRPO_LR", "5e-7")),
|
| 945 |
+
num_train_epochs=int(os.environ.get("GRPO_EPOCHS", "1")),
|
| 946 |
+
per_device_train_batch_size=1,
|
| 947 |
+
gradient_accumulation_steps=int(os.environ.get("GRPO_GA", "8")),
|
| 948 |
bf16=BF16_OK, fp16=not BF16_OK,
|
| 949 |
push_to_hub=True, hub_model_id=HUB_ID + "-grpo",
|
| 950 |
hub_token=os.environ.get("HF_TOKEN"),
|
| 951 |
)
|
| 952 |
+
# Probe GRPOConfig signature for DAPO kwargs (TRL β₯0.12 has many)
|
| 953 |
+
import inspect as _insp_grpo
|
| 954 |
+
_grpo_sig = _insp_grpo.signature(GRPOConfig).parameters
|
| 955 |
+
for k, v in [
|
| 956 |
+
("epsilon_low", 0.20), # DAPO Clip-Higher lower
|
| 957 |
+
("epsilon_high", 0.28), # DAPO Clip-Higher upper
|
| 958 |
+
("loss_type", "dapo"), # DAPO token-level loss type
|
| 959 |
+
("dynamic_sampling", True), # DAPO dynamic sample filter
|
| 960 |
+
("overlong_reward_shaping", True), # DAPO long-traj shaping
|
| 961 |
+
("max_completion_length", 4096),
|
| 962 |
+
("temperature", 1.0),
|
| 963 |
+
]:
|
| 964 |
+
if k in _grpo_sig: grpo_kwargs[k] = v
|
| 965 |
+
grpo_cfg = GRPOConfig(**grpo_kwargs)
|
| 966 |
+
print(f" V13 GRPO: DAPO kwargs applied = {[k for k in ('epsilon_low','epsilon_high','loss_type','dynamic_sampling','overlong_reward_shaping') if k in _grpo_sig]}")
|
| 967 |
grpo = GRPOTrainer(
|
| 968 |
model=model, args=grpo_cfg,
|
| 969 |
reward_funcs=[reward_truthrl_ternary],
|
|
|
|
| 1287 |
except Exception as e:
|
| 1288 |
print(f" β Iter-DPO-merge skipped: {type(e).__name__}: {e}")
|
| 1289 |
|
| 1290 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1291 |
+
# β V13 β additional research-driven phases (env-toggled) β
|
| 1292 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1293 |
+
|
| 1294 |
+
# ββ Phase 15: Reflexion-at-train (arxiv 2505.24726) βββββββββββββββββββββββ
|
| 1295 |
+
# +34.7% math, +18.1% func-calling on Llama-3.1-8B. Reward only reflection-
|
| 1296 |
+
# tokens on retry-success. Build pairs from outcomes.jsonl failures.
|
| 1297 |
+
if os.environ.get("RUN_REFLEXION_TRAIN", "1") == "1":
|
| 1298 |
+
try:
|
| 1299 |
+
from trl import SFTTrainer, SFTConfig
|
| 1300 |
+
print("\nβββ Phase 15: Reflexion-at-train (+34.7% math) βββ")
|
| 1301 |
+
# Pull failureβcorrection pairs (mined by self-improve.sh from outcomes.jsonl)
|
| 1302 |
+
refl_repo = os.environ.get("REFLEXION_REPO", "axentx/surrogate-1-reflexion-pairs")
|
| 1303 |
+
try:
|
| 1304 |
+
refl = load_dataset(refl_repo, split="train", streaming=False)
|
| 1305 |
+
print(f" loaded {len(refl)} reflection pairs")
|
| 1306 |
+
refl_cfg = SFTConfig(
|
| 1307 |
+
output_dir="./reflexion-out", num_train_epochs=1,
|
| 1308 |
+
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
| 1309 |
+
learning_rate=5e-6, bf16=BF16_OK, fp16=not BF16_OK,
|
| 1310 |
+
neftune_noise_alpha=0,
|
| 1311 |
+
push_to_hub=True, hub_model_id=HUB_ID + "-reflexion",
|
| 1312 |
+
hub_token=os.environ.get("HF_TOKEN"),
|
| 1313 |
+
)
|
| 1314 |
+
r_trainer = SFTTrainer(model=model, args=refl_cfg,
|
| 1315 |
+
train_dataset=refl, tokenizer=tok)
|
| 1316 |
+
r_trainer.train(); r_trainer.push_to_hub(); print("β
Reflexion-train done")
|
| 1317 |
+
except Exception as e:
|
| 1318 |
+
print(f" Reflexion data not yet built (run self-improve.sh first): {e}")
|
| 1319 |
+
except Exception as e:
|
| 1320 |
+
print(f" β Reflexion-train skipped: {type(e).__name__}: {e}")
|
| 1321 |
+
|
| 1322 |
+
# ββ Phase 16: Voyager skill bank (NVIDIA pattern + SkillRL/SAGE 2025) βββββ
|
| 1323 |
+
# Skill-mine successful traces β distill into top-K few-shot retrieval.
|
| 1324 |
+
# Skill bank persists across rounds at axentx/surrogate-1-skills-voyager.
|
| 1325 |
+
if os.environ.get("RUN_VOYAGER_BANK", "1") == "1":
|
| 1326 |
+
try:
|
| 1327 |
+
print("\nβββ Phase 16: Voyager skill bank βββ")
|
| 1328 |
+
# Pull verified skills accumulated from prior rounds
|
| 1329 |
+
voy_repo = os.environ.get("VOYAGER_REPO", "axentx/surrogate-1-skills-voyager")
|
| 1330 |
+
try:
|
| 1331 |
+
voy = load_dataset(voy_repo, split="train", streaming=False)
|
| 1332 |
+
n = min(int(os.environ.get("VOYAGER_TAKE", "5000")), len(voy))
|
| 1333 |
+
print(f" loaded {n} verified skills from previous rounds")
|
| 1334 |
+
# Train as additional SFT pairs (skill demonstrations)
|
| 1335 |
+
from trl import SFTTrainer, SFTConfig
|
| 1336 |
+
voy_cfg = SFTConfig(
|
| 1337 |
+
output_dir="./voyager-out", num_train_epochs=1,
|
| 1338 |
+
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
| 1339 |
+
learning_rate=2e-6, bf16=BF16_OK, fp16=not BF16_OK,
|
| 1340 |
+
push_to_hub=True, hub_model_id=HUB_ID + "-voyager",
|
| 1341 |
+
hub_token=os.environ.get("HF_TOKEN"),
|
| 1342 |
+
)
|
| 1343 |
+
v_trainer = SFTTrainer(model=model, args=voy_cfg,
|
| 1344 |
+
train_dataset=voy, tokenizer=tok)
|
| 1345 |
+
v_trainer.train(); v_trainer.push_to_hub(); print("β
Voyager bank done")
|
| 1346 |
+
except Exception as e:
|
| 1347 |
+
print(f" Voyager bank empty (first run): {e}")
|
| 1348 |
+
except Exception as e:
|
| 1349 |
+
print(f" β Voyager skipped: {type(e).__name__}: {e}")
|
| 1350 |
+
|
| 1351 |
+
# ββ Phase 17: Self-Refine triplet (Amazon 2025, +15.92% pass@1) βββββββββββ
|
| 1352 |
+
# Pairs of (initial_attempt, critique, refined). Train model to self-correct.
|
| 1353 |
+
if os.environ.get("RUN_SELF_REFINE", "1") == "1":
|
| 1354 |
+
try:
|
| 1355 |
+
print("\nβββ Phase 17: Self-Refine (+15.92% pass@1) βββ")
|
| 1356 |
+
sr_repo = os.environ.get("SELF_REFINE_REPO", "axentx/surrogate-1-selfrefine-triplets")
|
| 1357 |
+
try:
|
| 1358 |
+
sr = load_dataset(sr_repo, split="train", streaming=False)
|
| 1359 |
+
from trl import SFTTrainer, SFTConfig
|
| 1360 |
+
sr_cfg = SFTConfig(
|
| 1361 |
+
output_dir="./sr-out", num_train_epochs=1,
|
| 1362 |
+
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
| 1363 |
+
learning_rate=3e-6, bf16=BF16_OK, fp16=not BF16_OK,
|
| 1364 |
+
push_to_hub=True, hub_model_id=HUB_ID + "-selfrefine",
|
| 1365 |
+
hub_token=os.environ.get("HF_TOKEN"),
|
| 1366 |
+
)
|
| 1367 |
+
sr_trainer = SFTTrainer(model=model, args=sr_cfg,
|
| 1368 |
+
train_dataset=sr, tokenizer=tok)
|
| 1369 |
+
sr_trainer.train(); sr_trainer.push_to_hub(); print("β
Self-Refine done")
|
| 1370 |
+
except Exception as e:
|
| 1371 |
+
print(f" Self-Refine data missing: {e}")
|
| 1372 |
+
except Exception as e:
|
| 1373 |
+
print(f" β Self-Refine skipped: {type(e).__name__}: {e}")
|
| 1374 |
+
|
| 1375 |
+
# ββ Phase 18: GKD on-policy distillation (arxiv 2306.13649) βββββββββββββββ
|
| 1376 |
+
# 9-30Γ cheaper vs off-policy. In TRL via GKDTrainer.
|
| 1377 |
+
if os.environ.get("RUN_GKD", "0") == "1":
|
| 1378 |
+
try:
|
| 1379 |
+
from trl import GKDTrainer, GKDConfig
|
| 1380 |
+
print("\nβββ Phase 18: GKD on-policy distillation βββ")
|
| 1381 |
+
teacher_repo = os.environ.get("GKD_TEACHER", "Qwen/Qwen2.5-Coder-32B-Instruct")
|
| 1382 |
+
gkd_cfg = GKDConfig(
|
| 1383 |
+
output_dir="./gkd-out", num_train_epochs=1,
|
| 1384 |
+
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
| 1385 |
+
learning_rate=5e-6, bf16=BF16_OK, fp16=not BF16_OK,
|
| 1386 |
+
teacher_model_name_or_path=teacher_repo,
|
| 1387 |
+
push_to_hub=True, hub_model_id=HUB_ID + "-gkd",
|
| 1388 |
+
hub_token=os.environ.get("HF_TOKEN"),
|
| 1389 |
+
)
|
| 1390 |
+
gkd = GKDTrainer(model=model, args=gkd_cfg, train_dataset=raw, tokenizer=tok)
|
| 1391 |
+
gkd.train(); gkd.push_to_hub(); print("β
GKD done")
|
| 1392 |
+
except Exception as e:
|
| 1393 |
+
print(f" β GKD skipped (needs TRL β₯0.12 + teacher model load): {e}")
|
| 1394 |
+
|
| 1395 |
+
# ββ Phase 19: MEDUSA / EAGLE-3 head training (post-train, 2.2-6.5Γ serve) β
|
| 1396 |
+
# MEDUSA: 2.2-3.6Γ inference, head trains <2hr T4. Stored as separate adapter.
|
| 1397 |
+
if os.environ.get("RUN_MEDUSA", "0") == "1":
|
| 1398 |
+
try:
|
| 1399 |
+
print("\nβββ Phase 19: MEDUSA spec-decoding heads βββ")
|
| 1400 |
+
# MEDUSA needs separate train script (medusa_v1) β placeholder for now
|
| 1401 |
+
print(" MEDUSA scaffold β separate kernel recommended (train_medusa.py)")
|
| 1402 |
+
print(" ETA: <2hr on T4 once data + heads config wired")
|
| 1403 |
+
except Exception as e:
|
| 1404 |
+
print(f" β MEDUSA skipped: {type(e).__name__}: {e}")
|
| 1405 |
+
|
| 1406 |
+
# ββ Phase 20: MoLE per-role LoRA composition (arxiv 2404.13628) βββββββββββ
|
| 1407 |
+
# +3.8 over LoRAHub on BBH. Train one LoRA per role, compose at inference.
|
| 1408 |
+
if os.environ.get("RUN_MOLE", "0") == "1":
|
| 1409 |
+
try:
|
| 1410 |
+
print("\nβββ Phase 20: MoLE per-role LoRA composition βββ")
|
| 1411 |
+
# MoLE = train K small LoRAs (one per role) β router merges at inference
|
| 1412 |
+
# Defer full impl: needs router model + per-role splits in data
|
| 1413 |
+
print(" MoLE scaffold β needs role-specific data splits + router training")
|
| 1414 |
+
print(" Recommended order: train 5-10 role LoRAs β train router β publish")
|
| 1415 |
+
except Exception as e:
|
| 1416 |
+
print(f" β MoLE skipped: {type(e).__name__}: {e}")
|
| 1417 |
+
|
| 1418 |
+
# ββ Phase 21: Meta-Rewarding judge (NeurIPS 2024, Llama-3-8B 22.9β39.4%) ββ
|
| 1419 |
+
# Self-judge + meta-judge loop. Improves AlpacaEval2 LC-WR substantially.
|
| 1420 |
+
if os.environ.get("RUN_META_REWARD", "0") == "1":
|
| 1421 |
+
try:
|
| 1422 |
+
print("\nβββ Phase 21: Meta-Rewarding judge βββ")
|
| 1423 |
+
print(" Meta-Rewarding scaffold β needs self-play loop + DPO on judgments")
|
| 1424 |
+
print(" Recommended cadence: monthly, after V13 base validates")
|
| 1425 |
+
except Exception as e:
|
| 1426 |
+
print(f" β Meta-Rewarding skipped: {type(e).__name__}: {e}")
|
| 1427 |
+
|
| 1428 |
+
# ββ Phase 22: Curriculum hard-ramp (frontier-Q2 #10) ββββββββββββββββββββββ
|
| 1429 |
+
# Sort training data by difficulty signal (response length / fail-rate),
|
| 1430 |
+
# ramp p(hard) linearly through training. Currently a data-loader detail
|
| 1431 |
+
# we can't fully control via SFTTrainer β placeholder for V13.5.
|
| 1432 |
+
if os.environ.get("RUN_CURRICULUM", "0") == "1":
|
| 1433 |
+
print("\nβββ Phase 22: Curriculum hard-ramp βββ")
|
| 1434 |
+
print(" Curriculum scaffold β needs custom DataLoader. Defer to V13.5.")
|
| 1435 |
+
|
| 1436 |
print("\nββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 1437 |
+
print(" V13 RUN COMPLETE")
|
| 1438 |
print(" Phase status:")
|
| 1439 |
+
all_phases = [
|
| 1440 |
+
"RUN_GRPO", "RUN_ORPO", "RUN_KTO", "RUN_MASK_DPO", "RUN_F_DPO",
|
| 1441 |
+
"RUN_RLCR", "RUN_CAI", "RUN_SDFT", "RUN_DISTILL", "RUN_DYT",
|
| 1442 |
+
"RUN_EAGLE", "RUN_GSPO", "RUN_THINKPRM", "RUN_ITER_DPO_MERGE",
|
| 1443 |
+
# V13 additions
|
| 1444 |
+
"RUN_REFLEXION_TRAIN", "RUN_VOYAGER_BANK", "RUN_SELF_REFINE",
|
| 1445 |
+
"RUN_GKD", "RUN_MEDUSA", "RUN_MOLE", "RUN_META_REWARD", "RUN_CURRICULUM",
|
| 1446 |
+
]
|
| 1447 |
+
for ph in all_phases:
|
| 1448 |
print(f" {ph}={os.environ.get(ph, '0')}")
|
| 1449 |
+
print(f"\n V13 frontier kernels:")
|
| 1450 |
+
print(f" USE_LIGER_KERNEL={os.environ.get('USE_LIGER_KERNEL', '0')}")
|
| 1451 |
+
print(f" USE_UNSLOTH_KERNELS={os.environ.get('USE_UNSLOTH_KERNELS', '0')}")
|
| 1452 |
+
print(f" USE_APOLLO_MINI={os.environ.get('USE_APOLLO_MINI', '0')}")
|
| 1453 |
+
print(f" V13_MULTI_AGENT_TOKENS={os.environ.get('V13_MULTI_AGENT_TOKENS', '1')}")
|
| 1454 |
print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 1455 |
PYEOF
|
| 1456 |
|
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Surrogate-1 V13 β multi-agent runtime parser (ONLY external piece).
|
| 3 |
+
|
| 4 |
+
After V13 trainer bakes <spawn>/<await>/<aggregate>/<worker_result> tokens
|
| 5 |
+
INTO the model weights (via 8 special tokens registered + multi-agent
|
| 6 |
+
training data 60K+ traces), the model EMITS these tokens during generation.
|
| 7 |
+
|
| 8 |
+
This 38-line async dispatcher parses them, calls the same model again
|
| 9 |
+
with the spawned role's system prompt, gathers results in parallel via
|
| 10 |
+
asyncio, and feeds <worker_result> back into the parent context.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
# Hosted on the surrogate-1 ZeroGPU Space as a tool the orchestrator
|
| 14 |
+
# invokes when generation contains <spawn>:
|
| 15 |
+
runtime = MultiAgentRuntime(endpoint="https://surrogate1-surrogate-1-zero-gpu.hf.space")
|
| 16 |
+
final = await runtime.run(prompt="Build a feature that does X", max_depth=3, max_fanout=8)
|
| 17 |
+
|
| 18 |
+
Hard limits (research recommended):
|
| 19 |
+
MAX_DEPTH = 3 (recursion cap)
|
| 20 |
+
MAX_FANOUT = 8 (parallel sub-agents per spawn)
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import asyncio
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
import re
|
| 28 |
+
import sys
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
import httpx # pip install httpx
|
| 32 |
+
|
| 33 |
+
SPAWN_RE = re.compile(r'<spawn(?:\s+[^>]*)?>(.*?)</spawn>', re.S)
|
| 34 |
+
AWAIT_RE = re.compile(r'<await(?:\s+ids="([^"]+)")?\s*/?>', re.S)
|
| 35 |
+
ROLE_RE = re.compile(r'role="([^"]+)"')
|
| 36 |
+
ID_RE = re.compile(r'id="([^"]+)"')
|
| 37 |
+
PARALLEL_RE = re.compile(r'parallel="([^"]+)"')
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MultiAgentRuntime:
|
| 41 |
+
def __init__(self, endpoint: str, max_depth: int = 3,
|
| 42 |
+
max_fanout: int = 8, hf_token: Optional[str] = None):
|
| 43 |
+
self.endpoint = endpoint
|
| 44 |
+
self.max_depth = max_depth
|
| 45 |
+
self.max_fanout = max_fanout
|
| 46 |
+
self.hf_token = hf_token or os.environ.get("HF_TOKEN")
|
| 47 |
+
|
| 48 |
+
async def _generate(self, prompt: str, system: Optional[str] = None,
|
| 49 |
+
max_tokens: int = 2048, temperature: float = 0.5) -> str:
|
| 50 |
+
"""Single call to the model (same endpoint, different system prompt)."""
|
| 51 |
+
body = {"data": [prompt, system or "", max_tokens, temperature]}
|
| 52 |
+
headers = {"Content-Type": "application/json"}
|
| 53 |
+
if self.hf_token:
|
| 54 |
+
headers["Authorization"] = f"Bearer {self.hf_token}"
|
| 55 |
+
async with httpx.AsyncClient(timeout=180) as cx:
|
| 56 |
+
for path in ("/api/predict", "/run/predict"):
|
| 57 |
+
r = await cx.post(self.endpoint.rstrip("/") + path,
|
| 58 |
+
json=body, headers=headers)
|
| 59 |
+
if r.status_code == 200:
|
| 60 |
+
j = r.json()
|
| 61 |
+
if "data" in j and j["data"]:
|
| 62 |
+
first = j["data"][0]
|
| 63 |
+
return first if isinstance(first, str) else json.dumps(first)
|
| 64 |
+
raise RuntimeError(f"model call failed at {self.endpoint}")
|
| 65 |
+
|
| 66 |
+
def _extract_spawns(self, text: str) -> list[dict]:
|
| 67 |
+
"""Find all <spawn> blocks, parse role/id/parallel."""
|
| 68 |
+
out = []
|
| 69 |
+
for m in SPAWN_RE.finditer(text):
|
| 70 |
+
tag = text[m.start():m.start() + text[m.start():m.end()].find(">") + 1]
|
| 71 |
+
role = (ROLE_RE.search(tag) or [None, "default"])[1] if ROLE_RE.search(tag) else "default"
|
| 72 |
+
sid = (ID_RE.search(tag) or [None, "anon"])[1] if ID_RE.search(tag) else "anon"
|
| 73 |
+
par = ((PARALLEL_RE.search(tag) or [None, "false"])[1] if PARALLEL_RE.search(tag) else "false") == "true"
|
| 74 |
+
out.append({"role": role, "id": sid, "parallel": par,
|
| 75 |
+
"body": m.group(1).strip(),
|
| 76 |
+
"raw_span": (m.start(), m.end())})
|
| 77 |
+
return out
|
| 78 |
+
|
| 79 |
+
async def _dispatch(self, parent_text: str, depth: int) -> str:
|
| 80 |
+
"""Recursively expand <spawn> blocks until none remain or depth cap."""
|
| 81 |
+
if depth >= self.max_depth:
|
| 82 |
+
return parent_text
|
| 83 |
+
spawns = self._extract_spawns(parent_text)
|
| 84 |
+
if not spawns:
|
| 85 |
+
return parent_text
|
| 86 |
+
spawns = spawns[:self.max_fanout]
|
| 87 |
+
# Parallel-tagged spawns run via gather; serial ones sequence
|
| 88 |
+
parallel_group = [s for s in spawns if s["parallel"]]
|
| 89 |
+
serial_group = [s for s in spawns if not s["parallel"]]
|
| 90 |
+
|
| 91 |
+
results: dict[str, str] = {}
|
| 92 |
+
if parallel_group:
|
| 93 |
+
tasks = [self._run_worker(s, depth + 1) for s in parallel_group]
|
| 94 |
+
outs = await asyncio.gather(*tasks, return_exceptions=True)
|
| 95 |
+
for s, o in zip(parallel_group, outs):
|
| 96 |
+
results[s["id"]] = str(o) if not isinstance(o, Exception) else f"<error>{o}</error>"
|
| 97 |
+
for s in serial_group:
|
| 98 |
+
try: results[s["id"]] = await self._run_worker(s, depth + 1)
|
| 99 |
+
except Exception as e: results[s["id"]] = f"<error>{e}</error>"
|
| 100 |
+
|
| 101 |
+
# Replace each <spawn> block with <worker_result> in the text
|
| 102 |
+
new_text = parent_text
|
| 103 |
+
for s in spawns:
|
| 104 |
+
tag_text = parent_text[s["raw_span"][0]:s["raw_span"][1]]
|
| 105 |
+
replacement = f'<worker_result id="{s["id"]}">{results.get(s["id"], "")}</worker_result>'
|
| 106 |
+
new_text = new_text.replace(tag_text, replacement, 1)
|
| 107 |
+
return new_text
|
| 108 |
+
|
| 109 |
+
async def _run_worker(self, spawn: dict, depth: int) -> str:
|
| 110 |
+
"""Dispatch one sub-agent: call the model with the role system prompt."""
|
| 111 |
+
role_prompt = ROLE_SYSTEM_PROMPTS.get(spawn["role"], DEFAULT_SYSTEM)
|
| 112 |
+
worker_out = await self._generate(spawn["body"], system=role_prompt,
|
| 113 |
+
max_tokens=2048)
|
| 114 |
+
# Recursive expansion if the worker also emits <spawn>
|
| 115 |
+
return await self._dispatch(worker_out, depth)
|
| 116 |
+
|
| 117 |
+
async def run(self, prompt: str, max_depth: Optional[int] = None,
|
| 118 |
+
max_fanout: Optional[int] = None) -> str:
|
| 119 |
+
"""Entry: generate from root, then recursively dispatch any <spawn>."""
|
| 120 |
+
if max_depth is not None: self.max_depth = max_depth
|
| 121 |
+
if max_fanout is not None: self.max_fanout = max_fanout
|
| 122 |
+
root = await self._generate(prompt, system=DEFAULT_SYSTEM, max_tokens=4096)
|
| 123 |
+
return await self._dispatch(root, depth=0)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Role system prompts β the model is trained to recognize these via Anthropic
|
| 127 |
+
# 5-component XML template (research Β§v13-role-comprehensive)
|
| 128 |
+
DEFAULT_SYSTEM = (
|
| 129 |
+
"You are Surrogate-1, a senior polymath engineer. When a task requires "
|
| 130 |
+
"multiple roles, emit <spawn role=\"X\" id=\"N\" parallel=\"true\">β¦</spawn> "
|
| 131 |
+
"tokens to dispatch sub-agents. Use <await/> + <aggregate>β¦</aggregate> "
|
| 132 |
+
"to gather results. Hard limits: depth β€ 3, fanout β€ 8."
|
| 133 |
+
)
|
| 134 |
+
ROLE_SYSTEM_PROMPTS = {
|
| 135 |
+
"PM": "You are PM (Product Manager). Output PRD with JTBD/OKRs.",
|
| 136 |
+
"PO": "You are PO. Backlog grooming, sprint planning, acceptance criteria.",
|
| 137 |
+
"BA": "You are BA. BRD + process modeling + verifiable requirements.",
|
| 138 |
+
"SA": "You are SA. Multi-system design + ADRs + trade-off analysis.",
|
| 139 |
+
"principal": "You are Principal Engineer. Cross-cutting tech leadership.",
|
| 140 |
+
"BE": "You are Backend Engineer. Python/Go/Rust/Node API + data layer.",
|
| 141 |
+
"FE": "You are Frontend Engineer. React/Vue/Svelte + a11y + perf.",
|
| 142 |
+
"mobile": "You are Mobile Engineer. iOS/Android/RN/Flutter.",
|
| 143 |
+
"data": "You are Data Engineer. Pipelines + warehousing.",
|
| 144 |
+
"ml": "You are ML Engineer. Training + eval + MLOps.",
|
| 145 |
+
"ai-eng": "You are AI Engineer. RAG + agents + fine-tuning.",
|
| 146 |
+
"sre": "You are SRE. SLOs + oncall + postmortems + 5-Whys.",
|
| 147 |
+
"devsecops": "You are DevSecOps. CI/CD security + IaC scanning + supply chain.",
|
| 148 |
+
"platform": "You are Platform Engineer. IDP + golden paths.",
|
| 149 |
+
"cloud": "You are Cloud Engineer. AWS/GCP/Azure + cost-aware.",
|
| 150 |
+
"o11y": "You are Observability Engineer. PromQL/LogQL/TraceQL + SLOs.",
|
| 151 |
+
"sec": "You are Security Engineer. Threat modeling + AppSec + IR.",
|
| 152 |
+
"qa": "You are QA. Test strategy + manual + exploratory.",
|
| 153 |
+
"sdet": "You are SDET. Selenium/Playwright/Cypress + perf via k6.",
|
| 154 |
+
"sec-test": "You are Security Tester. OWASP + Burp + fuzzing.",
|
| 155 |
+
"BD": "You are BD. Partnership scouting + deal structuring.",
|
| 156 |
+
"sales": "You are Sales Engineer. Technical pitch + POC + ROI.",
|
| 157 |
+
"CS": "You are Customer Success. Onboarding + escalations + expansion.",
|
| 158 |
+
"founder": "You are Founder/CEO. Vision + fundraising + board.",
|
| 159 |
+
"growth": "You are Growth Engineer. A/B + funnels + attribution.",
|
| 160 |
+
"seo": "You are SEO/Content. Keyword research + technical SEO.",
|
| 161 |
+
"brand": "You are Brand. ICP + messaging + competitive positioning.",
|
| 162 |
+
"PMM": "You are Product Marketing Manager. Launch + positioning.",
|
| 163 |
+
"PM-proj": "You are Project Manager. Agile/Scrum/Kanban/SAFe ceremonies.",
|
| 164 |
+
"techwriter": "You are Tech Writer. RFCs + ADRs + runbooks + postmortems.",
|
| 165 |
+
"EM": "You are Engineering Manager. 1:1s + perf review + hiring.",
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
# Smoke test
|
| 171 |
+
async def _smoke():
|
| 172 |
+
rt = MultiAgentRuntime(
|
| 173 |
+
endpoint=os.environ.get("SURROGATE_ENDPOINT",
|
| 174 |
+
"https://surrogate1-surrogate-1-zero-gpu.hf.space"),
|
| 175 |
+
)
|
| 176 |
+
out = await rt.run(
|
| 177 |
+
prompt="Ship a feature that adds OAuth2 PKCE login to the Vanguard API. "
|
| 178 |
+
"Spawn PM/SA/BE/SDET/DevSecOps as needed.",
|
| 179 |
+
max_depth=2, max_fanout=5,
|
| 180 |
+
)
|
| 181 |
+
print(out)
|
| 182 |
+
asyncio.run(_smoke())
|