File size: 14,725 Bytes
1158f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d95cff9
0b3071f
 
1158f2c
 
d95cff9
 
 
 
 
 
1158f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a80dea9
d95cff9
a80dea9
1158f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b3071f
 
 
 
 
 
 
 
 
5975e96
 
d95cff9
5975e96
fd8e4cd
0b3071f
d95cff9
 
1158f2c
d95cff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1158f2c
 
 
 
 
 
 
 
 
 
 
 
 
5975e96
 
d95cff9
fd8e4cd
5975e96
1158f2c
 
 
 
 
0b3071f
1158f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0986040
 
 
 
 
1158f2c
 
 
5975e96
fd8e4cd
1158f2c
 
0986040
1158f2c
 
 
 
 
 
 
 
 
 
 
 
d95cff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1158f2c
d95cff9
1158f2c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""
Research subagent tool β€” spawns a cheap LLM call with a focused
research task and returns a summary. The subagent gets its own
independent context (not the main conversation), so research
work doesn't pollute the main agent's context window.

Inspired by claude-code's code-explorer agent pattern.
"""

import json
import logging
import os
from typing import Any

from litellm import Message, acompletion

from agent.core.doom_loop import check_for_doom_loop
from agent.core.session import Event

logger = logging.getLogger(__name__)

# Context budget for the research subagent (tokens).
# When usage exceeds WARN threshold, the subagent is told to wrap up.
# At MAX, the loop is force-stopped and whatever content exists is returned.
_RESEARCH_CONTEXT_WARN = 170_000  # 85% of 200k
_RESEARCH_CONTEXT_MAX = 190_000

# Tools the research agent can use (read-only subset)
RESEARCH_TOOL_NAMES = {
    "read",
    "bash",
    "explore_hf_docs",
    "fetch_hf_docs",
    "find_hf_api",
    "hf_papers",
    "github_find_examples",
    "github_list_repos",
    "github_read_file",
    "hf_inspect_dataset",
    "hf_repo_files",
}

RESEARCH_SYSTEM_PROMPT = """\
You are a research sub-agent for an ML engineering assistant.
Your job: explore documentation, code examples, APIs, and repos,
then return a concise, actionable summary. The main agent will use
your findings to implement the actual solution.

# Research methodology

1. **Discovery**: Find relevant entry points β€” example scripts, doc pages, API endpoints
2. **Tracing**: Follow the chain from entry point to implementation detail
3. **Analysis**: Identify patterns, current API usage, key dependencies
4. **Synthesis**: Summarize findings in a structured format

# How to use your tools

## GitHub code research (USE FIRST for any ML implementation task)
- `github_find_examples`: Find working example scripts in HF repos (trl, transformers, etc.)
  Example: `github_find_examples({"repo": "trl", "keyword": "sft"})`
  Returns: file paths in examples/, scripts/, notebooks/ directories
- `github_read_file`: Read the actual implementation code
  Example: `github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"})`
  Use line_start/line_end for large files

## Documentation
- `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc.
- `fetch_hf_docs(url)`: Fetch full page content from explore results
- `find_hf_api(query=..., tag=...)`: Find REST API endpoints

## Dataset inspection
- `hf_inspect_dataset`: Check dataset schema, splits, sample rows
  CRITICAL for training: verify column format matches training method:
  - SFT: needs "messages", "text", or "prompt"/"completion"
  - DPO: needs "prompt", "chosen", "rejected"
  - GRPO: needs "prompt" only

## Papers
- `hf_papers`: Search papers, get details, find linked datasets/models

## Hub repo inspection
- `hf_repo_files`: List/read files in any HF repo (model, dataset, space)

# Correct research pattern for ML tasks

```
# 1. Find working example code FIRST
github_find_examples({"repo": "trl", "keyword": "sft"})

# 2. Read the implementation
github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"})

# 3. Check docs for parameters/config details
explore_hf_docs("trl")
fetch_hf_docs("https://huggingface.co/docs/trl/sft_trainer")

# 4. Validate dataset format if relevant
hf_inspect_dataset({"dataset": "org/name", "split": "train", "sample_rows": 3})
```

# Output format

Your output MUST include:
- **Key findings**: The most important things you discovered (current API usage, working patterns)
- **Essential references**: Specific file paths, URLs, function names, doc sections, code snippets
  that the main agent should use directly
- **Code patterns**: Key imports, configurations, and usage patterns from working examples
- **Recommendations**: What to do next based on your findings

Be concise. Your output goes into another agent's context β€” every token counts.
Aim for 500-1500 words max. Include actual code snippets from examples you read,
not paraphrased descriptions.
"""

RESEARCH_TOOL_SPEC = {
    "name": "research",
    "description": (
        "Spawn a research sub-agent to explore documentation, codebases, "
        "or repos WITHOUT polluting the main conversation context. "
        "The sub-agent gets its own independent context window with read-only "
        "research tools and returns a concise summary of findings.\n\n"
        "Use this for:\n"
        "- Researching current API usage before implementing ML tasks "
        "(find examples + read docs)\n"
        "- Exploring HF docs, reading papers, analyzing GitHub repos\n"
        "- Any research where raw tool outputs would be too verbose\n\n"
        "The sub-agent knows how to use github_find_examples, github_read_file, "
        "explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, hf_papers, etc. "
        "Just describe what you need researched."
    ),
    "parameters": {
        "type": "object",
        "properties": {
            "task": {
                "type": "string",
                "description": (
                    "Detailed description of what to research. Be specific: "
                    "include library names, trainer types, dataset names, "
                    "repo names, or doc pages to explore. Example: "
                    "'Research current TRL SFTTrainer usage: find working "
                    "example scripts, read the SFT documentation, and check "
                    "SFTConfig parameters. Also validate that dataset "
                    "HuggingFaceH4/ultrachat_200k has the right format for SFT.'"
                ),
            },
            "context": {
                "type": "string",
                "description": (
                    "Optional context from the current conversation that the "
                    "research agent needs (e.g., what the user wants to build, "
                    "constraints, what's been tried)."
                ),
            },
        },
        "required": ["task"],
    },
}


def _resolve_llm_params(model_name: str) -> dict:
    """Build LiteLLM kwargs, reusing the HF router logic from agent_loop."""
    if not model_name.startswith("huggingface/"):
        return {"model": model_name}

    parts = model_name.split("/", 2)  # ["huggingface", "<provider>", "<org>/<model>"]
    if len(parts) < 3:
        return {"model": model_name}

    provider = parts[1]
    model_id = parts[2]
    return {
        "model": f"openai/{model_id}",
        "api_base": f"https://router.huggingface.co/{provider}/v3/openai",
        "api_key": os.environ.get("INFERENCE_TOKEN", ""),
    }


def _get_research_model(main_model: str) -> str:
    """Pick a cheaper model for research based on the main model."""
    if "anthropic/" in main_model:
        return "anthropic/claude-sonnet-4-6"
    # For non-Anthropic models (HF router etc.), use the same model
    return main_model


async def research_handler(
    arguments: dict[str, Any], session=None, **_kw
) -> tuple[str, bool]:
    """Execute a research sub-agent with its own context."""
    task = arguments.get("task", "")
    context = arguments.get("context", "")
    if not task:
        return "No research task provided.", False

    if not session:
        return "No session available for research agent.", False

    # Build the sub-agent's messages (independent context)
    messages: list[Message] = [
        Message(role="system", content=RESEARCH_SYSTEM_PROMPT),
    ]

    user_content = f"Research task: {task}"
    if context:
        user_content = f"Context: {context}\n\n{user_content}"
    messages.append(Message(role="user", content=user_content))

    # Use a cheaper/faster model for research
    main_model = session.config.model_name
    research_model = _get_research_model(main_model)
    llm_params = _resolve_llm_params(research_model)

    # Get read-only tool specs from the session's tool router
    tool_specs = [
        spec
        for spec in session.tool_router.get_tool_specs_for_llm()
        if spec["function"]["name"] in RESEARCH_TOOL_NAMES
    ]

    async def _log(text: str) -> None:
        """Send a progress event to the UI so it doesn't look frozen."""
        try:
            await session.send_event(
                Event(event_type="tool_log", data={"tool": "research", "log": text})
            )
        except Exception:
            pass

    _tool_uses = 0
    _total_tokens = 0
    _warned_context = False

    await _log("Starting research sub-agent...")

    # Run the research loop β€” context budget is the real limiter
    max_iterations = 60
    for _iteration in range(max_iterations):
        # ── Doom-loop detection ──
        doom_prompt = check_for_doom_loop(messages)
        if doom_prompt:
            logger.warning("Research sub-agent doom loop detected at iteration %d", _iteration)
            await _log("Doom loop detected β€” injecting corrective prompt")
            messages.append(Message(role="user", content=doom_prompt))

        # ── Context budget: warn at 75%, hard-stop at 95% ──
        if _total_tokens >= _RESEARCH_CONTEXT_MAX:
            logger.warning(
                "Research sub-agent hit context max (%d tokens) β€” forcing summary",
                _total_tokens,
            )
            await _log(f"Context limit reached ({_total_tokens} tokens) β€” forcing wrap-up")
            # Ask for a final summary with no tools
            messages.append(Message(
                role="user",
                content=(
                    "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. "
                    "Summarize your findings NOW. Do NOT call any more tools."
                ),
            ))
            try:
                response = await acompletion(
                    messages=messages,
                    tools=None,  # no tools β€” force text response
                    stream=False,
                    timeout=120,
                    **llm_params,
                )
                content = response.choices[0].message.content or ""
                return content or "Research context exhausted β€” no summary produced.", bool(content)
            except Exception:
                return "Research context exhausted and summary call failed.", False

        if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN:
            _warned_context = True
            await _log(f"Context at {_total_tokens} tokens β€” nudging to wrap up")
            messages.append(Message(
                role="user",
                content=(
                    "[SYSTEM: You have used 75% of your context budget. "
                    "Start wrapping up: finish any critical lookups, then "
                    "produce your final summary within the next 1-2 iterations.]"
                ),
            ))

        try:
            response = await acompletion(
                messages=messages,
                tools=tool_specs if tool_specs else None,
                tool_choice="auto",
                stream=False,
                timeout=120,
                **llm_params,
            )
        except Exception as e:
            logger.error("Research sub-agent LLM error: %s", e)
            return f"Research agent LLM error: {e}", False

        # Track tokens
        if response.usage:
            _total_tokens = response.usage.total_tokens
            await _log(f"tokens:{_total_tokens}")

        choice = response.choices[0]
        msg = choice.message

        # If no tool calls, we have our final answer
        if not msg.tool_calls:
            await _log("Research complete.")
            content = msg.content or "Research completed but no summary generated."
            return content, True

        # Execute tool calls and add results
        messages.append(msg)
        for tc in msg.tool_calls:
            try:
                tool_args = json.loads(tc.function.arguments)
            except (json.JSONDecodeError, TypeError):
                messages.append(
                    Message(
                        role="tool",
                        content="Invalid tool arguments.",
                        tool_call_id=tc.id,
                        name=tc.function.name,
                    )
                )
                continue

            tool_name = tc.function.name
            if tool_name not in RESEARCH_TOOL_NAMES:
                messages.append(
                    Message(
                        role="tool",
                        content=f"Tool '{tool_name}' not available for research.",
                        tool_call_id=tc.id,
                        name=tool_name,
                    )
                )
                continue

            try:
                import json as _json

                args_str = _json.dumps(tool_args)[:80]
                await _log(f"β–Έ {tool_name}  {args_str}")

                output, _success = await session.tool_router.call_tool(
                    tool_name, tool_args, session=session
                )
                _tool_uses += 1
                await _log(f"tools:{_tool_uses}")
                # Truncate tool output for the research context
                if len(output) > 8000:
                    output = output[:4800] + "\n...(truncated)...\n" + output[-3200:]
            except Exception as e:
                output = f"Tool error: {e}"

            messages.append(
                Message(
                    role="tool",
                    content=output,
                    tool_call_id=tc.id,
                    name=tool_name,
                )
            )

    # ── Iteration limit: try to salvage findings ──
    await _log("Iteration limit reached β€” extracting summary")
    messages.append(Message(
        role="user",
        content=(
            "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research "
            "iterations. Summarize ALL findings so far. Do NOT call any more tools."
        ),
    ))
    try:
        response = await acompletion(
            messages=messages,
            tools=None,
            stream=False,
            timeout=120,
            **llm_params,
        )
        content = response.choices[0].message.content or ""
        if content:
            return content, True
    except Exception as e:
        logger.error("Research summary call failed: %s", e)

    return (
        "Research agent hit iteration limit (60). "
        "Partial findings may be incomplete β€” try a more focused task.",
        False,
    )