File size: 9,378 Bytes
dc702a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
"""Surrogate-1 V13 β€” multi-agent runtime parser (ONLY external piece).

After V13 trainer bakes <spawn>/<await>/<aggregate>/<worker_result> tokens
INTO the model weights (via 8 special tokens registered + multi-agent
training data 60K+ traces), the model EMITS these tokens during generation.

This 38-line async dispatcher parses them, calls the same model again
with the spawned role's system prompt, gathers results in parallel via
asyncio, and feeds <worker_result> back into the parent context.

Usage:
    # Hosted on the surrogate-1 ZeroGPU Space as a tool the orchestrator
    # invokes when generation contains <spawn>:
    runtime = MultiAgentRuntime(endpoint="https://surrogate1-surrogate-1-zero-gpu.hf.space")
    final = await runtime.run(prompt="Build a feature that does X", max_depth=3, max_fanout=8)

Hard limits (research recommended):
    MAX_DEPTH  = 3   (recursion cap)
    MAX_FANOUT = 8   (parallel sub-agents per spawn)
"""
from __future__ import annotations

import asyncio
import json
import os
import re
import sys
from typing import Optional

import httpx  # pip install httpx

SPAWN_RE = re.compile(r'<spawn(?:\s+[^>]*)?>(.*?)</spawn>', re.S)
AWAIT_RE = re.compile(r'<await(?:\s+ids="([^"]+)")?\s*/?>', re.S)
ROLE_RE  = re.compile(r'role="([^"]+)"')
ID_RE    = re.compile(r'id="([^"]+)"')
PARALLEL_RE = re.compile(r'parallel="([^"]+)"')


class MultiAgentRuntime:
    def __init__(self, endpoint: str, max_depth: int = 3,
                 max_fanout: int = 8, hf_token: Optional[str] = None):
        self.endpoint = endpoint
        self.max_depth = max_depth
        self.max_fanout = max_fanout
        self.hf_token = hf_token or os.environ.get("HF_TOKEN")

    async def _generate(self, prompt: str, system: Optional[str] = None,
                        max_tokens: int = 2048, temperature: float = 0.5) -> str:
        """Single call to the model (same endpoint, different system prompt)."""
        body = {"data": [prompt, system or "", max_tokens, temperature]}
        headers = {"Content-Type": "application/json"}
        if self.hf_token:
            headers["Authorization"] = f"Bearer {self.hf_token}"
        async with httpx.AsyncClient(timeout=180) as cx:
            for path in ("/api/predict", "/run/predict"):
                r = await cx.post(self.endpoint.rstrip("/") + path,
                                  json=body, headers=headers)
                if r.status_code == 200:
                    j = r.json()
                    if "data" in j and j["data"]:
                        first = j["data"][0]
                        return first if isinstance(first, str) else json.dumps(first)
        raise RuntimeError(f"model call failed at {self.endpoint}")

    def _extract_spawns(self, text: str) -> list[dict]:
        """Find all <spawn> blocks, parse role/id/parallel."""
        out = []
        for m in SPAWN_RE.finditer(text):
            tag = text[m.start():m.start() + text[m.start():m.end()].find(">") + 1]
            role = (ROLE_RE.search(tag) or [None, "default"])[1] if ROLE_RE.search(tag) else "default"
            sid  = (ID_RE.search(tag) or [None, "anon"])[1] if ID_RE.search(tag) else "anon"
            par  = ((PARALLEL_RE.search(tag) or [None, "false"])[1] if PARALLEL_RE.search(tag) else "false") == "true"
            out.append({"role": role, "id": sid, "parallel": par,
                         "body": m.group(1).strip(),
                         "raw_span": (m.start(), m.end())})
        return out

    async def _dispatch(self, parent_text: str, depth: int) -> str:
        """Recursively expand <spawn> blocks until none remain or depth cap."""
        if depth >= self.max_depth:
            return parent_text
        spawns = self._extract_spawns(parent_text)
        if not spawns:
            return parent_text
        spawns = spawns[:self.max_fanout]
        # Parallel-tagged spawns run via gather; serial ones sequence
        parallel_group = [s for s in spawns if s["parallel"]]
        serial_group   = [s for s in spawns if not s["parallel"]]

        results: dict[str, str] = {}
        if parallel_group:
            tasks = [self._run_worker(s, depth + 1) for s in parallel_group]
            outs = await asyncio.gather(*tasks, return_exceptions=True)
            for s, o in zip(parallel_group, outs):
                results[s["id"]] = str(o) if not isinstance(o, Exception) else f"<error>{o}</error>"
        for s in serial_group:
            try: results[s["id"]] = await self._run_worker(s, depth + 1)
            except Exception as e: results[s["id"]] = f"<error>{e}</error>"

        # Replace each <spawn> block with <worker_result> in the text
        new_text = parent_text
        for s in spawns:
            tag_text = parent_text[s["raw_span"][0]:s["raw_span"][1]]
            replacement = f'<worker_result id="{s["id"]}">{results.get(s["id"], "")}</worker_result>'
            new_text = new_text.replace(tag_text, replacement, 1)
        return new_text

    async def _run_worker(self, spawn: dict, depth: int) -> str:
        """Dispatch one sub-agent: call the model with the role system prompt."""
        role_prompt = ROLE_SYSTEM_PROMPTS.get(spawn["role"], DEFAULT_SYSTEM)
        worker_out = await self._generate(spawn["body"], system=role_prompt,
                                           max_tokens=2048)
        # Recursive expansion if the worker also emits <spawn>
        return await self._dispatch(worker_out, depth)

    async def run(self, prompt: str, max_depth: Optional[int] = None,
                  max_fanout: Optional[int] = None) -> str:
        """Entry: generate from root, then recursively dispatch any <spawn>."""
        if max_depth is not None: self.max_depth = max_depth
        if max_fanout is not None: self.max_fanout = max_fanout
        root = await self._generate(prompt, system=DEFAULT_SYSTEM, max_tokens=4096)
        return await self._dispatch(root, depth=0)


# Role system prompts β€” the model is trained to recognize these via Anthropic
# 5-component XML template (research Β§v13-role-comprehensive)
DEFAULT_SYSTEM = (
    "You are Surrogate-1, a senior polymath engineer. When a task requires "
    "multiple roles, emit <spawn role=\"X\" id=\"N\" parallel=\"true\">…</spawn> "
    "tokens to dispatch sub-agents. Use <await/> + <aggregate>…</aggregate> "
    "to gather results. Hard limits: depth ≀ 3, fanout ≀ 8."
)
ROLE_SYSTEM_PROMPTS = {
    "PM":         "You are PM (Product Manager). Output PRD with JTBD/OKRs.",
    "PO":         "You are PO. Backlog grooming, sprint planning, acceptance criteria.",
    "BA":         "You are BA. BRD + process modeling + verifiable requirements.",
    "SA":         "You are SA. Multi-system design + ADRs + trade-off analysis.",
    "principal":  "You are Principal Engineer. Cross-cutting tech leadership.",
    "BE":         "You are Backend Engineer. Python/Go/Rust/Node API + data layer.",
    "FE":         "You are Frontend Engineer. React/Vue/Svelte + a11y + perf.",
    "mobile":     "You are Mobile Engineer. iOS/Android/RN/Flutter.",
    "data":       "You are Data Engineer. Pipelines + warehousing.",
    "ml":         "You are ML Engineer. Training + eval + MLOps.",
    "ai-eng":     "You are AI Engineer. RAG + agents + fine-tuning.",
    "sre":        "You are SRE. SLOs + oncall + postmortems + 5-Whys.",
    "devsecops":  "You are DevSecOps. CI/CD security + IaC scanning + supply chain.",
    "platform":   "You are Platform Engineer. IDP + golden paths.",
    "cloud":      "You are Cloud Engineer. AWS/GCP/Azure + cost-aware.",
    "o11y":       "You are Observability Engineer. PromQL/LogQL/TraceQL + SLOs.",
    "sec":        "You are Security Engineer. Threat modeling + AppSec + IR.",
    "qa":         "You are QA. Test strategy + manual + exploratory.",
    "sdet":       "You are SDET. Selenium/Playwright/Cypress + perf via k6.",
    "sec-test":   "You are Security Tester. OWASP + Burp + fuzzing.",
    "BD":         "You are BD. Partnership scouting + deal structuring.",
    "sales":      "You are Sales Engineer. Technical pitch + POC + ROI.",
    "CS":         "You are Customer Success. Onboarding + escalations + expansion.",
    "founder":    "You are Founder/CEO. Vision + fundraising + board.",
    "growth":     "You are Growth Engineer. A/B + funnels + attribution.",
    "seo":        "You are SEO/Content. Keyword research + technical SEO.",
    "brand":      "You are Brand. ICP + messaging + competitive positioning.",
    "PMM":        "You are Product Marketing Manager. Launch + positioning.",
    "PM-proj":    "You are Project Manager. Agile/Scrum/Kanban/SAFe ceremonies.",
    "techwriter": "You are Tech Writer. RFCs + ADRs + runbooks + postmortems.",
    "EM":         "You are Engineering Manager. 1:1s + perf review + hiring.",
}


if __name__ == "__main__":
    # Smoke test
    async def _smoke():
        rt = MultiAgentRuntime(
            endpoint=os.environ.get("SURROGATE_ENDPOINT",
                                     "https://surrogate1-surrogate-1-zero-gpu.hf.space"),
        )
        out = await rt.run(
            prompt="Ship a feature that adds OAuth2 PKCE login to the Vanguard API. "
                   "Spawn PM/SA/BE/SDET/DevSecOps as needed.",
            max_depth=2, max_fanout=5,
        )
        print(out)
    asyncio.run(_smoke())