diff --git a/.gitattributes b/.gitattributes index 5c1fa543a2dcf0e292a5151a6d696f7f59a1556b..24a8e87939aa53cdd833f6be7610cb4972e063ad 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1 @@ *.png filter=lfs diff=lfs merge=lfs -text -README.md merge=ours diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 5d79742fe97daa25e23740b7904a69439fd38368..0000000000000000000000000000000000000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: CI - -on: - pull_request: - push: - branches: [main] - -permissions: - contents: read - -concurrency: - group: ci-${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - ruff: - name: Ruff - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v5 - with: - enable-cache: true - cache-dependency-glob: uv.lock - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - - name: Install dependencies - run: uv sync --locked --extra dev - - - name: Run Ruff - run: uv run ruff check . - - - name: Check formatting - run: uv run ruff format --check . - - tests: - name: Tests - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v5 - with: - enable-cache: true - cache-dependency-glob: uv.lock - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - - name: Install dependencies - run: uv sync --locked --extra dev - - - name: Run tests - run: uv run pytest diff --git a/.github/workflows/claude-review.yml b/.github/workflows/claude-review.yml deleted file mode 100644 index 1304cfb9cf5efb059ae02ed071ef2030390802bf..0000000000000000000000000000000000000000 --- a/.github/workflows/claude-review.yml +++ /dev/null @@ -1,78 +0,0 @@ -name: Claude PR Review - -on: - pull_request_target: - types: [opened, synchronize, ready_for_review, reopened] - -permissions: - contents: read - pull-requests: write - issues: read - id-token: write - -concurrency: - group: claude-review-${{ github.event.pull_request.number }} - cancel-in-progress: true - -jobs: - review: - if: github.event.pull_request.draft == false - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - # On pull_request_target, keep checkout on the trusted base-repo ref. - # The Claude action can review the PR via GitHub context/API without - # executing untrusted fork code with repository secrets. - persist-credentials: false - - - name: Compose review prompt - id: compose - run: | - { - printf 'prompt<> "$GITHUB_OUTPUT" - - - name: Prepare Claude Code bin directory - run: mkdir -p "$HOME/.local/bin" - - - uses: anthropics/claude-code-action@v1 - with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - # Bypass the OIDC -> Claude GitHub App token exchange. That exchange - # rejects OIDC tokens minted for pull_request_target events with - # "401 Invalid OIDC token", which broke every review after the switch - # away from pull_request. Using the workflow's GITHUB_TOKEN works for - # both same-repo and fork PRs; comments post as github-actions[bot] - # instead of claude[bot], which is the documented trade-off. - github_token: ${{ secrets.GITHUB_TOKEN }} - track_progress: true - prompt: ${{ steps.compose.outputs.prompt }} diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml deleted file mode 100644 index d3036a23259e41c48a7efe2aa72a2d7c3c77bebf..0000000000000000000000000000000000000000 --- a/.github/workflows/claude.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Claude on Mention - -on: - issue_comment: - types: [created] - pull_request_review_comment: - types: [created] - pull_request_review: - types: [submitted] - issues: - types: [opened, assigned] - -permissions: - contents: write - pull-requests: write - issues: write - id-token: write - -jobs: - claude: - if: | - (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || - (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || - (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || - (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - uses: anthropics/claude-code-action@v1 - with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - track_progress: true diff --git a/.gitignore b/.gitignore index c10ab3552f66fe78ba9248272546e98a050789be..71fc3082173c89ba24599c3429094701bd0987c2 100644 --- a/.gitignore +++ b/.gitignore @@ -52,11 +52,7 @@ frontend/yarn-error.log* # Docker .docker/ -# Eval (stale) -eval/ - # Project-specific -scratch/ session_logs/ /logs hf-agent-leaderboard/ diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index 03f3bd9d98fee110961befde5d4ac148421f4b28..0000000000000000000000000000000000000000 --- a/AGENTS.md +++ /dev/null @@ -1,47 +0,0 @@ -# Agent Notes - -## Local Dev Servers - -- Frontend: from `frontend/`, run `npm ci` if dependencies are missing, then `npm run dev`. -- Backend: from `backend/`, run `uv run uvicorn main:app --host ::1 --port 7860`. -- Frontend URL: http://localhost:5173/ -- Backend health check: `curl -g http://[::1]:7860/api` -- Frontend proxy health check: `curl http://localhost:5173/api` - -Notes: - -- Vite proxies `/api` and `/auth` to `http://localhost:7860`. -- If `127.0.0.1:7860` is already owned by another local process, binding the backend to `::1` lets the Vite proxy resolve `localhost` cleanly. -- Prefer `npm ci` over `npm install` for setup, since `npm install` may rewrite `frontend/package-lock.json` metadata depending on npm version. -- Production defaults to the Bedrock Claude model. For local development with a personal Anthropic key, set `ANTHROPIC_API_KEY` and `ML_INTERN_CLAUDE_MODEL_ID=anthropic/claude-opus-4-6` before starting the backend. Other models are selected through the app's model switcher. - -## Development Checks - -- Before every commit, run `uv run ruff check .` and `uv run ruff format --check .`. -- If formatting fails, run `uv run ruff format .`, then re-run the Ruff checks before committing. - -## GitHub CLI - -- For multiline PR descriptions, prefer `gh pr edit --body-file ` over inline `--body` so shell quoting, `$` env-var names, backticks, and newlines are preserved correctly. - -## GitHub PRs - -- Open code changes as GitHub PRs first. Do not push code changes directly to the Hugging Face Space deployment branch or Space remote before the PR has been opened, reviewed, and merged, unless the user explicitly asks to bypass the PR flow. - -## Hugging Face Space Deploys - -- The Space remote is `space` and points to `https://huggingface.co/spaces/smolagents/ml-intern`. -- Deploy GitHub `main` to the Space from the local `space-main` branch by merging `origin/main` into `space-main` with a single merge commit, then pushing `space-main:main` to the `space` remote. -- Keep the Space-only README frontmatter on `space-main`; `.gitattributes` should contain `README.md merge=ours` and the local repo config should include `merge.ours.driver=true`. -- Local dev commonly uses a personal `HF_TOKEN`, but the deployed Space uses HF OAuth tokens. When adding Hub features, make sure the Space README `hf_oauth_scopes` frontmatter and the backend OAuth request in `backend/routes/auth.py` include the scopes required by the Hub APIs being called. A feature can work locally with a broad PAT and still fail in production with 403s if OAuth scopes are missing; after changing scopes, users may need to log out and log in again to receive a fresh token. -- Recommended deploy flow: - -```bash -git pull --ff-only origin main -git switch space-main -git config merge.ours.driver true -git merge --no-ff origin/main -m "Deploy $(date +%Y-%m-%d)" \ - -m "Co-authored-by: OpenAI Codex " -git push space space-main:main -git switch main -``` diff --git a/Dockerfile b/Dockerfile index 264dd3d9f97d6d2353e96611a6af7c90680f17b1..4faee687f7796beee985dc9fcdd3853199fffac7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,7 @@ COPY pyproject.toml uv.lock ./ # Install dependencies into /app/.venv # Use --frozen to ensure exact versions from uv.lock -RUN uv sync --no-dev --frozen +RUN uv sync --extra agent --no-dev --frozen # Copy application code COPY agent/ ./agent/ diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64..0000000000000000000000000000000000000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md index 26b65da33edb1b3d882cdf9a6896a03030213195..fed2a689ba76144abbf89ce49d0b74a973325062 100644 --- a/README.md +++ b/README.md @@ -1,164 +1,57 @@ --- -title: ML Intern +title: HF Agent emoji: πŸ€– -colorFrom: yellow -colorTo: blue +colorFrom: blue +colorTo: purple sdk: docker app_port: 7860 hf_oauth: true -hf_oauth_expiration_minutes: 43200 hf_oauth_scopes: - read-repos - write-repos - contribute-repos - manage-repos - - write-collections - inference-api - jobs - write-discussions --- -

- smolagents logo -

+# HF Agent -# ML Intern +An MLE agent CLI with MCP (Model Context Protocol) integration and built-in tool support. -An ML intern that autonomously researches, writes, and ships good quality ML related code using the Hugging Face ecosystem β€” with deep access to docs, papers, datasets, and cloud compute. ## Quick Start ### Installation ```bash -git clone git@github.com:huggingface/ml-intern.git -cd ml-intern -uv sync -uv tool install -e . +# Clone the repository +git clone git@github.com:huggingface/hf_agent.git +cd hf_agent ``` -#### That's it. Now `ml-intern` works from any directory: - -```bash -ml-intern -``` - -Create a `.env` file in the project root (or export these in your shell): - -```bash -ANTHROPIC_API_KEY= # if using anthropic models -OPENAI_API_KEY= # if using openai models -HF_TOKEN= -GITHUB_TOKEN= -``` -If no `HF_TOKEN` is set, the CLI will prompt you to paste one on first launch. To get a GITHUB_TOKEN follow the tutorial [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token). - -### Usage - -**Interactive mode** (start a chat session): - +#### Install recommended dependencies ```bash -ml-intern +uv sync --extra agent # or uv sync --extra all ``` -**Headless mode** (single prompt, auto-approve): +### Interactive CLI ```bash -ml-intern "fine-tune llama on my dataset" -``` - -**Options:** - -```bash -ml-intern --model anthropic/claude-opus-4-6 "your prompt" -ml-intern --model openai/gpt-5.5 "your prompt" -ml-intern --max-iterations 100 "your prompt" -ml-intern --no-stream "your prompt" -``` - -## Sharing Traces - -Every session is auto-uploaded to your **own private Hugging Face dataset** -in [Claude Code JSONL format](https://huggingface.co/changelog/agent-trace-viewer), -which the HF Agent Trace Viewer auto-detects so you can browse turns, tool -calls, and model responses directly on the Hub. - -By default the dataset is named `{your-hf-username}/ml-intern-sessions` and is -**created private**. You can flip it to public from inside the CLI: - -```bash -/share-traces # show current visibility + dataset URL -/share-traces public # publish (anyone can view) -/share-traces private # lock it back down -``` - -You can also flip visibility from the dataset page on huggingface.co β€” the -agent honours whatever you set there for subsequent uploads. - -To opt out entirely, set in your CLI config (e.g. `configs/cli_agent_config.json` -or `~/.config/ml-intern/cli_agent_config.json`): - -```json -{ "share_traces": false } -``` - -To override the destination repo, set: - -```json -{ "personal_trace_repo_template": "{hf_user}/my-custom-traces" } +uv run python -m agent.main ``` +This starts an interactive chat session with the agent. Type your messages and the agent will respond, using tools as needed. -The shared `smolagents/ml-intern-sessions` dataset is unrelated and only -receives anonymized telemetry rows used by the backend KPI scheduler. +The agent will automatically discover and register all tools from configured MCP servers. -## Supported Gateways - -ML Intern currently supports one-way notification gateways from CLI sessions. -These gateways send out-of-band status updates; they do not accept inbound chat -messages. - -### Slack - -Slack notifications use the Slack Web API to post messages when the agent needs -approval, hits an error, or completes a turn. Create a Slack app with a bot token -that has `chat:write`, invite the bot to the target channel, then set: +### Env Setup ```bash -SLACK_BOT_TOKEN=xoxb-... -SLACK_CHANNEL_ID=C... -``` - -The CLI automatically creates a `slack.default` destination when both variables -are present. Optional environment variables for the env-only default: - -```bash -ML_INTERN_SLACK_NOTIFICATIONS=false -ML_INTERN_SLACK_DESTINATION=slack.ops -ML_INTERN_SLACK_AUTO_EVENTS=approval_required,error,turn_complete -ML_INTERN_SLACK_ALLOW_AGENT_TOOL=true -ML_INTERN_SLACK_ALLOW_AUTO_EVENTS=true -``` - -For a persistent user-level config, put overrides in -`~/.config/ml-intern/cli_agent_config.json` or point `ML_INTERN_CLI_CONFIG` at a -JSON file: - -```json -{ - "messaging": { - "enabled": true, - "auto_event_types": ["approval_required", "error", "turn_complete"], - "destinations": { - "slack.ops": { - "provider": "slack", - "token": "${SLACK_BOT_TOKEN}", - "channel": "${SLACK_CHANNEL_ID}", - "allow_agent_tool": true, - "allow_auto_events": true - } - } - } -} +ANTHROPIC_API_KEY= +HF_TOKEN= +GITHUB_TOKEN= +HF_NAMESPACE= ``` ## Architecture @@ -167,70 +60,62 @@ JSON file: ``` β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ User/CLI β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - β”‚ Operations β”‚ Events - ↓ (user_input, exec_approval, ↑ - submission_queue interrupt, compact, ...) event_queue - β”‚ β”‚ - ↓ β”‚ -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ submission_loop (agent_loop.py) β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ 1. Receive Operation from queue β”‚ β”‚ β”‚ -β”‚ β”‚ 2. Route to handler (run_agent/compact/...) β”‚ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ ↓ β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ Handlers.run_agent() β”‚ β”œβ”€β”€β”€ -β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ Agentic Loop (max 300 iterations) β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ Session β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ ContextManager β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Message history β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ (litellm.Message[]) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Auto-compaction (170k) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Session upload to HF β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ ToolRouter β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ HF docs & research β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ HF repos, datasets, β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ jobs, papers β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ GitHub code search β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ Sandbox & local tools β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ Planning β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ └─ MCP server tools β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ Doom Loop Detector β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β€’ Detects repeated tool patterns β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β€’ Injects corrective prompts β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ Loop: β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 1. LLM call (litellm.acompletion) β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 2. Parse tool_calls[] β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 3. Approval check β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ (jobs, sandbox, destructive ops) β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 4. Execute via ToolRouter β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 5. Add results to ContextManager β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 6. Repeat if tool_calls exist β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”˜ +β”‚ User/CLI β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ User request β”‚ Events + ↓ ↑ + submission_queue event_queue + β”‚ β”‚ + ↓ β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ submission_loop (agent_loop.py) β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ +β”‚ β”‚ 1. Receive Operation from queue β”‚ β”‚ β”‚ +β”‚ β”‚ 2. Route to Handler (run_agent/compact/...) β”‚ β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ +β”‚ ↓ β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ +β”‚ β”‚ Handlers.run_agent() β”‚ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ β”‚ β”‚ β”‚ Emit β”‚ +β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ Events β”‚ +β”‚ β”‚ β”‚ Agentic Loop (max 10 iterations) β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ Session β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ ContextManager β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Message history β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ (litellm.Message[]) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Auto-compaction (180k) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ ToolRouter β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ explore_hf_docs β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ fetch_hf_docs β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ find_hf_api β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ plan_tool β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ hf_jobs* β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ hf_private_repos* β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ github_* (3 tools) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ └─ MCP tools (e.g., β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ model_search, etc.) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ Loop: β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 1. LLM call (litellm.acompletion) β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 2. Parse tool_calls[] β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 3. Execute via ToolRouter β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 4. Add results to ContextManager β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 5. Repeat if tool_calls exist β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ ``` ### Agentic Loop Flow @@ -240,49 +125,61 @@ User Message ↓ [Add to ContextManager] ↓ - ╔═══════════════════════════════════════════╗ - β•‘ Iteration Loop (max 300) β•‘ - β•‘ β•‘ - β•‘ Get messages + tool specs β•‘ - β•‘ ↓ β•‘ - β•‘ litellm.acompletion() β•‘ - β•‘ ↓ β•‘ - β•‘ Has tool_calls? ──No──> Done β•‘ - β•‘ β”‚ β•‘ - β•‘ Yes β•‘ - β•‘ ↓ β•‘ - β•‘ Add assistant msg (with tool_calls) β•‘ - β•‘ ↓ β•‘ - β•‘ Doom loop check β•‘ - β•‘ ↓ β•‘ - β•‘ For each tool_call: β•‘ - β•‘ β€’ Needs approval? ──Yes──> Wait for β•‘ - β•‘ β”‚ user confirm β•‘ - β•‘ No β•‘ - β•‘ ↓ β•‘ - β•‘ β€’ ToolRouter.execute_tool() β•‘ - β•‘ β€’ Add result to ContextManager β•‘ - β•‘ ↓ β•‘ - β•‘ Continue loop ─────────────────┐ β•‘ - β•‘ ↑ β”‚ β•‘ - β•‘ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β•‘ - β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β• + ╔═══════════════════════════════════════╗ + β•‘ Iteration Loop (max 10) β•‘ + β•‘ β•‘ + β•‘ Get messages + tool specs β•‘ + β•‘ ↓ β•‘ + β•‘ litellm.acompletion() β•‘ + β•‘ ↓ β•‘ + β•‘ Has tool_calls? ──No──> Done β•‘ + β•‘ β”‚ β•‘ + β•‘ Yes β•‘ + β•‘ ↓ β•‘ + β•‘ Add assistant msg (with tool_calls) β•‘ + β•‘ ↓ β•‘ + β•‘ For each tool_call: β•‘ + β•‘ β€’ ToolRouter.execute_tool() β•‘ + β•‘ β€’ Add result to ContextManager β•‘ + β•‘ ↓ β•‘ + β•‘ Continue loop ─────────────────┐ β•‘ + β•‘ ↑ β”‚ β•‘ + β•šβ•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β• +``` + +## Project Structure + +``` +agent/ +β”œβ”€β”€ config.py # Configuration models +β”œβ”€β”€ main.py # Interactive CLI entry point +β”œβ”€β”€ prompts/ +β”‚ └── system_prompt.yaml # Agent behavior and personality +β”œβ”€β”€ context_manager/ +β”‚ └── manager.py # Message history & auto-compaction +└── core/ + β”œβ”€β”€ agent_loop.py # Main agent loop and handlers + β”œβ”€β”€ session.py # Session management + β”œβ”€β”€ mcp_client.py # MCP SDK integration + └── tools.py # ToolRouter and built-in tools + +configs/ +└── main_agent_config.json # Model and MCP server configuration + +tests/ # Integration and unit tests +eval/ # Evaluation suite (see eval/README.md) ``` + ## Events The agent emits the following events via `event_queue`: - `processing` - Starting to process user input -- `ready` - Agent is ready for input -- `assistant_chunk` - Streaming token chunk -- `assistant_message` - Complete LLM response text -- `assistant_stream_end` - Token stream finished +- `assistant_message` - LLM response text - `tool_call` - Tool being called with arguments - `tool_output` - Tool execution result -- `tool_log` - Informational tool log message -- `tool_state_change` - Tool execution state transition -- `approval_required` - Requesting user approval for sensitive operations +- `approval_request` - Requesting user approval for sensitive operations - `turn_complete` - Agent finished processing - `error` - Error occurred during processing - `interrupted` - Agent was interrupted @@ -317,8 +214,7 @@ def create_builtin_tools() -> list[ToolSpec]: ### Adding MCP Servers -Edit `configs/cli_agent_config.json` for CLI defaults, or -`configs/frontend_agent_config.json` for web-session defaults: +Edit `configs/main_agent_config.json`: ```json { diff --git a/REVIEW.md b/REVIEW.md deleted file mode 100644 index 3f08c60a8a43022497a58e3d547433bd5ecb5c24..0000000000000000000000000000000000000000 --- a/REVIEW.md +++ /dev/null @@ -1,135 +0,0 @@ -# Review instructions - -These rules override the default review guidance. Treat them as the highest-priority -instruction block for any review of this repo. If something here contradicts a more -generic review habit, follow these. - -## Severity levels - -Every finding carries one of three priority labels: - -- **P0** β€” blocks merge. -- **P1** β€” worth fixing, not blocking. -- **P2** β€” informational. - -Write labels as plain text (`P0`, `P1`, `P2`) in finding headers. Do not use -emoji or colored markers. Use judgment on what belongs at which level β€” this -repo does not enumerate P0 cases; read the code and decide. - -## Default bias: rigor - -Reviews gate merges. This is an open-source repo that takes PRs from anyone; the -maintainer team is small and relies on the review to catch what they don't have -time to verify themselves. **Default bias is rigor, not speed.** When in doubt -on a P0-class concern, investigate further before deciding whether to flag β€” a -false negative ships a bug to production, a false positive costs the contributor -one round trip. - -Rigor is not nitpicking. The P1 cap, "do not report" skip list, and verification -bar all still apply. Rigor means going deep on a small number of real concerns, -not surfacing a large number of shallow ones. Prefer one well-investigated P0 -over three speculative P1s. - -**Hold the line on P0.** If the author pushes back on a P0 finding without a fix -that actually addresses the root cause, re-state the concern with added -citations. Only accept the pushback if the author points to code or behavior you -missed. Do not soften a P0 because the contributor is polite or new to the repo. - -For P1 and P2: if the author defers or pushes back without fixing, accept it -silently β€” do not re-flag on subsequent commits. P1/P2 are informational; the -author may defer to a follow-up issue at their discretion. - -If Claude and the author repeatedly disagree on the same class of finding, the -signal is that REVIEW.md is missing a rule; note it once in the PR summary as -`suggest-rule: ` and stop. - -## Investigate before posting - -The depth of your analysis determines the strength of your finding. For any -P0-class concern, before writing it up: - -- Read the relevant callers and callees, not just the diff. Use Read and Grep - to open files the diff doesn't touch but the changed code interacts with. -- Trace the full chain end-to-end for routing, auth, and agent-loop findings. - Cite each hop by `file:line`, not just the suspicious line. -- Check whether the codebase already has an established pattern for this kind - of change (`grep` for similar call sites, similar tool definitions, similar - route guards). If the PR introduces a new approach where an established - pattern exists, flag that β€” divergence from the existing pattern is usually a - regression vector even when the new code "works." -- Confirm the specific behavior you're claiming. "This breaks X" must be - grounded in either the code handling X or a test exercising X, not in - inference from naming or structure. - -A finding you "spotted" by scanning the diff is more likely to be a false -positive than a finding you verified by reading the code around it. - -## P1 cap - -Report at most **3** P1 findings per review. If you found more, say "plus N -similar items" in the summary. If everything you found is P1 or below, open the -summary with "No blocking issues." - -## Re-review convergence - -If this PR has already received a Claude review (there is a prior review comment -by the `claude` bot), suppress new P1 findings and post only P0 ones. Do not -re-post P1s that were already flagged on earlier commits. If the author pushed a -fix for a previously flagged issue, acknowledge it in one line rather than -re-flagging. - -## Do not report - -Anything in these paths β€” skip entirely: - -- `frontend/node_modules/**`, `**/*.lock`, `uv.lock`, `package-lock.json` -- `hf_agent.egg-info/**`, `.ruff_cache/**`, `.pytest_cache/**`, `.venv/**` -- `session_logs/**`, `reports/**` -- Anything under a `gen/` or `generated/` path - -Anything speculative β€” do not post: - -- "This might be slow" without a concrete complexity claim tied to a specific - input size -- Hypothetical race conditions without a concrete interleaving - -## Dependency PRs - -For PRs whose diff is only a lockfile bump, a `pyproject.toml` change, or a -new dependency, the code rules above don't apply β€” risks shift to provenance -and framing. Every claim in the title or body (CVE IDs, version numbers, -behavior fixes) must match what the diff actually does, and any new -transitive dep needs justification. A PR that lies in its framing is P0 -regardless of whether the code change is safe in isolation. - -## Verification bar - -Every behavior claim in a finding must cite `file:line`. "This breaks X" is not -actionable without a line reference. If you cannot cite a line, do not post -the finding. - -## Summary shape - -Open the review body with a single-line tally and an explicit merge verdict, on -two lines: - -``` -2 P0, 3 P1 -Verdict: changes requested -``` - -Valid verdicts: - -- **Verdict: ready to merge** β€” no P0 findings, contributor can merge as-is - once any CI passes -- **Verdict: changes requested** β€” at least one P0 that must be addressed - before merging -- **Verdict: needs discussion** β€” a design-level concern the maintainer should - weigh in on before the contributor iterates (use sparingly) - -If it's a clean review, write `LGTM` followed by `Verdict: ready to merge`. - -Then a **What I checked** bullet list β€” one line per major area you examined, -regardless of whether you found anything. This gives the maintainer visible -coverage at a glance and lets them decide whether to spot-check areas you -didn't touch. diff --git a/agent/__init__.py b/agent/__init__.py index 2e301c8d7b97df90efb932a3685a5c401326232e..3528882f8728ddce586748e8256755fe5b2ea6ad 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -2,20 +2,6 @@ HF Agent - Main agent module """ -import litellm - -# Global LiteLLM behavior β€” set once at package import so both CLI and -# backend entries share the same config. -# drop_params: quietly drop unsupported params rather than raising -# suppress_debug_info: hide the noisy "Give Feedback" banner on errors -# modify_params: let LiteLLM patch Anthropic's tool-call requirements -# (synthesize a dummy tool spec when we call completion on a history -# that contains tool_calls but aren't passing `tools=` β€” happens -# during summarization / session seeding). -litellm.drop_params = True -litellm.suppress_debug_info = True -litellm.modify_params = True - -from agent.core.agent_loop import submission_loop # noqa: E402 +from agent.core.agent_loop import submission_loop __all__ = ["submission_loop"] diff --git a/agent/config.py b/agent/config.py index 35b095c328fe64b53eb51ef5126ebec7e6f546e4..0682c17b47b4d4755e46ab4e239df16a4bd388fb 100644 --- a/agent/config.py +++ b/agent/config.py @@ -1,7 +1,6 @@ import json import os import re -from pathlib import Path from typing import Any, Union from dotenv import load_dotenv @@ -11,14 +10,9 @@ from fastmcp.mcp_config import ( ) from pydantic import BaseModel -from agent.messaging.models import MessagingConfig - # These two are the canonical server config types for MCP servers. MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer] -# Project root: two levels up from this file (agent/config.py -> project root) -_PROJECT_ROOT = Path(__file__).resolve().parent.parent - class Config(BaseModel): """Configuration manager""" @@ -26,20 +20,8 @@ class Config(BaseModel): model_name: str mcpServers: dict[str, MCPServerConfig] = {} save_sessions: bool = True - session_dataset_repo: str = "smolagents/ml-intern-sessions" - # Per-user private dataset that mirrors each session in Claude Code JSONL - # format so the HF Agent Trace Viewer auto-renders it - # (https://huggingface.co/changelog/agent-trace-viewer). Created private - # on first use; user flips it public via /share-traces. ``{hf_user}`` is - # substituted at upload time from the authenticated HF username. - share_traces: bool = True - personal_trace_repo_template: str = "{hf_user}/ml-intern-sessions" - auto_save_interval: int = 1 # Save every N user turns (0 = disabled) - # Mid-turn heartbeat: save + upload every N seconds while events are being - # emitted. Guards against losing trace data on long-running turns that - # crash before turn_complete (e.g. a multi-hour hf_jobs wait that OOMs). - # 0 = disabled. Consumed by agent.core.telemetry.HeartbeatSaver. - heartbeat_interval_s: int = 60 + session_dataset_repo: str = "akseljoonas/hf-agent-sessions" + auto_save_interval: int = 3 # Save every N user turns (0 = disabled) yolo_mode: bool = False # Auto-approve all tool calls without confirmation max_iterations: int = 300 # Max LLM calls per agent turn (-1 = unlimited) @@ -47,118 +29,6 @@ class Config(BaseModel): confirm_cpu_jobs: bool = True auto_file_upload: bool = False - # Reasoning effort *preference* β€” the ceiling the user wants. The probe - # on `/model` walks a cascade down from here (``max`` β†’ ``xhigh`` β†’ ``high`` - # β†’ …) and caches per-model what the provider actually accepted in - # ``Session.model_effective_effort``. Default ``max`` because we'd rather - # burn tokens thinking than ship a wrong ML recipe; the cascade lands on - # whichever level the model supports (``high`` for GPT-5 / HF router, - # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off. - # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max" - reasoning_effort: str | None = "max" - messaging: MessagingConfig = MessagingConfig() - - -USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG" -DEFAULT_USER_CONFIG_PATH = ( - Path.home() / ".config" / "ml-intern" / "cli_agent_config.json" -) -SLACK_DEFAULT_DESTINATION = "slack.default" -SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"] - - -def _deep_merge_config( - base: dict[str, Any], override: dict[str, Any] -) -> dict[str, Any]: - merged = dict(base) - for key, value in override.items(): - current = merged.get(key) - if isinstance(current, dict) and isinstance(value, dict): - merged[key] = _deep_merge_config(current, value) - else: - merged[key] = value - return merged - - -def _load_json_config(path: Path) -> dict[str, Any]: - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - if not isinstance(data, dict): - raise ValueError(f"Config file {path} must contain a JSON object") - return data - - -def _load_user_config() -> dict[str, Any]: - raw_path = os.environ.get(USER_CONFIG_ENV_VAR) - if raw_path: - path = Path(raw_path).expanduser() - if not path.exists(): - raise FileNotFoundError( - f"{USER_CONFIG_ENV_VAR} points to missing config file: {path}" - ) - return _load_json_config(path) - - if DEFAULT_USER_CONFIG_PATH.exists(): - return _load_json_config(DEFAULT_USER_CONFIG_PATH) - return {} - - -def _env_bool(name: str, default: bool) -> bool: - value = os.environ.get(name) - if value is None: - return default - normalized = value.strip().lower() - if normalized in {"1", "true", "yes", "on"}: - return True - if normalized in {"0", "false", "no", "off"}: - return False - return default - - -def _env_list(name: str) -> list[str] | None: - value = os.environ.get(name) - if value is None: - return None - return [item.strip() for item in value.split(",") if item.strip()] - - -def apply_slack_user_defaults(raw_config: dict[str, Any]) -> dict[str, Any]: - """Enable a default Slack destination from user env vars, when present.""" - if not _env_bool("ML_INTERN_SLACK_NOTIFICATIONS", True): - return raw_config - - token = os.environ.get("SLACK_BOT_TOKEN") - channel = os.environ.get("SLACK_CHANNEL_ID") or os.environ.get("SLACK_CHANNEL") - if not token or not channel: - return raw_config - - config = dict(raw_config) - messaging = dict(config.get("messaging") or {}) - destinations = dict(messaging.get("destinations") or {}) - destination_name = ( - os.environ.get("ML_INTERN_SLACK_DESTINATION") or SLACK_DEFAULT_DESTINATION - ).strip() - - if destination_name not in destinations: - destinations[destination_name] = { - "provider": "slack", - "token": token, - "channel": channel, - "allow_agent_tool": _env_bool("ML_INTERN_SLACK_ALLOW_AGENT_TOOL", True), - "allow_auto_events": _env_bool("ML_INTERN_SLACK_ALLOW_AUTO_EVENTS", True), - } - - auto_events = _env_list("ML_INTERN_SLACK_AUTO_EVENTS") - if auto_events is not None: - messaging["auto_event_types"] = auto_events - elif "auto_event_types" not in messaging: - messaging["auto_event_types"] = SLACK_DEFAULT_AUTO_EVENT_TYPES - - messaging["enabled"] = True - messaging["destinations"] = destinations - config["messaging"] = messaging - return config - def substitute_env_vars(obj: Any) -> Any: """ @@ -197,25 +67,18 @@ def substitute_env_vars(obj: Any) -> Any: return obj -def load_config( - config_path: str = "config.json", - include_user_defaults: bool = False, -) -> Config: +def load_config(config_path: str = "config.json") -> Config: """ Load configuration with environment variable substitution. Use ${VAR_NAME} in your JSON for any secret. Automatically loads from .env file. """ - # Load .env from project root first (so it works from any directory), - # then CWD .env can override if present - load_dotenv(_PROJECT_ROOT / ".env") - load_dotenv(override=False) - - raw_config = _load_json_config(Path(config_path)) - if include_user_defaults: - raw_config = _deep_merge_config(raw_config, _load_user_config()) - raw_config = apply_slack_user_defaults(raw_config) + # Load environment variables from .env file + load_dotenv() + + with open(config_path, "r") as f: + raw_config = json.load(f) config_with_env = substitute_env_vars(raw_config) return Config.model_validate(config_with_env) diff --git a/agent/context_manager/manager.py b/agent/context_manager/manager.py index 85e96af0f6f3fa6d0426acddcd308281d502558b..f13c26220a27a01c35e0753f495832577bb1832b 100644 --- a/agent/context_manager/manager.py +++ b/agent/context_manager/manager.py @@ -3,7 +3,7 @@ Context management for conversation history """ import logging -import time +import os import zoneinfo from datetime import datetime from pathlib import Path @@ -13,8 +13,6 @@ import yaml from jinja2 import Template from litellm import Message, acompletion -from agent.core.prompt_caching import with_prompt_caching - logger = logging.getLogger(__name__) _HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2" @@ -70,113 +68,12 @@ def _get_hf_username(hf_token: str | None = None) -> str: return "unknown" -_COMPACT_PROMPT = ( - "Please provide a concise summary of the conversation above, focusing on " - "key decisions, the 'why' behind the decisions, problems solved, and " - "important context needed for developing further. Your summary will be " - "given to someone who has never worked on this project before and they " - "will be have to be filled in." -) - -# Per-message ceiling. If a single message in the "untouched" tail is larger -# than this, compaction can't recover even after summarizing the middle β€” -# producing the infinite compaction loop seen 2026-05-03 in pod logs (200k -# context shrinks to 200k+ because one tool output is 80k tokens). We replace -# such messages with a placeholder before compaction runs. -_MAX_TOKENS_PER_MESSAGE = 50_000 - - -class CompactionFailedError(Exception): - """Raised when compaction can't reduce context below the threshold. - - Typically means an individual preserved message (system, first user, or - untouched tail) exceeds what truncation can fix in one pass. The caller - must terminate the session β€” retrying produces an infinite loop that - burns Bedrock budget for free (~$3 per re-attempt on Opus). - """ - - -# Used when seeding a brand-new session from prior browser-cached messages. -# Here we're writing a note to *ourselves* β€” so preserve the tool-call trail, -# files produced, and planned next steps in first person. Optimized for -# continuity, not brevity. -_RESTORE_PROMPT = ( - "You're about to be restored into a fresh session with no memory of the " - "conversation above. Write a first-person note to your future self so " - "you can continue right where you left off. Include:\n" - " β€’ What the user originally asked for and what progress you've made.\n" - " β€’ Every tool you called, with arguments and a one-line result summary.\n" - " β€’ Any code, files, scripts, or artifacts you produced (with paths).\n" - " β€’ Key decisions and the reasoning behind them.\n" - " β€’ What you were planning to do next.\n\n" - "Don't be cute. Be specific. This is the only context you'll have." -) - - -async def summarize_messages( - messages: list[Message], - model_name: str, - hf_token: str | None = None, - max_tokens: int = 2000, - tool_specs: list[dict] | None = None, - prompt: str = _COMPACT_PROMPT, - session: Any = None, - kind: str = "compaction", -) -> tuple[str, int]: - """Run a summarization prompt against a list of messages. - - ``prompt`` defaults to the compaction prompt (terse, decision-focused). - Callers seeding a new session after a restart should pass ``_RESTORE_PROMPT`` - instead β€” it preserves the tool-call trail so the agent can answer - follow-up questions about what it did. - - ``session`` is optional; when provided, the call is recorded via - ``telemetry.record_llm_call`` so its cost lands in the session's - ``total_cost_usd``. Without it, the call still happens but is - invisible in telemetry β€” which used to be the case for every - compaction call until 2026-04-29 (~30-50% of Bedrock spend was - attributed to this single source of dark cost). - - Returns ``(summary_text, completion_tokens)``. - """ - from agent.core.llm_params import _resolve_llm_params - - prompt_messages = list(messages) + [Message(role="user", content=prompt)] - llm_params = _resolve_llm_params(model_name, hf_token, reasoning_effort="high") - prompt_messages, tool_specs = with_prompt_caching( - prompt_messages, tool_specs, llm_params.get("model") - ) - _t0 = time.monotonic() - response = await acompletion( - messages=prompt_messages, - max_completion_tokens=max_tokens, - tools=tool_specs, - **llm_params, - ) - if session is not None: - from agent.core import telemetry - - await telemetry.record_llm_call( - session, - model=model_name, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind=kind, - ) - summary = response.choices[0].message.content or "" - completion_tokens = response.usage.completion_tokens if response.usage else 0 - return summary, completion_tokens - - class ContextManager: """Manages conversation context and message history for the agent""" def __init__( self, - model_max_tokens: int = 180_000, + max_context: int = 180_000, compact_size: float = 0.1, untouched_messages: int = 5, tool_specs: list[dict[str, Any]] | None = None, @@ -190,18 +87,11 @@ class ContextManager: hf_token=hf_token, local_mode=local_mode, ) - # The model's real input-token ceiling (from litellm.get_model_info). - # Compaction triggers at _COMPACT_THRESHOLD_RATIO below it β€” see - # the compaction_threshold property. - self.model_max_tokens = model_max_tokens - self.compact_size = int(model_max_tokens * compact_size) - # Running count of tokens the last LLM call reported. Drives the - # compaction gate; updated in add_message() with each response's - # usage.total_tokens. - self.running_context_usage = 0 + self.max_context = max_context - 10000 + self.compact_size = int(max_context * compact_size) + self.context_length = 0 # Updated after each LLM call with actual usage self.untouched_messages = untouched_messages self.items: list[Message] = [Message(role="system", content=self.system_prompt)] - self.on_message_added = None def _load_system_prompt( self, @@ -236,7 +126,6 @@ class ContextManager: # CLI-specific context for local mode if local_mode: import os - cwd = os.getcwd() local_context = ( f"\n\n# CLI / Local mode\n\n" @@ -260,10 +149,8 @@ class ContextManager: def add_message(self, message: Message, token_count: int = None) -> None: """Add a message to the history""" if token_count: - self.running_context_usage = token_count + self.context_length = token_count self.items.append(message) - if self.on_message_added: - self.on_message_added(message) def get_messages(self) -> list[Message]: """Get all messages for sending to LLM. @@ -298,53 +185,45 @@ class ContextManager: def _patch_dangling_tool_calls(self) -> None: """Add stub tool results for any tool_calls that lack a matching result. - Ensures each assistant message's tool_calls are followed immediately - by matching tool-result messages. This has to work across the whole - history, not just the most recent turn, because a cancelled tool use - in an earlier turn can still poison the next provider request. + Scans backwards to find the last assistant message with tool_calls, + which may not be items[-1] if some tool results were already added. """ if not self.items: return - i = 0 - while i < len(self.items): + # Find the last assistant message with tool_calls + assistant_msg = None + for i in range(len(self.items) - 1, -1, -1): msg = self.items[i] - if getattr(msg, "role", None) != "assistant" or not getattr( + if getattr(msg, "role", None) == "assistant" and getattr( msg, "tool_calls", None ): - i += 1 - continue - - self._normalize_tool_calls(msg) - - # Consume the contiguous tool-result block that immediately follows - # this assistant message. Any missing tool ids must be inserted - # before the next non-tool message to satisfy provider ordering. - j = i + 1 - immediate_ids: set[str | None] = set() - while ( - j < len(self.items) and getattr(self.items[j], "role", None) == "tool" - ): - immediate_ids.add(getattr(self.items[j], "tool_call_id", None)) - j += 1 - - missing: list[Message] = [] - for tc in msg.tool_calls: - if tc.id not in immediate_ids: - missing.append( - Message( - role="tool", - content="Tool was not executed (interrupted or error).", - tool_call_id=tc.id, - name=tc.function.name, - ) - ) + assistant_msg = msg + break + # Stop scanning once we hit a user message β€” anything before + # that belongs to a previous (complete) turn. + if getattr(msg, "role", None) == "user": + break - if missing: - self.items[j:j] = missing - j += len(missing) + if not assistant_msg: + return - i = j + self._normalize_tool_calls(assistant_msg) + answered_ids = { + getattr(m, "tool_call_id", None) + for m in self.items + if getattr(m, "role", None) == "tool" + } + for tc in assistant_msg.tool_calls: + if tc.id not in answered_ids: + self.items.append( + Message( + role="tool", + content="Tool was not executed (interrupted or error).", + tool_call_id=tc.id, + name=tc.function.name, + ) + ) def undo_last_turn(self) -> bool: """Remove the last complete turn (user msg + all assistant/tool msgs that follow). @@ -383,119 +262,11 @@ class ContextManager: count += 1 return False - # Compaction fires at 90% of model_max_tokens so there's headroom for - # the next turn's prompt + response before we actually hit the ceiling. - _COMPACT_THRESHOLD_RATIO = 0.9 - - @property - def compaction_threshold(self) -> int: - """Token count at which `compact()` kicks in.""" - return int(self.model_max_tokens * self._COMPACT_THRESHOLD_RATIO) - - @property - def needs_compaction(self) -> bool: - return self.running_context_usage > self.compaction_threshold and bool( - self.items - ) - - def _truncate_oversized( - self, messages: list[Message], model_name: str - ) -> list[Message]: - """Replace any message > _MAX_TOKENS_PER_MESSAGE with a placeholder. - - These are typically tool outputs (CSV dumps, file contents) sitting in - the untouched tail or first-user position that compaction can't shrink - β€” they pass through verbatim, keeping context above threshold and - triggering an infinite compaction retry loop. - """ - from litellm import token_counter - - out: list[Message] = [] - for msg in messages: - # System messages are sacred β€” they're the agent's instructions. - # In edge cases (items < untouched_messages), the slice math in - # compact() can let items[0] (the system message) leak into the - # recent_messages list. Defense-in-depth: never truncate it. - if msg.role == "system": - out.append(msg) - continue - try: - n = token_counter(model=model_name, messages=[msg.model_dump()]) - except Exception: - # token_counter occasionally fails on edge-case content; - # don't drop the message, just keep it as-is. - out.append(msg) - continue - if n <= _MAX_TOKENS_PER_MESSAGE: - out.append(msg) - continue - placeholder = ( - f"[truncated for compaction β€” original was {n} tokens, " - f"removed to keep context under {self.compaction_threshold} tokens]" - ) - logger.warning( - "Truncating %s message: %d -> %d tokens for compaction", - msg.role, - n, - len(placeholder) // 4, - ) - # Preserve all known assistant-side fields (tool_calls, thinking_blocks, - # reasoning_content, provider_specific_fields) even when content is - # replaced. Anthropic extended-thinking models reject the next request - # with "Invalid signature in thinking block" if thinking_blocks is - # dropped from a prior assistant message. - kept = { - k: getattr(msg, k, None) - for k in ( - "tool_call_id", - "tool_calls", - "name", - "thinking_blocks", - "reasoning_content", - "provider_specific_fields", - ) - if getattr(msg, k, None) is not None - } - out.append(Message(role=msg.role, content=placeholder, **kept)) - return out - - def _recompute_usage(self, model_name: str) -> None: - """Refresh ``running_context_usage`` from current items via real tokenizer.""" - from litellm import token_counter - - try: - self.running_context_usage = token_counter( - model=model_name, - messages=[m.model_dump() for m in self.items], - ) - except Exception as e: - logger.warning("token_counter failed (%s); rough estimate", e) - # Rough fallback: 4 chars per token. - self.running_context_usage = ( - sum(len(getattr(m, "content", "") or "") for m in self.items) // 4 - ) - async def compact( - self, - model_name: str, - tool_specs: list[dict] | None = None, - hf_token: str | None = None, - session: Any = None, + self, model_name: str, tool_specs: list[dict] | None = None ) -> None: - """Remove old messages to keep history under target size. - - ``session`` is optional β€” if passed, the underlying summarization - LLM call is recorded via ``telemetry.record_llm_call(kind= - "compaction")`` so its cost shows up in ``total_cost_usd``. - - Raises ``CompactionFailedError`` if the post-compact context is still - over the threshold. This happens when a preserved message (typically - a giant tool output stuck in the untouched tail) is too large for - truncation to fix. The caller must terminate the session β€” retrying - is what caused the 2026-05-03 infinite-compaction-loop pattern that - burned Bedrock budget invisibly. - """ - if not self.needs_compaction: + """Remove old messages to keep history under target size""" + if (self.context_length <= self.max_context) or not self.items: return system_msg = ( @@ -517,60 +288,33 @@ class ContextManager: idx = len(self.items) - self.untouched_messages while idx > 1 and self.items[idx].role != "user": idx -= 1 - # The real invariant is "idx must be strictly after first_user_idx, - # otherwise recent_messages overlaps with the messages we put in - # head". The walk-back's `idx > 1` guard is necessary (no system in - # recent) but insufficient (first_user is also in head and would be - # duplicated). Anthropic API rejects two consecutive user messages - # with a 400 β€” bot review on PR #213 caught this on the second clamp - # iteration. - if idx <= first_user_idx: - idx = first_user_idx + 1 recent_messages = self.items[idx:] - messages_to_summarize = self.items[first_user_idx + 1 : idx] - - # Truncate any message that's larger than _MAX_TOKENS_PER_MESSAGE in - # the parts we PRESERVE through compaction (first_user + recent_tail). - # These are the only places where individual messages can defeat - # compaction by being intrinsically too large. Messages in - # ``messages_to_summarize`` are folded into the summary, so their size - # doesn't matter on its own. - if first_user_msg is not None: - truncated = self._truncate_oversized([first_user_msg], model_name) - first_user_msg = truncated[0] - recent_messages = self._truncate_oversized(recent_messages, model_name) - - # If there's nothing to summarize but the preserved messages are now - # truncated and small, just rebuild and recompute. This is rare but - # avoids returning silently with the old (over-threshold) state. + messages_to_summarize = self.items[first_user_idx + 1:idx] + + # improbable, messages would have to very long if not messages_to_summarize: - head = [system_msg] if system_msg else [] - if first_user_msg: - head.append(first_user_msg) - self.items = head + recent_messages - self._recompute_usage(model_name) - if self.running_context_usage > self.compaction_threshold: - raise CompactionFailedError( - f"Nothing to summarize but context ({self.running_context_usage}) " - f"still over threshold ({self.compaction_threshold}) after truncation. " - f"System prompt or first user message likely exceeds the budget." - ) return - summary, completion_tokens = await summarize_messages( - messages_to_summarize, - model_name=model_name, - hf_token=hf_token, - max_tokens=self.compact_size, - tool_specs=tool_specs, - prompt=_COMPACT_PROMPT, - session=session, - kind="compaction", + messages_to_summarize.append( + Message( + role="user", + content="Please provide a concise summary of the conversation above, focusing on key decisions, the 'why' behind the decisions, problems solved, and important context needed for developing further. Your summary will be given to someone who has never worked on this project before and they will be have to be filled in.", + ) + ) + + hf_key = os.environ.get("INFERENCE_TOKEN") + response = await acompletion( + model=model_name, + messages=messages_to_summarize, + max_completion_tokens=self.compact_size, + tools=tool_specs, + api_key=hf_key + if hf_key and model_name.startswith("huggingface/") + else None, ) summarized_message = Message( - role="assistant", - content=summary, + role="assistant", content=response.choices[0].message.content ) # Reconstruct: system + first user msg + summary + recent messages @@ -579,19 +323,6 @@ class ContextManager: head.append(first_user_msg) self.items = head + [summarized_message] + recent_messages - self._recompute_usage(model_name) - - # Hard verify: if compaction didn't bring us below the threshold even - # after truncating oversized preserved messages, retrying just burns - # Bedrock budget on the same useless compaction call. Raise so the - # caller can terminate the session cleanly. Pre-2026-05-04, the - # caller looped indefinitely (~$3/Opus retry) until the pod was - # killed β€” invisible to the dataset because the session never - # finished cleanly. - if self.running_context_usage > self.compaction_threshold: - raise CompactionFailedError( - f"Compaction ineffective: {self.running_context_usage} tokens " - f"still over threshold {self.compaction_threshold} after summarize " - f"and truncation. Likely the system prompt + first user + summary " - f"+ truncated tail still exceeds budget." - ) + self.context_length = ( + len(self.system_prompt) // 4 + response.usage.completion_tokens + ) diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 0eaa6e9d64b7bdb6e1addc09a4a837e80dff2cd2..dd12f6f2be1c760e3bc3be3f8acdf3e23098dba3 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -5,94 +5,55 @@ Main agent implementation with integrated tool system and MCP support import asyncio import json import logging -import time -from dataclasses import dataclass, field -from typing import Any - -from litellm import ( - ChatCompletionMessageToolCall, - Message, - acompletion, - stream_chunk_builder, -) +import os +from dataclasses import dataclass + +from litellm import ChatCompletionMessageToolCall, Message, acompletion from litellm.exceptions import ContextWindowExceededError from agent.config import Config -from agent.core.approval_policy import ( - is_scheduled_operation, - normalize_tool_operation, -) -from agent.core.cost_estimation import CostEstimate, estimate_tool_cost -from agent.messaging.gateway import NotificationGateway -from agent.core import telemetry from agent.core.doom_loop import check_for_doom_loop -from agent.core.hub_artifacts import start_session_artifact_collection_task -from agent.core.llm_params import _resolve_llm_params -from agent.core.prompt_caching import with_prompt_caching from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter from agent.tools.jobs_tool import CPU_FLAVORS -from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE logger = logging.getLogger(__name__) ToolCall = ChatCompletionMessageToolCall +# Explicit inference token for LLM API calls (separate from user OAuth tokens). +_INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN") -_MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '" -_MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments" - - -def _malformed_tool_name(message: Message) -> str | None: - """Return the tool name for malformed-json tool-result messages.""" - if getattr(message, "role", None) != "tool": - return None - content = getattr(message, "content", None) - if not isinstance(content, str): - return None - if not content.startswith(_MALFORMED_TOOL_PREFIX): - return None - end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX)) - if end == -1: - return None - return content[len(_MALFORMED_TOOL_PREFIX) : end] - - -def _detect_repeated_malformed( - items: list[Message], - threshold: int = 2, -) -> str | None: - """Return the repeated malformed tool name if the tail contains a streak. - - Walk backward over the current conversation tail. A streak counts only - consecutive malformed tool-result messages for the same tool; any other - tool result breaks it. - """ - if threshold <= 0: - return None - - streak_tool: str | None = None - streak = 0 - - for item in reversed(items): - if getattr(item, "role", None) != "tool": - continue - malformed_tool = _malformed_tool_name(item) - if malformed_tool is None: - break - - if streak_tool is None: - streak_tool = malformed_tool - streak = 1 - elif malformed_tool == streak_tool: - streak += 1 - else: - break +def _resolve_hf_router_params(model_name: str) -> dict: + """ + Build LiteLLM kwargs for HuggingFace Router models. - if streak >= threshold: - return streak_tool + api-inference.huggingface.co is deprecated; the new router lives at + router.huggingface.co//v3/openai. LiteLLM's built-in + ``huggingface/`` provider still targets the old endpoint, so we + rewrite model names to ``openai/`` and supply the correct api_base. - return None + Input format: huggingface/// + Example: huggingface/novita/moonshotai/kimi-k2.5 + """ + if not model_name.startswith("huggingface/"): + return {"model": model_name} + + parts = model_name.split( + "/", 2 + ) # ['huggingface', 'novita', 'moonshotai/kimi-k2.5'] + if len(parts) < 3: + return {"model": model_name} + + router_provider = parts[1] + actual_model = parts[2] + api_key = _INFERENCE_API_KEY + + return { + "model": f"openai/{actual_model}", + "api_base": f"https://router.huggingface.co/{router_provider}/v3/openai", + "api_key": api_key, + } def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: @@ -117,42 +78,13 @@ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: return True, None -_IMMEDIATE_HF_JOB_RUNS = {"run", "uv"} - - -@dataclass(frozen=True) -class ApprovalDecision: - requires_approval: bool - auto_approved: bool = False - auto_approval_blocked: bool = False - block_reason: str | None = None - estimated_cost_usd: float | None = None - remaining_cap_usd: float | None = None - billable: bool = False - - -def _operation(tool_args: dict) -> str: - return normalize_tool_operation(tool_args.get("operation")) - - -def _is_immediate_hf_job_run(tool_name: str, tool_args: dict) -> bool: - return tool_name == "hf_jobs" and _operation(tool_args) in _IMMEDIATE_HF_JOB_RUNS - - -def _is_scheduled_hf_job_run(tool_name: str, tool_args: dict) -> bool: - return tool_name == "hf_jobs" and is_scheduled_operation(_operation(tool_args)) - - -def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool: - return tool_name == "sandbox_create" or _is_immediate_hf_job_run( - tool_name, tool_args - ) - - -def _base_needs_approval( +def _needs_approval( tool_name: str, tool_args: dict, config: Config | None = None ) -> bool: - """Check if a tool call requires approval before YOLO policy is applied.""" + """Check if a tool call requires user approval before execution.""" + # Yolo mode: skip all approvals + if config and config.yolo_mode: + return False # If args are malformed, skip approval (validation error will be shown later) args_valid, _ = _validate_tool_args(tool_args) @@ -160,14 +92,11 @@ def _base_needs_approval( return False if tool_name == "sandbox_create": - hardware = tool_args.get("hardware") or DEFAULT_CPU_SANDBOX_HARDWARE - return hardware != DEFAULT_CPU_SANDBOX_HARDWARE + return True if tool_name == "hf_jobs": - operation = _operation(tool_args) - if is_scheduled_operation(operation): - return True - if operation not in _IMMEDIATE_HF_JOB_RUNS: + operation = tool_args.get("operation", "") + if operation not in ["run", "uv", "scheduled run", "scheduled uv"]: return False # Check if this is a CPU-only job @@ -219,405 +148,51 @@ def _base_needs_approval( return False -def _needs_approval( - tool_name: str, tool_args: dict, config: Config | None = None -) -> bool: - """Legacy sync approval predicate used by tests and CLI display helpers.""" - if _is_scheduled_hf_job_run(tool_name, tool_args): - return True - if config and config.yolo_mode: - return False - return _base_needs_approval(tool_name, tool_args, config) - - -def _session_auto_approval_enabled(session: Session | None) -> bool: - return bool(session and getattr(session, "auto_approval_enabled", False)) - - -def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool: - return bool( - (config and config.yolo_mode) or _session_auto_approval_enabled(session) - ) - - -def _remaining_budget_after_reservations( - session: Session | None, reserved_spend_usd: float -) -> float | None: - if not session or getattr(session, "auto_approval_cost_cap_usd", None) is None: - return None - cap = float(getattr(session, "auto_approval_cost_cap_usd") or 0.0) - spent = float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0) - return round(max(0.0, cap - spent - reserved_spend_usd), 4) - - -def _budget_block_reason( - estimate: CostEstimate, - *, - remaining_cap_usd: float | None, -) -> str | None: - if estimate.estimated_cost_usd is None: - return estimate.block_reason or "Could not estimate the cost safely." - if ( - remaining_cap_usd is not None - and estimate.estimated_cost_usd > remaining_cap_usd - ): - return ( - f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds " - f"remaining YOLO cap ${remaining_cap_usd:.2f}." - ) - return None - - -async def _approval_decision( - tool_name: str, - tool_args: dict, - session: Session, - *, - reserved_spend_usd: float = 0.0, -) -> ApprovalDecision: - """Return the approval decision for one parsed tool call.""" - config = session.config - base_requires_approval = _base_needs_approval(tool_name, tool_args, config) - - # Scheduled jobs are recurring/unbounded enough that YOLO never bypasses - # the human confirmation, including legacy config.yolo_mode. - if _is_scheduled_hf_job_run(tool_name, tool_args): - return ApprovalDecision( - requires_approval=True, - auto_approval_blocked=_effective_yolo_enabled(session, config), - block_reason="Scheduled HF jobs always require manual approval.", - ) - - yolo_enabled = _effective_yolo_enabled(session, config) - budgeted_target = _is_budgeted_auto_approval_target(tool_name, tool_args) - - # Cost caps are a session-scoped web policy. Legacy config.yolo_mode - # remains uncapped for CLI/headless, except for scheduled jobs above. - session_yolo_enabled = _session_auto_approval_enabled(session) - if yolo_enabled and budgeted_target and session_yolo_enabled: - estimate = await estimate_tool_cost(tool_name, tool_args, session=session) - remaining = _remaining_budget_after_reservations(session, reserved_spend_usd) - reason = _budget_block_reason(estimate, remaining_cap_usd=remaining) - if reason: - return ApprovalDecision( - requires_approval=True, - auto_approval_blocked=True, - block_reason=reason, - estimated_cost_usd=estimate.estimated_cost_usd, - remaining_cap_usd=remaining, - billable=estimate.billable, - ) - if base_requires_approval: - return ApprovalDecision( - requires_approval=False, - auto_approved=True, - estimated_cost_usd=estimate.estimated_cost_usd, - remaining_cap_usd=remaining, - billable=estimate.billable, - ) - return ApprovalDecision( - requires_approval=False, - estimated_cost_usd=estimate.estimated_cost_usd, - remaining_cap_usd=remaining, - billable=estimate.billable, - ) - - if base_requires_approval and yolo_enabled: - return ApprovalDecision(requires_approval=False, auto_approved=True) - - return ApprovalDecision(requires_approval=base_requires_approval) - - -def _record_estimated_spend(session: Session, decision: ApprovalDecision) -> None: - if not decision.billable or decision.estimated_cost_usd is None: - return - if hasattr(session, "add_auto_approval_estimated_spend"): - session.add_auto_approval_estimated_spend(decision.estimated_cost_usd) - else: - session.auto_approval_estimated_spend_usd = round( - float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0) - + float(decision.estimated_cost_usd), - 4, - ) - - -async def _record_manual_approved_spend_if_needed( - session: Session, - tool_name: str, - tool_args: dict, -) -> None: - if not _session_auto_approval_enabled(session): - return - if not _is_budgeted_auto_approval_target(tool_name, tool_args): - return - estimate = await estimate_tool_cost(tool_name, tool_args, session=session) - _record_estimated_spend( - session, - ApprovalDecision( - requires_approval=False, - billable=estimate.billable, - estimated_cost_usd=estimate.estimated_cost_usd, - ), - ) - - # -- LLM retry constants -------------------------------------------------- _MAX_LLM_RETRIES = 3 _LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries -_LLM_RATE_LIMIT_RETRY_DELAYS = [30, 60] # exceed Bedrock's ~60s TPM bucket window - - -def _is_rate_limit_error(error: Exception) -> bool: - """Return True for rate-limit / quota-bucket style provider errors.""" - err_str = str(error).lower() - rate_limit_patterns = [ - "429", - "rate limit", - "rate_limit", - "too many requests", - "too many tokens", - "request limit", - "throttl", - ] - return any(pattern in err_str for pattern in rate_limit_patterns) - - -def _is_context_overflow_error(error: Exception) -> bool: - """Return True when the prompt exceeded the model's context window.""" - if isinstance(error, ContextWindowExceededError): - return True - - err_str = str(error).lower() - overflow_patterns = [ - "context window exceeded", - "maximum context length", - "max context length", - "prompt is too long", - "context length exceeded", - "too many input tokens", - "input is too long", - ] - return any(pattern in err_str for pattern in overflow_patterns) - - -def _retry_delay_for(error: Exception, attempt_index: int) -> int | None: - """Return the delay for this retry attempt, or None if it should not retry.""" - if _is_rate_limit_error(error): - schedule = _LLM_RATE_LIMIT_RETRY_DELAYS - elif _is_transient_error(error): - schedule = _LLM_RETRY_DELAYS - else: - return None - - if attempt_index >= len(schedule): - return None - return schedule[attempt_index] def _is_transient_error(error: Exception) -> bool: """Return True for errors that are likely transient and worth retrying.""" err_str = str(error).lower() transient_patterns = [ - "timeout", - "timed out", - "503", - "service unavailable", - "502", - "bad gateway", - "500", - "internal server error", - "overloaded", - "capacity", - "connection reset", - "connection refused", - "connection error", - "eof", - "broken pipe", + "timeout", "timed out", + "429", "rate limit", "rate_limit", + "503", "service unavailable", + "502", "bad gateway", + "500", "internal server error", + "overloaded", "capacity", + "connection reset", "connection refused", "connection error", + "eof", "broken pipe", ] - return _is_rate_limit_error(error) or any( - pattern in err_str for pattern in transient_patterns - ) - - -def _is_effort_config_error(error: Exception) -> bool: - """Catch the two 400s the effort probe also handles β€” thinking - unsupported for this model, or the specific effort level invalid. - - This is our safety net for the case where ``/effort`` was changed - mid-conversation (which clears the probe cache) and the new level - doesn't work for the current model. We heal the cache and retry once. - """ - from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported - - return _is_thinking_unsupported(error) or _is_invalid_effort(error) - - -async def _heal_effort_and_rebuild_params( - session: Session, - error: Exception, - llm_params: dict, -) -> dict: - """Update the session's effort cache based on ``error`` and return new - llm_params. Called only when ``_is_effort_config_error(error)`` is True. - - Two branches: - β€’ thinking-unsupported β†’ cache ``None`` for this model, next call - strips thinking entirely - β€’ invalid-effort β†’ re-run the full cascade probe; the result lands - in the cache - """ - from agent.core.effort_probe import ( - ProbeInconclusive, - _is_thinking_unsupported, - probe_effort, - ) - - model = session.config.model_name - if _is_thinking_unsupported(error): - session.model_effective_effort[model] = None - logger.info("healed: %s doesn't support thinking β€” stripped", model) - else: - try: - outcome = await probe_effort( - model, - session.config.reasoning_effort, - session.hf_token, - session=session, - ) - session.model_effective_effort[model] = outcome.effective_effort - logger.info( - "healed: %s effort cascade β†’ %s", - model, - outcome.effective_effort, - ) - except ProbeInconclusive: - # Transient during healing β€” strip thinking for safety, next - # call will either succeed or surface the real error. - session.model_effective_effort[model] = None - logger.info("healed: %s probe inconclusive β€” stripped", model) - - return _resolve_llm_params( - model, - session.hf_token, - reasoning_effort=session.effective_effort_for(model), - ) - - -def _friendly_error_message(error: Exception) -> str | None: - """Return a user-friendly message for known error types, or None to fall back to traceback.""" - err_str = str(error).lower() - - if ( - "authentication" in err_str - or "unauthorized" in err_str - or "invalid x-api-key" in err_str - ): - return ( - "Authentication failed β€” your API key is missing or invalid.\n\n" - "To fix this, set the API key for your model provider:\n" - " β€’ Anthropic: export ANTHROPIC_API_KEY=sk-...\n" - " β€’ OpenAI: export OPENAI_API_KEY=sk-...\n" - " β€’ HF Router: export HF_TOKEN=hf_...\n\n" - "You can also add it to a .env file in the project root.\n" - "To switch models, use the /model command." - ) - - if "insufficient" in err_str and "credit" in err_str: - return ( - "Insufficient API credits. Please check your account balance " - "at your model provider's dashboard." - ) - - if "not supported by provider" in err_str or "no provider supports" in err_str: - return ( - "The model isn't served by the provider you pinned.\n\n" - "Drop the ':' suffix to let the HF router auto-pick a " - "provider, or use '/model' (no arg) to see which providers host " - "which models." - ) - - if "model_not_found" in err_str or ( - "model" in err_str and ("not found" in err_str or "does not exist" in err_str) - ): - return ( - "Model not found. Use '/model' to list suggestions, or paste an " - "HF model id like 'MiniMaxAI/MiniMax-M2.7'. Availability is shown " - "when you switch." - ) - - return None + return any(pattern in err_str for pattern in transient_patterns) async def _compact_and_notify(session: Session) -> None: - """Run compaction and send event if context was reduced. - - Catches ``CompactionFailedError`` and ends the session cleanly instead - of letting the caller retry. Pre-2026-05-04 the caller looped on - ContextWindowExceededError β†’ compact β†’ re-trigger, burning Bedrock - budget at ~$3/Opus retry while the session never reached the upload - path (so the cost was invisible in the dataset). - """ - from agent.context_manager.manager import CompactionFailedError - - cm = session.context_manager - old_usage = cm.running_context_usage + """Run compaction and send event if context was reduced.""" + old_length = session.context_manager.context_length + max_ctx = session.context_manager.max_context logger.debug( - "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s", - old_usage, - cm.model_max_tokens, - cm.compaction_threshold, - cm.needs_compaction, + "Compaction check: context_length=%d, max_context=%d, needs_compact=%s", + old_length, max_ctx, old_length > max_ctx, ) - try: - await cm.compact( - model_name=session.config.model_name, - tool_specs=session.tool_router.get_tool_specs_for_llm(), - hf_token=session.hf_token, - session=session, - ) - except CompactionFailedError as e: - logger.error( - "Compaction failed for session %s: %s β€” terminating session", - session.session_id, - e, - ) - # Persist the failure event so the dataset has a record of WHY this - # session ended (and the cost it incurred up to that point) even if - # save_and_upload_detached has issues downstream. - await session.send_event( - Event( - event_type="session_terminated", - data={ - "reason": "compaction_failed", - "context_usage": cm.running_context_usage, - "context_threshold": cm.compaction_threshold, - "error": str(e)[:300], - "user_message": ( - "Your conversation has grown too large to continue. " - "The work you've done is saved β€” start a new session to keep going." - ), - }, - ) - ) - # Stop the agent loop; the finally in _run_session will fire - # cleanup_sandbox + save_trajectory so the dataset captures - # everything that did happen. - session.is_running = False - return - - new_usage = cm.running_context_usage - if new_usage != old_usage: + tool_specs = session.tool_router.get_tool_specs_for_llm() + await session.context_manager.compact( + model_name=session.config.model_name, + tool_specs=tool_specs, + ) + new_length = session.context_manager.context_length + if new_length != old_length: logger.warning( "Context compacted: %d -> %d tokens (max=%d, %d messages)", - old_usage, - new_usage, - cm.model_max_tokens, - len(cm.items), + old_length, new_length, max_ctx, + len(session.context_manager.items), ) await session.send_event( Event( event_type="compacted", - data={"old_tokens": old_usage, "new_tokens": new_usage}, + data={"old_tokens": old_length, "new_tokens": new_length}, ) ) @@ -651,171 +226,15 @@ async def _cleanup_on_cancel(session: Session) -> None: @dataclass class LLMResult: """Result from an LLM call (streaming or non-streaming).""" - content: str | None tool_calls_acc: dict[int, dict] token_count: int finish_reason: str | None - usage: dict = field(default_factory=dict) - thinking_blocks: list[dict[str, Any]] | None = None - reasoning_content: str | None = None - - -def _extract_thinking_state( - message: Any, -) -> tuple[list[dict[str, Any]] | None, str | None]: - """Return provider reasoning fields that must be replayed after tool calls.""" - provider_fields = getattr(message, "provider_specific_fields", None) - if not isinstance(provider_fields, dict): - provider_fields = {} - - thinking_blocks = ( - getattr(message, "thinking_blocks", None) - or provider_fields.get("thinking_blocks") - or None - ) - reasoning_content = ( - getattr(message, "reasoning_content", None) - or provider_fields.get("reasoning_content") - or None - ) - return thinking_blocks, reasoning_content - - -def _should_replay_thinking_state(model_name: str | None) -> bool: - """Only Anthropic's native adapter accepts replayed thinking metadata.""" - return bool(model_name and model_name.startswith("anthropic/")) - - -def _is_invalid_thinking_signature_error(exc: Exception) -> bool: - """Return True when Anthropic rejected replayed extended-thinking state.""" - text = str(exc) - return ( - "Invalid `signature` in `thinking` block" in text - or "Invalid signature in thinking block" in text - ) - -def _strip_thinking_state_from_messages(messages: list[Any]) -> int: - """Remove replayed thinking metadata from assistant history messages.""" - stripped = 0 - - for message in messages: - role = ( - message.get("role") - if isinstance(message, dict) - else getattr(message, "role", None) - ) - if role != "assistant": - continue - if isinstance(message, dict): - if message.pop("thinking_blocks", None) is not None: - stripped += 1 - if message.pop("reasoning_content", None) is not None: - stripped += 1 - provider_fields = message.get("provider_specific_fields") - content = message.get("content") - else: - if getattr(message, "thinking_blocks", None) is not None: - message.thinking_blocks = None - stripped += 1 - if getattr(message, "reasoning_content", None) is not None: - message.reasoning_content = None - stripped += 1 - provider_fields = getattr(message, "provider_specific_fields", None) - content = getattr(message, "content", None) - - if isinstance(provider_fields, dict): - cleaned_fields = dict(provider_fields) - if cleaned_fields.pop("thinking_blocks", None) is not None: - stripped += 1 - if cleaned_fields.pop("reasoning_content", None) is not None: - stripped += 1 - if cleaned_fields != provider_fields: - if isinstance(message, dict): - message["provider_specific_fields"] = cleaned_fields - else: - message.provider_specific_fields = cleaned_fields - - if isinstance(content, list): - cleaned_content = [ - block - for block in content - if not ( - isinstance(block, dict) - and block.get("type") in {"thinking", "redacted_thinking"} - ) - ] - if len(cleaned_content) != len(content): - stripped += len(content) - len(cleaned_content) - if isinstance(message, dict): - message["content"] = cleaned_content - else: - message.content = cleaned_content - - return stripped - - -async def _maybe_heal_invalid_thinking_signature( - session: Session, - messages: list[Any], - exc: Exception, - *, - already_healed: bool, -) -> bool: - if already_healed or not _is_invalid_thinking_signature_error(exc): - return False - - stripped = _strip_thinking_state_from_messages(messages) - if not stripped: - return False - - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": ( - "Anthropic rejected stale thinking signatures; retrying " - "without replayed thinking metadata." - ), - }, - ) - ) - return True - - -def _assistant_message_from_result( - llm_result: LLMResult, - *, - model_name: str | None, - tool_calls: list[ToolCall] | None = None, -) -> Message: - """Build an assistant history message without dropping reasoning state.""" - kwargs: dict[str, Any] = { - "role": "assistant", - "content": llm_result.content, - } - if tool_calls is not None: - kwargs["tool_calls"] = tool_calls - if _should_replay_thinking_state(model_name): - if llm_result.thinking_blocks: - kwargs["thinking_blocks"] = llm_result.thinking_blocks - if llm_result.reasoning_content: - kwargs["reasoning_content"] = llm_result.reasoning_content - return Message(**kwargs) - - -async def _call_llm_streaming( - session: Session, messages, tools, llm_params -) -> LLMResult: +async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult: """Call the LLM with streaming, emitting assistant_chunk events.""" response = None - _healed_effort = False # one-shot safety net per call - _healed_thinking_signature = False - messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) - t_start = time.monotonic() for _llm_attempt in range(_MAX_LLM_RETRIES): try: response = await acompletion( @@ -831,49 +250,16 @@ async def _call_llm_streaming( except ContextWindowExceededError: raise except Exception as e: - if _is_context_overflow_error(e): - raise ContextWindowExceededError(str(e)) from e - if not _healed_effort and _is_effort_config_error(e): - _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params( - session, e, llm_params - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": "Reasoning effort not supported for this model β€” adjusting and retrying.", - }, - ) - ) - continue - if await _maybe_heal_invalid_thinking_signature( - session, - messages, - e, - already_healed=_healed_thinking_signature, - ): - _healed_thinking_signature = True - continue - _delay = _retry_delay_for(e, _llm_attempt) - if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: + if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): + _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( "Transient LLM error (attempt %d/%d): %s β€” retrying in %ds", - _llm_attempt + 1, - _MAX_LLM_RETRIES, - e, - _delay, - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": f"LLM connection error, retrying in {_delay}s...", - }, - ) + _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, ) + await session.send_event(Event( + event_type="tool_log", + data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, + )) await asyncio.sleep(_delay) continue raise @@ -882,12 +268,8 @@ async def _call_llm_streaming( tool_calls_acc: dict[int, dict] = {} token_count = 0 finish_reason = None - final_usage_chunk = None - chunks = [] - should_replay_thinking = _should_replay_thinking_state(llm_params.get("model")) async for chunk in response: - chunks.append(chunk) if session.is_cancelled: tool_calls_acc.clear() break @@ -896,7 +278,6 @@ async def _call_llm_streaming( if not choice: if hasattr(chunk, "usage") and chunk.usage: token_count = chunk.usage.total_tokens - final_usage_chunk = chunk continue delta = choice.delta @@ -914,66 +295,31 @@ async def _call_llm_streaming( idx = tc_delta.index if idx not in tool_calls_acc: tool_calls_acc[idx] = { - "id": "", - "type": "function", + "id": "", "type": "function", "function": {"name": "", "arguments": ""}, } if tc_delta.id: tool_calls_acc[idx]["id"] = tc_delta.id if tc_delta.function: if tc_delta.function.name: - tool_calls_acc[idx]["function"]["name"] += ( - tc_delta.function.name - ) + tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name if tc_delta.function.arguments: - tool_calls_acc[idx]["function"]["arguments"] += ( - tc_delta.function.arguments - ) + tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments if hasattr(chunk, "usage") and chunk.usage: token_count = chunk.usage.total_tokens - final_usage_chunk = chunk - - usage = await telemetry.record_llm_call( - session, - model=llm_params.get("model", session.config.model_name), - response=final_usage_chunk, - latency_ms=int((time.monotonic() - t_start) * 1000), - finish_reason=finish_reason, - ) - thinking_blocks = None - reasoning_content = None - if chunks and should_replay_thinking: - try: - rebuilt = stream_chunk_builder(chunks, messages=messages) - if rebuilt and getattr(rebuilt, "choices", None): - rebuilt_msg = rebuilt.choices[0].message - thinking_blocks, reasoning_content = _extract_thinking_state( - rebuilt_msg - ) - except Exception: - logger.debug("Failed to rebuild streaming thinking state", exc_info=True) return LLMResult( content=full_content or None, tool_calls_acc=tool_calls_acc, token_count=token_count, finish_reason=finish_reason, - usage=usage, - thinking_blocks=thinking_blocks, - reasoning_content=reasoning_content, ) -async def _call_llm_non_streaming( - session: Session, messages, tools, llm_params -) -> LLMResult: +async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) -> LLMResult: """Call the LLM without streaming, emit assistant_message at the end.""" response = None - _healed_effort = False - _healed_thinking_signature = False - messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) - t_start = time.monotonic() for _llm_attempt in range(_MAX_LLM_RETRIES): try: response = await acompletion( @@ -988,49 +334,16 @@ async def _call_llm_non_streaming( except ContextWindowExceededError: raise except Exception as e: - if _is_context_overflow_error(e): - raise ContextWindowExceededError(str(e)) from e - if not _healed_effort and _is_effort_config_error(e): - _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params( - session, e, llm_params - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": "Reasoning effort not supported for this model β€” adjusting and retrying.", - }, - ) - ) - continue - if await _maybe_heal_invalid_thinking_signature( - session, - messages, - e, - already_healed=_healed_thinking_signature, - ): - _healed_thinking_signature = True - continue - _delay = _retry_delay_for(e, _llm_attempt) - if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: + if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): + _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( "Transient LLM error (attempt %d/%d): %s β€” retrying in %ds", - _llm_attempt + 1, - _MAX_LLM_RETRIES, - e, - _delay, - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": f"LLM connection error, retrying in {_delay}s...", - }, - ) + _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, ) + await session.send_event(Event( + event_type="tool_log", + data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, + )) await asyncio.sleep(_delay) continue raise @@ -1040,7 +353,6 @@ async def _call_llm_non_streaming( content = message.content or None finish_reason = choice.finish_reason token_count = response.usage.total_tokens if response.usage else 0 - thinking_blocks, reasoning_content = _extract_thinking_state(message) # Build tool_calls_acc in the same format as streaming tool_calls_acc: dict[int, dict] = {} @@ -1061,22 +373,11 @@ async def _call_llm_non_streaming( Event(event_type="assistant_message", data={"content": content}) ) - usage = await telemetry.record_llm_call( - session, - model=llm_params.get("model", session.config.model_name), - response=response, - latency_ms=int((time.monotonic() - t_start) * 1000), - finish_reason=finish_reason, - ) - return LLMResult( content=content, tool_calls_acc=tool_calls_acc, token_count=token_count, finish_reason=finish_reason, - usage=usage, - thinking_blocks=thinking_blocks, - reasoning_content=reasoning_content, ) @@ -1123,8 +424,7 @@ class Handlers: @staticmethod async def run_agent( - session: Session, - text: str, + session: Session, text: str, ) -> str | None: """ Handle user input (like user_input_or_turn in codex.rs:1291) @@ -1159,15 +459,8 @@ class Handlers: if session.is_cancelled: break - # Compact before calling the LLM if context is near the limit. - # When _compact_and_notify catches CompactionFailedError it sets - # session.is_running = False; we MUST exit the loop here, otherwise - # the LLM call below fires with an over-threshold context, hits - # ContextWindowExceededError, and we end up looping again on the - # except path β€” exactly the bug this PR is supposed to fix. + # Compact before calling the LLM if context is near the limit await _compact_and_notify(session) - if not session.is_running: - break # Doom-loop detection: break out of repeated tool call patterns doom_prompt = check_for_doom_loop(session.context_manager.items) @@ -1175,28 +468,12 @@ class Handlers: session.context_manager.add_message( Message(role="user", content=doom_prompt) ) - - malformed_tool = _detect_repeated_malformed(session.context_manager.items) - if malformed_tool: - recovery_prompt = ( - "[SYSTEM: Repeated malformed tool arguments detected for " - f"'{malformed_tool}'. Stop retrying the same tool call shape. " - "Use a different strategy that produces smaller, valid JSON. " - "For large file writes, prefer bash with a heredoc or split the " - "edit into multiple smaller tool calls.]" - ) - session.context_manager.add_message( - Message(role="user", content=recovery_prompt) - ) await session.send_event( Event( event_type="tool_log", data={ "tool": "system", - "log": ( - "Repeated malformed tool arguments detected β€” " - f"forcing a different strategy for {malformed_tool}" - ), + "log": "Doom loop detected β€” injecting corrective prompt", }, ) ) @@ -1205,24 +482,11 @@ class Handlers: tools = session.tool_router.get_tool_specs_for_llm() try: # ── Call the LLM (streaming or non-streaming) ── - # Pull the per-model probed effort from the session cache when - # available; fall back to the raw preference for models we - # haven't probed yet (e.g. research sub-model). - llm_params = _resolve_llm_params( - session.config.model_name, - session.hf_token, - reasoning_effort=session.effective_effort_for( - session.config.model_name - ), - ) + llm_params = _resolve_hf_router_params(session.config.model_name) if session.stream: - llm_result = await _call_llm_streaming( - session, messages, tools, llm_params - ) + llm_result = await _call_llm_streaming(session, messages, tools, llm_params) else: - llm_result = await _call_llm_non_streaming( - session, messages, tools, llm_params - ) + llm_result = await _call_llm_non_streaming(session, messages, tools, llm_params) content = llm_result.content tool_calls_acc = llm_result.tool_calls_acc @@ -1254,10 +518,7 @@ class Handlers: " β€’ For other tools: reduce the size of your arguments or use bash." ) if content: - assistant_msg = _assistant_message_from_result( - llm_result, - model_name=llm_params.get("model"), - ) + assistant_msg = Message(role="assistant", content=content) session.context_manager.add_message(assistant_msg, token_count) session.context_manager.add_message( Message(role="user", content=f"[SYSTEM: {truncation_hint}]") @@ -1269,10 +530,7 @@ class Handlers: await session.send_event( Event( event_type="tool_log", - data={ - "tool": "system", - "log": f"Output truncated β€” retrying with smaller content ({dropped_names})", - }, + data={"tool": "system", "log": f"Output truncated β€” retrying with smaller content ({dropped_names})"}, ) ) iteration += 1 @@ -1301,25 +559,36 @@ class Handlers: # If no tool calls, add assistant message and we're done if not tool_calls: - logger.debug( + logger.warning( "Agent loop ending: no tool calls. " "finish_reason=%s, token_count=%d, " - "usage=%d, model_max_tokens=%d, " + "context_length=%d, max_context=%d, " "iteration=%d/%d, " "response_text=%s", finish_reason, token_count, - session.context_manager.running_context_usage, - session.context_manager.model_max_tokens, + session.context_manager.context_length, + session.context_manager.max_context, iteration, max_iterations, (content or "")[:500], ) - if content: - assistant_msg = _assistant_message_from_result( - llm_result, - model_name=llm_params.get("model"), + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": ( + f"Loop exit: no tool calls. " + f"finish_reason={finish_reason}, " + f"tokens={token_count}/{session.context_manager.max_context}, " + f"iter={iteration}/{max_iterations}" + ), + }, ) + ) + if content: + assistant_msg = Message(role="assistant", content=content) session.context_manager.add_message(assistant_msg, token_count) final_response = content break @@ -1335,16 +604,15 @@ class Handlers: except (json.JSONDecodeError, TypeError, ValueError): logger.warning( "Malformed arguments for tool_call %s (%s) β€” skipping", - tc.id, - tc.function.name, + tc.id, tc.function.name, ) tc.function.arguments = "{}" bad_tools.append(tc) # Add assistant message with all tool calls to context - assistant_msg = _assistant_message_from_result( - llm_result, - model_name=llm_params.get("model"), + assistant_msg = Message( + role="assistant", + content=content, tool_calls=tool_calls, ) session.context_manager.add_message(assistant_msg, token_count) @@ -1357,92 +625,48 @@ class Handlers: f"arguments and was NOT executed. Retry with smaller content β€” " f"for 'write', split into multiple smaller writes using 'edit'." ) - session.context_manager.add_message( - Message( - role="tool", - content=error_msg, - tool_call_id=tc.id, - name=tc.function.name, - ) - ) - await session.send_event( - Event( - event_type="tool_call", - data={ - "tool": tc.function.name, - "arguments": {}, - "tool_call_id": tc.id, - }, - ) - ) - await session.send_event( - Event( - event_type="tool_output", - data={ - "tool": tc.function.name, - "tool_call_id": tc.id, - "output": error_msg, - "success": False, - }, - ) - ) + session.context_manager.add_message(Message( + role="tool", + content=error_msg, + tool_call_id=tc.id, + name=tc.function.name, + )) + await session.send_event(Event( + event_type="tool_call", + data={"tool": tc.function.name, "arguments": {}, "tool_call_id": tc.id}, + )) + await session.send_event(Event( + event_type="tool_output", + data={"tool": tc.function.name, "tool_call_id": tc.id, "output": error_msg, "success": False}, + )) # ── Cancellation check: before tool execution ── if session.is_cancelled: break - # Separate good tools into approval-required vs auto-execute. - # Track reserved spend while classifying a batch so two - # auto-approved jobs in one model response cannot jointly - # exceed the remaining session cap. - approval_required_tools: list[ - tuple[ToolCall, str, dict, ApprovalDecision] - ] = [] - non_approval_tools: list[ - tuple[ToolCall, str, dict, ApprovalDecision] - ] = [] - reserved_auto_spend_usd = 0.0 + # Separate good tools into approval-required vs auto-execute + approval_required_tools: list[tuple[ToolCall, str, dict]] = [] + non_approval_tools: list[tuple[ToolCall, str, dict]] = [] for tc, tool_name, tool_args in good_tools: - decision = await _approval_decision( - tool_name, - tool_args, - session, - reserved_spend_usd=reserved_auto_spend_usd, - ) - if decision.requires_approval: - approval_required_tools.append( - (tc, tool_name, tool_args, decision) - ) + if _needs_approval(tool_name, tool_args, session.config): + approval_required_tools.append((tc, tool_name, tool_args)) else: - non_approval_tools.append((tc, tool_name, tool_args, decision)) - if ( - decision.auto_approved - and decision.billable - and decision.estimated_cost_usd is not None - ): - reserved_auto_spend_usd += decision.estimated_cost_usd + non_approval_tools.append((tc, tool_name, tool_args)) # Execute non-approval tools (in parallel when possible) if non_approval_tools: # 1. Validate args upfront parsed_tools: list[ - tuple[ToolCall, str, dict, ApprovalDecision, bool, str] + tuple[ToolCall, str, dict, bool, str] ] = [] - for tc, tool_name, tool_args, decision in non_approval_tools: + for tc, tool_name, tool_args in non_approval_tools: args_valid, error_msg = _validate_tool_args(tool_args) parsed_tools.append( - (tc, tool_name, tool_args, decision, args_valid, error_msg) + (tc, tool_name, tool_args, args_valid, error_msg) ) # 2. Send all tool_call events upfront (so frontend shows them all) - for ( - tc, - tool_name, - tool_args, - _decision, - args_valid, - _, - ) in parsed_tools: + for tc, tool_name, tool_args, args_valid, _ in parsed_tools: if args_valid: await session.send_event( Event( @@ -1460,27 +684,22 @@ class Handlers: tc: ToolCall, name: str, args: dict, - decision: ApprovalDecision, valid: bool, err: str, ) -> tuple[ToolCall, str, dict, str, bool]: if not valid: return (tc, name, args, err, False) - if decision.billable: - _record_estimated_spend(session, decision) out, ok = await session.tool_router.call_tool( - name, args, session=session, tool_call_id=tc.id + name, args, session=session ) return (tc, name, args, out, ok) - gather_task = asyncio.ensure_future( - asyncio.gather( - *[ - _exec_tool(tc, name, args, decision, valid, err) - for tc, name, args, decision, valid, err in parsed_tools - ] - ) - ) + gather_task = asyncio.ensure_future(asyncio.gather( + *[ + _exec_tool(tc, name, args, valid, err) + for tc, name, args, valid, err in parsed_tools + ] + )) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( @@ -1495,18 +714,12 @@ class Handlers: except asyncio.CancelledError: pass # Notify frontend that in-flight tools were cancelled - for tc, name, _args, _decision, valid, _ in parsed_tools: + for tc, name, _args, valid, _ in parsed_tools: if valid: - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": name, - "state": "cancelled", - }, - ) - ) + await session.send_event(Event( + event_type="tool_state_change", + data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"}, + )) await _cleanup_on_cancel(session) break @@ -1539,60 +752,30 @@ class Handlers: if approval_required_tools: # Prepare batch approval data tools_data = [] - blocked_payloads = [] - for tc, tool_name, tool_args, decision in approval_required_tools: + for tc, tool_name, tool_args in approval_required_tools: # Resolve sandbox file paths for hf_jobs scripts so the # frontend can display & edit the actual file content. - if tool_name == "hf_jobs" and isinstance( - tool_args.get("script"), str - ): + if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str): from agent.tools.sandbox_tool import resolve_sandbox_script - sandbox = getattr(session, "sandbox", None) - resolved, _ = await resolve_sandbox_script( - sandbox, tool_args["script"] - ) + resolved, _ = await resolve_sandbox_script(sandbox, tool_args["script"]) if resolved: tool_args = {**tool_args, "script": resolved} - tool_payload = { + tools_data.append({ "tool": tool_name, "arguments": tool_args, "tool_call_id": tc.id, - } - if decision.auto_approval_blocked: - tool_payload.update( - { - "auto_approval_blocked": True, - "block_reason": decision.block_reason, - "estimated_cost_usd": decision.estimated_cost_usd, - "remaining_cap_usd": decision.remaining_cap_usd, - } - ) - blocked_payloads.append(tool_payload) - tools_data.append(tool_payload) - - event_data = {"tools": tools_data, "count": len(tools_data)} - if blocked_payloads: - first = blocked_payloads[0] - event_data.update( - { - "auto_approval_blocked": True, - "block_reason": first.get("block_reason"), - "estimated_cost_usd": first.get("estimated_cost_usd"), - "remaining_cap_usd": first.get("remaining_cap_usd"), - } - ) - await session.send_event( - Event( - event_type="approval_required", - data=event_data, - ) - ) + }) + + await session.send_event(Event( + event_type="approval_required", + data={"tools": tools_data, "count": len(tools_data)}, + )) # Store all approval-requiring tools (ToolCall objects for execution) session.pending_approval = { - "tool_calls": [tc for tc, _, _, _ in approval_required_tools], + "tool_calls": [tc for tc, _, _ in approval_required_tools], } # Return early - wait for EXEC_APPROVAL operation @@ -1601,37 +784,28 @@ class Handlers: iteration += 1 except ContextWindowExceededError: - # Force compact and retry this iteration. - cm = session.context_manager + # Force compact and retry this iteration logger.warning( "ContextWindowExceededError at iteration %d β€” forcing compaction " - "(usage=%d, model_max_tokens=%d, messages=%d)", + "(context_length=%d, max_context=%d, messages=%d)", iteration, - cm.running_context_usage, - cm.model_max_tokens, - len(cm.items), + session.context_manager.context_length, + session.context_manager.max_context, + len(session.context_manager.items), + ) + session.context_manager.context_length = ( + session.context_manager.max_context + 1 ) - cm.running_context_usage = cm.model_max_tokens + 1 await _compact_and_notify(session) - # Same guard as the top of the loop: if compaction couldn't - # bring us under threshold, _compact_and_notify has already - # emitted session_terminated and set is_running=False. Continue - # would just re-call the LLM with the same too-big context. - if not session.is_running: - break continue except Exception as e: import traceback - error_msg = _friendly_error_message(e) - if error_msg is None: - error_msg = str(e) + "\n" + traceback.format_exc() - await session.send_event( Event( event_type="error", - data={"error": error_msg}, + data={"error": str(e) + "\n" + traceback.format_exc()}, ) ) errored = True @@ -1644,12 +818,7 @@ class Handlers: await session.send_event( Event( event_type="turn_complete", - data={ - "history_size": len(session.context_manager.items), - "final_response": final_response - if isinstance(final_response, str) - else None, - }, + data={"history_size": len(session.context_manager.items)}, ) ) @@ -1737,9 +906,6 @@ class Handlers: tool_args["script"] = edited_script was_edited = True logger.info(f"Using user-edited script for {tool_name} ({tc.id})") - selected_namespace = approval_decision.get("namespace") - if selected_namespace and tool_name == "hf_jobs": - tool_args["namespace"] = selected_namespace approved_tasks.append((tc, tool_name, tool_args, was_edited)) else: rejected_tasks.append((tc, tool_name, approval_decision)) @@ -1791,8 +957,6 @@ class Handlers: ) ) - await _record_manual_approved_spend_if_needed(session, tool_name, tool_args) - output, success = await session.tool_router.call_tool( tool_name, tool_args, session=session, tool_call_id=tc.id ) @@ -1801,15 +965,13 @@ class Handlers: # Execute all approved tools concurrently (cancellable) if approved_tasks: - gather_task = asyncio.ensure_future( - asyncio.gather( - *[ - execute_tool(tc, tool_name, tool_args, was_edited) - for tc, tool_name, tool_args, was_edited in approved_tasks - ], - return_exceptions=True, - ) - ) + gather_task = asyncio.ensure_future(asyncio.gather( + *[ + execute_tool(tc, tool_name, tool_args, was_edited) + for tc, tool_name, tool_args, was_edited in approved_tasks + ], + return_exceptions=True, + )) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( @@ -1825,16 +987,10 @@ class Handlers: pass # Notify frontend that approved tools were cancelled for tc, tool_name, _args, _was_edited in approved_tasks: - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": tool_name, - "state": "cancelled", - }, - ) - ) + await session.send_event(Event( + event_type="tool_state_change", + data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"}, + )) await _cleanup_on_cancel(session) await session.send_event(Event(event_type="interrupted")) session.increment_turn() @@ -1968,16 +1124,12 @@ async def process_submission(session: Session, submission) -> bool: async def submission_loop( submission_queue: asyncio.Queue, event_queue: asyncio.Queue, - config: Config, + config: Config | None = None, tool_router: ToolRouter | None = None, session_holder: list | None = None, hf_token: str | None = None, - user_id: str | None = None, local_mode: bool = False, stream: bool = True, - notification_gateway: NotificationGateway | None = None, - notification_destinations: list[str] | None = None, - defer_turn_complete_notification: bool = False, ) -> None: """ Main agent loop - processes submissions and dispatches to handlers. @@ -1986,30 +1138,17 @@ async def submission_loop( # Create session with tool router session = Session( - event_queue, - config=config, - tool_router=tool_router, - hf_token=hf_token, - user_id=user_id, - local_mode=local_mode, - stream=stream, - notification_gateway=notification_gateway, - notification_destinations=notification_destinations, - defer_turn_complete_notification=defer_turn_complete_notification, + event_queue, config=config, tool_router=tool_router, hf_token=hf_token, + local_mode=local_mode, stream=stream, ) if session_holder is not None: session_holder[0] = session - start_session_artifact_collection_task(session, token=hf_token) logger.info("Agent loop started") - # Retry any failed uploads from previous sessions (fire-and-forget). - # Includes the personal trace repo when enabled so a session that failed - # to publish to the user's HF dataset gets a fresh attempt on next run. + # Retry any failed uploads from previous sessions (fire-and-forget) if config and config.save_sessions: Session.retry_failed_uploads_detached( - directory="session_logs", - repo_id=config.session_dataset_repo, - personal_repo_id=session._personal_trace_repo_id(), + directory="session_logs", repo_id=config.session_dataset_repo ) try: @@ -2017,13 +1156,7 @@ async def submission_loop( async with tool_router: # Emit ready event after initialization await session.send_event( - Event( - event_type="ready", - data={ - "message": "Agent initialized", - "tool_count": len(tool_router.tools), - }, - ) + Event(event_type="ready", data={"message": "Agent initialized"}) ) while session.is_running: diff --git a/agent/core/approval_policy.py b/agent/core/approval_policy.py deleted file mode 100644 index 73098ca61dffca66929984bd5b5c34e532106f18..0000000000000000000000000000000000000000 --- a/agent/core/approval_policy.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Shared predicates for approval-gated tool operations.""" - -from typing import Any - - -def normalize_tool_operation(operation: Any) -> str: - return str(operation or "").strip().lower() - - -def is_scheduled_operation(operation: Any) -> bool: - return normalize_tool_operation(operation).startswith("scheduled ") diff --git a/agent/core/cost_estimation.py b/agent/core/cost_estimation.py deleted file mode 100644 index a41ad196efec7495c7ca9141d2f7f3a4f38e6dbd..0000000000000000000000000000000000000000 --- a/agent/core/cost_estimation.py +++ /dev/null @@ -1,282 +0,0 @@ -"""Conservative cost estimates for auto-approved infrastructure actions.""" - -import os -import re -import time -from dataclasses import dataclass -from typing import Any - -import httpx - -OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") -JOBS_HARDWARE_URL = f"{OPENID_PROVIDER_URL}/api/jobs/hardware" -JOBS_PRICE_CACHE_TTL_S = 6 * 60 * 60 - -DEFAULT_JOB_TIMEOUT_HOURS = 0.5 -DEFAULT_SANDBOX_RESERVATION_HOURS = 1.0 - -# Static fallback prices are intentionally conservative enough for a budget -# guard. The live /api/jobs/hardware catalog wins whenever it is reachable. -HF_JOBS_PRICE_USD_PER_HOUR: dict[str, float] = { - "cpu-basic": 0.05, - "cpu-upgrade": 0.25, - "cpu-performance": 0.50, - "cpu-xl": 1.00, - "t4-small": 0.60, - "t4-medium": 0.90, - "l4x1": 1.00, - "l4x4": 4.00, - "l40sx1": 2.00, - "l40sx4": 8.00, - "l40sx8": 16.00, - "a10g-small": 1.00, - "a10g-large": 2.00, - "a10g-largex2": 4.00, - "a10g-largex4": 8.00, - "a100-large": 4.00, - "a100x4": 16.00, - "a100x8": 32.00, - "h200": 10.00, - "h200x2": 20.00, - "h200x4": 40.00, - "h200x8": 80.00, - "inf2x6": 6.00, -} - -SPACE_PRICE_USD_PER_HOUR: dict[str, float] = { - "cpu-basic": 0.0, - "cpu-upgrade": 0.05, - "cpu-performance": 0.50, - "cpu-xl": 1.00, - "t4-small": 0.60, - "t4-medium": 0.90, - "l4x1": 1.00, - "l4x4": 4.00, - "l40sx1": 2.00, - "l40sx4": 8.00, - "l40sx8": 16.00, - "a10g-small": 1.00, - "a10g-large": 2.00, - "a10g-largex2": 4.00, - "a10g-largex4": 8.00, - "a100-large": 4.00, - "a100x4": 16.00, - "a100x8": 32.00, - "h200": 10.00, - "h200x2": 20.00, - "h200x4": 40.00, - "h200x8": 80.00, - "inf2x6": 6.00, -} - -_DURATION_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s*([smhd]?)\s*$", re.IGNORECASE) -_PRICE_RE = re.compile(r"(\d+(?:\.\d+)?)") -_jobs_price_cache: tuple[float, dict[str, float]] | None = None - - -@dataclass(frozen=True) -class CostEstimate: - """Estimated cost for a tool call. - - ``estimated_cost_usd=None`` means the call may be billable but we could not - estimate it safely, so auto-approval should fall back to a human decision. - """ - - estimated_cost_usd: float | None - billable: bool - block_reason: str | None = None - label: str | None = None - - -def parse_timeout_hours( - value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS -) -> float | None: - """Parse HF timeout values into hours. - - Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are - treated as seconds, matching the Hub client's typed timeout parameter. - """ - if value is None or value == "": - return default_hours - if isinstance(value, bool): - return None - if isinstance(value, int | float): - seconds = float(value) - return seconds / 3600 if seconds > 0 else None - if not isinstance(value, str): - return None - - match = _DURATION_RE.match(value) - if not match: - return None - amount = float(match.group(1)) - unit = match.group(2).lower() or "s" - if amount <= 0: - return None - if unit == "s": - return amount / 3600 - if unit == "m": - return amount / 60 - if unit == "h": - return amount - if unit == "d": - return amount * 24 - return None - - -def _extract_flavor(item: dict[str, Any]) -> str | None: - for key in ("flavor", "name", "id", "value", "hardware", "hardware_flavor"): - value = item.get(key) - if isinstance(value, str) and value: - return value - return None - - -def _coerce_price(value: Any) -> float | None: - if isinstance(value, bool) or value is None: - return None - if isinstance(value, int | float): - return float(value) if value >= 0 else None - if isinstance(value, str): - match = _PRICE_RE.search(value.replace(",", "")) - if match: - return float(match.group(1)) - return None - - -def _extract_hourly_price(item: dict[str, Any]) -> float | None: - for key in ( - "price", - "price_usd", - "priceUsd", - "price_per_hour", - "pricePerHour", - "hourly_price", - "hourlyPrice", - "usd_per_hour", - "usdPerHour", - ): - price = _coerce_price(item.get(key)) - if price is not None: - return price - for key in ("pricing", "billing", "cost"): - nested = item.get(key) - if isinstance(nested, dict): - price = _extract_hourly_price(nested) - if price is not None: - return price - return None - - -def _iter_hardware_items(payload: Any): - if isinstance(payload, list): - for item in payload: - yield from _iter_hardware_items(item) - elif isinstance(payload, dict): - if _extract_flavor(payload): - yield payload - for key in ("hardware", "flavors", "items", "data", "jobs"): - child = payload.get(key) - if child is not None: - yield from _iter_hardware_items(child) - - -def _parse_jobs_price_catalog(payload: Any) -> dict[str, float]: - prices: dict[str, float] = {} - for item in _iter_hardware_items(payload): - flavor = _extract_flavor(item) - price = _extract_hourly_price(item) - if flavor and price is not None: - prices[flavor] = price - return prices - - -async def hf_jobs_price_catalog() -> dict[str, float]: - """Return live HF Jobs hourly prices, falling back to static prices.""" - global _jobs_price_cache - now = time.monotonic() - if _jobs_price_cache and now - _jobs_price_cache[0] < JOBS_PRICE_CACHE_TTL_S: - return dict(_jobs_price_cache[1]) - - prices: dict[str, float] = {} - try: - async with httpx.AsyncClient(timeout=3.0) as client: - response = await client.get(JOBS_HARDWARE_URL) - if response.status_code == 200: - prices = _parse_jobs_price_catalog(response.json()) - except (httpx.HTTPError, ValueError): - prices = {} - - if not prices: - prices = dict(HF_JOBS_PRICE_USD_PER_HOUR) - else: - prices = {**HF_JOBS_PRICE_USD_PER_HOUR, **prices} - - _jobs_price_cache = (now, prices) - return dict(prices) - - -async def estimate_hf_job_cost(args: dict[str, Any]) -> CostEstimate: - flavor = str( - args.get("hardware_flavor") - or args.get("flavor") - or args.get("hardware") - or "cpu-basic" - ) - timeout_hours = parse_timeout_hours(args.get("timeout")) - if timeout_hours is None: - return CostEstimate( - estimated_cost_usd=None, - billable=True, - block_reason=f"Could not parse HF job timeout: {args.get('timeout')!r}.", - label=flavor, - ) - - prices = await hf_jobs_price_catalog() - price = prices.get(flavor) - if price is None: - return CostEstimate( - estimated_cost_usd=None, - billable=True, - block_reason=f"No price is available for HF job hardware '{flavor}'.", - label=flavor, - ) - - return CostEstimate( - estimated_cost_usd=round(price * timeout_hours, 4), - billable=price > 0, - label=flavor, - ) - - -async def estimate_sandbox_cost( - args: dict[str, Any], *, session: Any = None -) -> CostEstimate: - if session is not None and getattr(session, "sandbox", None): - return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing") - - hardware = str(args.get("hardware") or "cpu-basic") - price = SPACE_PRICE_USD_PER_HOUR.get(hardware) - if price is None: - return CostEstimate( - estimated_cost_usd=None, - billable=True, - block_reason=f"No price is available for sandbox hardware '{hardware}'.", - label=hardware, - ) - - return CostEstimate( - estimated_cost_usd=round(price * DEFAULT_SANDBOX_RESERVATION_HOURS, 4), - billable=price > 0, - label=hardware, - ) - - -async def estimate_tool_cost( - tool_name: str, args: dict[str, Any], *, session: Any = None -) -> CostEstimate: - if tool_name == "sandbox_create": - return await estimate_sandbox_cost(args, session=session) - if tool_name == "hf_jobs": - return await estimate_hf_job_cost(args) - return CostEstimate(estimated_cost_usd=0.0, billable=False) diff --git a/agent/core/doom_loop.py b/agent/core/doom_loop.py index 3b57fe2cc3cffd07b466db9ac98cc0d0b665de79..5050d75509ba3a46bd7d0106a1e0f735d6dbda62 100644 --- a/agent/core/doom_loop.py +++ b/agent/core/doom_loop.py @@ -17,58 +17,25 @@ logger = logging.getLogger(__name__) @dataclass(frozen=True) class ToolCallSignature: - """Hashable signature for a single tool call plus its observed result.""" + """Hashable signature for a single tool call (name + args hash).""" name: str args_hash: str - result_hash: str | None = None - - -def _normalize_args(args_str: str) -> str: - """Canonicalise a tool-call arguments string before hashing. - - LLMs can emit semantically-identical JSON for the same call with different - key orderings (``{"a": 1, "b": 2}`` vs ``{"b": 2, "a": 1}``) or whitespace - (``{"a":1}`` vs ``{"a": 1}``). Hashing the raw bytes makes the doom-loop - detector miss those repeats. We parse-and-redump with ``sort_keys=True`` - plus the most compact separators so trivially-different spellings collapse - to the same canonical form. - - Falls back to the original string if the input isn't valid JSON (e.g. a - handful of providers occasionally pass a bare string for ``arguments``); - that path keeps the legacy behaviour and never raises. - """ - if not args_str: - return "" - try: - return json.dumps(json.loads(args_str), sort_keys=True, separators=(",", ":")) - except (json.JSONDecodeError, TypeError, ValueError): - return args_str def _hash_args(args_str: str) -> str: - """Return a short hash of the JSON arguments string. - - The input is normalised via :func:`_normalize_args` first so that - semantically-identical tool calls produce the same hash regardless of key - order or whitespace. - """ - return hashlib.md5(_normalize_args(args_str).encode()).hexdigest()[:12] + """Return a short hash of the JSON arguments string.""" + return hashlib.md5(args_str.encode()).hexdigest()[:12] def extract_recent_tool_signatures( messages: list[Message], lookback: int = 30 ) -> list[ToolCallSignature]: - """Extract tool call signatures from recent assistant messages. - - Includes the immediate tool result hash when present. This prevents - legitimate polling from being classified as a doom loop when the poll - arguments stay constant but the observed result keeps changing. - """ + """Extract tool call signatures from recent assistant messages.""" signatures: list[ToolCallSignature] = [] recent = messages[-lookback:] if len(messages) > lookback else messages - for idx, msg in enumerate(recent): + for msg in recent: if getattr(msg, "role", None) != "assistant": continue tool_calls = getattr(msg, "tool_calls", None) @@ -80,23 +47,7 @@ def extract_recent_tool_signatures( continue name = getattr(fn, "name", "") or "" args_str = getattr(fn, "arguments", "") or "" - result_hash = None - for follow in recent[idx + 1 :]: - role = getattr(follow, "role", None) - if role == "tool" and getattr(follow, "tool_call_id", None) == getattr( - tc, "id", None - ): - result_hash = _hash_args(str(getattr(follow, "content", "") or "")) - break - if role in {"assistant", "user"}: - break - signatures.append( - ToolCallSignature( - name=name, - args_hash=_hash_args(args_str), - result_hash=result_hash, - ) - ) + signatures.append(ToolCallSignature(name=name, args_hash=_hash_args(args_str))) return signatures @@ -158,13 +109,9 @@ def check_for_doom_loop(messages: list[Message]) -> str | None: # Check for identical consecutive calls tool_name = detect_identical_consecutive(signatures, threshold=3) if tool_name: - logger.warning( - "Repetition guard activated: %d+ identical consecutive calls to '%s'", - 3, - tool_name, - ) + logger.warning("Doom loop detected: %d+ identical consecutive calls to '%s'", 3, tool_name) return ( - f"[SYSTEM: REPETITION GUARD] You have called '{tool_name}' with the same " + f"[SYSTEM: DOOM LOOP DETECTED] You have called '{tool_name}' with the same " f"arguments multiple times in a row, getting the same result each time. " f"STOP repeating this approach β€” it is not working. " f"Step back and try a fundamentally different strategy. " @@ -176,11 +123,9 @@ def check_for_doom_loop(messages: list[Message]) -> str | None: pattern = detect_repeating_sequence(signatures) if pattern: pattern_desc = " β†’ ".join(s.name for s in pattern) - logger.warning( - "Repetition guard activated: repeating sequence [%s]", pattern_desc - ) + logger.warning("Doom loop detected: repeating sequence [%s]", pattern_desc) return ( - f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: " + f"[SYSTEM: DOOM LOOP DETECTED] You are stuck in a repeating cycle of tool calls: " f"[{pattern_desc}]. This pattern has repeated multiple times without progress. " f"STOP this cycle and try a fundamentally different approach. " f"Consider: breaking down the problem differently, using alternative tools, " diff --git a/agent/core/effort_probe.py b/agent/core/effort_probe.py deleted file mode 100644 index dbad4c3da95e939ec9d2dae5c6c7408bcc6ea156..0000000000000000000000000000000000000000 --- a/agent/core/effort_probe.py +++ /dev/null @@ -1,284 +0,0 @@ -"""Probe-and-cascade for reasoning effort on /model switch. - -We don't maintain a per-model capability table. Instead, the first time a -user picks a model we fire a 1-token ping with the same params we'd use -for real and walk down a cascade (``max`` β†’ ``xhigh`` β†’ ``high`` β†’ …) -until the provider stops rejecting us. The result is cached per-model on -the session, so real messages don't pay the probe cost again. - -Three outcomes, classified from the 400 error text: - -* success β†’ cache the effort that worked -* ``"thinking ... not supported"`` β†’ model doesn't do thinking at all; - cache ``None`` so we stop sending thinking params -* ``"effort ... invalid"`` / synonyms β†’ cascade walks down and retries - -Transient errors (5xx, timeout, connection reset) bubble out as -``ProbeInconclusive`` so the caller can complete the switch with a -warning instead of blocking on a flaky provider. -""" - -from __future__ import annotations - -import asyncio -import logging -import time -from dataclasses import dataclass -from typing import Any - -from litellm import acompletion - -from agent.core.llm_params import UnsupportedEffortError, _resolve_llm_params - -logger = logging.getLogger(__name__) - - -# Cascade: for each user-stated preference, the ordered list of levels to -# try. First success wins. ``max`` is Anthropic-only; ``xhigh`` is also -# supported on current OpenAI GPT-5 models. Providers that don't accept a -# requested level raise ``UnsupportedEffortError`` synchronously (no wasted -# network round-trip) and we advance to the next level. -_EFFORT_CASCADE: dict[str, list[str]] = { - "max": ["max", "xhigh", "high", "medium", "low"], - "xhigh": ["xhigh", "high", "medium", "low"], - "high": ["high", "medium", "low"], - "medium": ["medium", "low"], - "minimal": ["minimal", "low"], - "low": ["low"], -} - -_PROBE_TIMEOUT = 15.0 -# Keep the probe cheap, but high enough that frontier reasoning models can -# finish a trivial reply instead of tripping a false "output limit reached" -# error during capability detection. -_PROBE_MAX_TOKENS = 64 - - -class ProbeInconclusive(Exception): - """The probe couldn't reach a verdict (transient network / provider error). - - Caller should complete the switch with a warning β€” the next real call - will re-surface the error if it's persistent. - """ - - -@dataclass -class ProbeOutcome: - """What the probe learned. ``effective_effort`` semantics match the cache: - - * str β†’ send this level - * None β†’ model doesn't support thinking; strip it - """ - - effective_effort: str | None - attempts: int - elapsed_ms: int - note: str | None = None # e.g. "max not supported, falling back" - - -def _is_thinking_unsupported(e: Exception) -> bool: - """Model rejected any thinking config. - - Matches Anthropic's 'thinking.type.enabled is not supported for this - model' as well as the adaptive variant. Substring-match because the - exact wording shifts across API versions. - """ - s = str(e).lower() - return "thinking" in s and "not supported" in s - - -def _is_invalid_effort(e: Exception) -> bool: - """The requested effort level isn't accepted for this model. - - Covers both API responses (Anthropic/OpenAI 400 with "invalid", "must - be one of", etc.) and LiteLLM's local validation that fires *before* - the request (e.g. "effort='max' is only supported by Claude Opus 4.6" - β€” LiteLLM knows max is Opus-4.6-only and raises synchronously). The - cascade walks down on either. - - Explicitly returns False when the message is really about thinking - itself (e.g. Anthropic's 4.7 error mentions ``output_config.effort`` - in its fix hint, but the actual failure is ``thinking.type.enabled`` - being unsupported). That case is caught by ``_is_thinking_unsupported``. - """ - if _is_thinking_unsupported(e): - return False - s = str(e).lower() - if "effort" not in s and "output_config" not in s: - return False - return any( - phrase in s - for phrase in ( - "invalid", - "not supported", - "must be one of", - "not a valid", - "unrecognized", - "unknown", - # LiteLLM's own pre-flight validation phrasing. - "only supported by", - "is only supported", - ) - ) - - -def _is_transient(e: Exception) -> bool: - """Network / provider-side flake. Keep in sync with agent_loop's list. - - Also matches by type for ``asyncio.TimeoutError`` β€” its ``str(e)`` is - empty, so substring matching alone misses it. - """ - if isinstance(e, (asyncio.TimeoutError, TimeoutError)): - return True - s = str(e).lower() - return any( - p in s - for p in ( - "timeout", - "timed out", - "429", - "rate limit", - "503", - "service unavailable", - "502", - "bad gateway", - "500", - "internal server error", - "overloaded", - "capacity", - "connection reset", - "connection refused", - "connection error", - "eof", - "broken pipe", - ) - ) - - -async def probe_effort( - model_name: str, - preference: str | None, - hf_token: str | None, - session: Any = None, -) -> ProbeOutcome: - """Walk the cascade for ``preference`` on ``model_name``. - - Returns the first effort the provider accepts, or ``None`` if it - rejects thinking altogether. Raises ``ProbeInconclusive`` only for - transient errors (5xx, timeout) β€” persistent 4xx that aren't thinking/ - effort related bubble as the original exception so callers can surface - them (auth, model-not-found, quota, etc.). - - ``session`` is optional; when provided, each successful probe attempt - is recorded via ``telemetry.record_llm_call(kind="effort_probe")`` so - the cost shows up in the session's ``total_cost_usd``. Failed probes - (rejected by the provider) typically aren't billed, so we only record - on success. - """ - loop = asyncio.get_event_loop() - start = loop.time() - attempts = 0 - - if not preference: - # User explicitly turned effort off β€” nothing to probe. A bare - # ping with no thinking params is pointless; just report "off". - return ProbeOutcome(effective_effort=None, attempts=0, elapsed_ms=0) - - cascade = _EFFORT_CASCADE.get(preference, [preference]) - skipped: list[str] = [] # levels the provider rejected synchronously - - last_error: Exception | None = None - for effort in cascade: - try: - params = _resolve_llm_params( - model_name, - hf_token, - reasoning_effort=effort, - strict=True, - ) - except UnsupportedEffortError: - # Provider can't even accept this effort name (e.g. "max" on - # HF router). Skip without a network call. - skipped.append(effort) - continue - - attempts += 1 - try: - _t0 = time.monotonic() - response = await asyncio.wait_for( - acompletion( - messages=[{"role": "user", "content": "ping"}], - max_tokens=_PROBE_MAX_TOKENS, - stream=False, - **params, - ), - timeout=_PROBE_TIMEOUT, - ) - if session is not None: - # Best-effort telemetry β€” never let a logging blip propagate - # out of the probe and break model switching. - try: - from agent.core import telemetry - - await telemetry.record_llm_call( - session, - model=model_name, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind="effort_probe", - ) - except Exception as _telem_err: - logger.debug("effort_probe telemetry failed: %s", _telem_err) - except Exception as e: - last_error = e - if _is_thinking_unsupported(e): - elapsed = int((loop.time() - start) * 1000) - return ProbeOutcome( - effective_effort=None, - attempts=attempts, - elapsed_ms=elapsed, - note="model doesn't support reasoning, dropped", - ) - if _is_invalid_effort(e): - logger.debug( - "probe: %s rejected effort=%s, trying next", model_name, effort - ) - continue - if _is_transient(e): - raise ProbeInconclusive(str(e)) from e - # Persistent non-thinking 4xx (auth, quota, model-not-found) β€” - # let the caller classify & surface. - raise - else: - elapsed = int((loop.time() - start) * 1000) - note = None - if effort != preference: - note = f"{preference} not supported, using {effort}" - return ProbeOutcome( - effective_effort=effort, - attempts=attempts, - elapsed_ms=elapsed, - note=note, - ) - - # Cascade exhausted without a success. This only happens when every - # level was either rejected synchronously (``UnsupportedEffortError``, - # e.g. preference=max on HF and we also somehow filtered all others) - # or the provider 400'd ``invalid effort`` on every level. - elapsed = int((loop.time() - start) * 1000) - if last_error is not None and not _is_invalid_effort(last_error): - raise last_error - note = ( - "no effort level accepted β€” proceeding without thinking" - if not skipped - else f"provider rejected all efforts ({', '.join(skipped)})" - ) - return ProbeOutcome( - effective_effort=None, - attempts=attempts, - elapsed_ms=elapsed, - note=note, - ) diff --git a/agent/core/hf_access.py b/agent/core/hf_access.py deleted file mode 100644 index 254a9c73161df9be2866f7cd2574dc7701934f08..0000000000000000000000000000000000000000 --- a/agent/core/hf_access.py +++ /dev/null @@ -1,172 +0,0 @@ -"""Helpers for Hugging Face account / org access decisions. - -HF Jobs are gated by *credits*, not by HF Pro subscriptions. Any user who -has credits β€” on their personal account or on an org they belong to β€” can -launch jobs under that namespace. The picker UI lets the caller choose -which wallet to bill. -""" - -from __future__ import annotations - -import asyncio -import os -import re -from dataclasses import dataclass -from typing import Any - -import httpx - -OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") - - -@dataclass(frozen=True) -class JobsAccess: - """Namespaces the caller may bill HF Jobs to.""" - - username: str | None - org_names: list[str] - eligible_namespaces: list[str] - default_namespace: str | None - access_known: bool = True - - -class JobsAccessError(Exception): - """Structured jobs-namespace error. - - ``namespace_required`` fires when the caller belongs to more than one - eligible namespace and the UI must prompt them to pick one. There is no - longer an ``upgrade_required`` state β€” Pro is irrelevant; HF Jobs are - gated on per-wallet credits, surfaced separately when the API returns - a billing error at job-creation time. - """ - - def __init__( - self, - message: str, - *, - access: JobsAccess | None = None, - namespace_required: bool = False, - ) -> None: - super().__init__(message) - self.access = access - self.namespace_required = namespace_required - - -def _extract_username(whoami: dict[str, Any]) -> str | None: - for key in ("name", "user", "preferred_username"): - value = whoami.get(key) - if isinstance(value, str) and value: - return value - return None - - -def _org_names(whoami: dict[str, Any]) -> list[str]: - """All orgs the caller belongs to. - - Plan/tier is ignored β€” credits live on the namespace itself, so any - org the user belongs to can host a job as long as it has credits. - """ - names: list[str] = [] - orgs = whoami.get("orgs") or [] - if not isinstance(orgs, list): - return names - for org in orgs: - if not isinstance(org, dict): - continue - name = org.get("name") - if isinstance(name, str) and name: - names.append(name) - return sorted(set(names)) - - -def jobs_access_from_whoami(whoami: dict[str, Any]) -> JobsAccess: - username = _extract_username(whoami) - org_names = _org_names(whoami) - eligible: list[str] = [] - if username: - eligible.append(username) - eligible.extend(org_names) - default = username if username else (org_names[0] if org_names else None) - return JobsAccess( - username=username, - org_names=org_names, - eligible_namespaces=eligible, - default_namespace=default, - ) - - -async def fetch_whoami_v2(token: str, timeout: float = 5.0) -> dict[str, Any] | None: - if not token: - return None - async with httpx.AsyncClient(timeout=timeout) as client: - try: - response = await client.get( - f"{OPENID_PROVIDER_URL}/api/whoami-v2", - headers={"Authorization": f"Bearer {token}"}, - ) - if response.status_code != 200: - return None - payload = response.json() - return payload if isinstance(payload, dict) else None - except (httpx.HTTPError, ValueError): - return None - - -async def get_jobs_access(token: str) -> JobsAccess | None: - whoami = await fetch_whoami_v2(token) - if whoami is None: - return None - return jobs_access_from_whoami(whoami) - - -async def resolve_jobs_namespace( - token: str, - requested_namespace: str | None = None, -) -> tuple[str, JobsAccess | None]: - """Return the namespace to use for jobs. - - If whoami-v2 is unavailable, fall back to the token owner's username. - """ - access = await get_jobs_access(token) - if access: - if requested_namespace: - if requested_namespace in access.eligible_namespaces: - return requested_namespace, access - raise JobsAccessError( - f"You can only run jobs under your own account or an org you belong to. " - f"Allowed namespaces: {', '.join(access.eligible_namespaces) or '(none)'}", - access=access, - ) - if access.default_namespace: - return access.default_namespace, access - raise JobsAccessError( - "Couldn't resolve a Hugging Face namespace for this token.", - access=access, - ) - - # Fallback: whoami-v2 unavailable. Don't block the call pre-emptively. - from huggingface_hub import HfApi - - username = None - if token: - whoami = await asyncio.to_thread(HfApi(token=token).whoami) - username = whoami.get("name") - if not username: - raise JobsAccessError("No HF token available to resolve a jobs namespace.") - return requested_namespace or username, None - - -_BILLING_PATTERNS = re.compile( - r"\b(insufficient[_\s-]?credits?|out\s+of\s+credits?|payment\s+required|" - r"billing|no\s+credits?|add\s+credits?|requires?\s+credits?)\b", - re.IGNORECASE, -) - - -def is_billing_error(message: str) -> bool: - """True if an HF API error message looks like an out-of-credits / billing error.""" - if not message: - return False - if "402" in message: - return True - return bool(_BILLING_PATTERNS.search(message)) diff --git a/agent/core/hf_router_catalog.py b/agent/core/hf_router_catalog.py deleted file mode 100644 index 625ccf4fb85498e229fe63dc0faac56628d0be39..0000000000000000000000000000000000000000 --- a/agent/core/hf_router_catalog.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Fetch and cache the HF Inference Router model catalog. - -The router exposes an OpenAI-compatible listing at -``https://router.huggingface.co/v1/models`` with per-provider availability, -pricing, context length, and tool-use support. We use it to: - - β€’ Validate ``/model`` switches with live data instead of a hard-coded allowlist. - β€’ Show the user which providers serve a model, at what price, and whether they - support tool calls. - β€’ Derive a reasonable context-window limit for any routed model. - -The listing is cached in-memory for a few minutes so repeated lookups during a -session are free. On fetch failure we return stale data if we have it, or an -empty catalog otherwise. -""" - -import logging -import time -from dataclasses import dataclass -from difflib import get_close_matches -from typing import Optional - -import httpx - -logger = logging.getLogger(__name__) - -_CATALOG_URL = "https://router.huggingface.co/v1/models" -_CACHE_TTL_SECONDS = 300 -_HTTP_TIMEOUT_SECONDS = 5.0 - -_cache: Optional[dict] = None -_cache_time: float = 0.0 - - -@dataclass -class ProviderInfo: - provider: str - status: str - context_length: Optional[int] - input_price: Optional[float] - output_price: Optional[float] - supports_tools: bool - supports_structured_output: bool - - -@dataclass -class ModelInfo: - id: str - providers: list[ProviderInfo] - - @property - def live_providers(self) -> list[ProviderInfo]: - return [p for p in self.providers if p.status == "live"] - - @property - def max_context_length(self) -> Optional[int]: - lengths = [p.context_length for p in self.live_providers if p.context_length] - return max(lengths) if lengths else None - - @property - def any_supports_tools(self) -> bool: - return any(p.supports_tools for p in self.live_providers) - - -def _fetch_catalog(force: bool = False) -> dict: - global _cache, _cache_time - now = time.time() - if not force and _cache is not None and now - _cache_time < _CACHE_TTL_SECONDS: - return _cache - try: - resp = httpx.get(_CATALOG_URL, timeout=_HTTP_TIMEOUT_SECONDS) - resp.raise_for_status() - _cache = resp.json() - _cache_time = now - except Exception as e: - logger.warning("Failed to fetch HF router catalog: %s", e) - if _cache is None: - _cache = {"data": []} - _cache_time = now - return _cache - - -def _parse_entry(entry: dict) -> ModelInfo: - providers = [] - for p in entry.get("providers", []) or []: - pricing = p.get("pricing") or {} - providers.append( - ProviderInfo( - provider=p.get("provider", ""), - status=p.get("status", ""), - context_length=p.get("context_length"), - input_price=pricing.get("input"), - output_price=pricing.get("output"), - supports_tools=bool(p.get("supports_tools", False)), - supports_structured_output=bool( - p.get("supports_structured_output", False) - ), - ) - ) - return ModelInfo(id=entry.get("id", ""), providers=providers) - - -def lookup(model_id: str) -> Optional[ModelInfo]: - """Find a model in the router catalog. - - Accepts ``/`` or ``/:`` β€” the tag is stripped - for lookup. Returns ``None`` if the model isn't listed. - """ - bare = model_id.split(":", 1)[0] - catalog = _fetch_catalog() - for entry in catalog.get("data", []): - if entry.get("id") == bare: - return _parse_entry(entry) - return None - - -def fuzzy_suggest(model_id: str, limit: int = 3) -> list[str]: - """Return the closest model ids from the catalog.""" - bare = model_id.split(":", 1)[0] - catalog = _fetch_catalog() - ids = [e.get("id", "") for e in catalog.get("data", []) if e.get("id")] - return get_close_matches(bare, ids, n=limit, cutoff=0.4) - - -def prewarm() -> None: - """Fetch the catalog so subsequent lookups are instant. Safe to call from - a background task β€” swallows failures.""" - try: - _fetch_catalog(force=False) - except Exception: - pass diff --git a/agent/core/hf_tokens.py b/agent/core/hf_tokens.py deleted file mode 100644 index 3e72ccc128a9d9aaecb661c4c2ba3850a10b5dc0..0000000000000000000000000000000000000000 --- a/agent/core/hf_tokens.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Hugging Face token resolution helpers.""" - -from __future__ import annotations - -import os -from typing import Any - - -def clean_hf_token(token: str | None) -> str | None: - """Normalize token strings the same way huggingface_hub does.""" - if token is None: - return None - return token.replace("\r", "").replace("\n", "").strip() or None - - -def get_cached_hf_token() -> str | None: - """Return the token from huggingface_hub's normal env/cache lookup.""" - try: - from huggingface_hub import get_token - - return get_token() - except Exception: - return None - - -def resolve_hf_token( - *candidates: str | None, - include_cached: bool = True, -) -> str | None: - """Return the first non-empty explicit token, then optionally HF cache.""" - for token in candidates: - cleaned = clean_hf_token(token) - if cleaned: - return cleaned - if include_cached: - return get_cached_hf_token() - return None - - -def resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: - """Resolve the token used for Hugging Face Router LLM calls. - - App-specific precedence: - 1. INFERENCE_TOKEN: shared hosted-Space inference token. - 2. session_hf_token: the active user/session token. - 3. huggingface_hub.get_token(): HF_TOKEN/HUGGING_FACE_HUB_TOKEN or - local ``hf auth login`` cache. - """ - return resolve_hf_token(os.environ.get("INFERENCE_TOKEN"), session_hf_token) - - -def get_hf_bill_to() -> str | None: - """Return X-HF-Bill-To only when a shared inference token is active.""" - if clean_hf_token(os.environ.get("INFERENCE_TOKEN")): - return os.environ.get("HF_BILL_TO", "smolagents") - return None - - -def bearer_token_from_header(auth_header: str | None) -> str | None: - """Extract a cleaned bearer token from an Authorization header.""" - if not auth_header or not auth_header.startswith("Bearer "): - return None - return clean_hf_token(auth_header[7:]) - - -def resolve_hf_request_token( - request: Any, - *, - include_env_fallback: bool = True, -) -> str | None: - """Resolve a user token from a FastAPI request. - - This intentionally does not use the local ``hf auth login`` cache. Backend - request paths should act as the browser user from Authorization/cookie, or - fall back only to an explicit server ``HF_TOKEN`` in dev/server contexts. - """ - token = bearer_token_from_header(request.headers.get("Authorization", "")) - if token: - return token - token = clean_hf_token(request.cookies.get("hf_access_token")) - if token: - return token - if include_env_fallback: - return clean_hf_token(os.environ.get("HF_TOKEN")) - return None diff --git a/agent/core/hub_artifacts.py b/agent/core/hub_artifacts.py deleted file mode 100644 index bf6002677ac568aa2edad9c77ab404a08787c2dc..0000000000000000000000000000000000000000 --- a/agent/core/hub_artifacts.py +++ /dev/null @@ -1,790 +0,0 @@ -"""Best-effort Hub metadata for artifacts generated by ML Intern sessions.""" - -import asyncio -import base64 -import logging -import re -import shlex -import tempfile -import textwrap -from datetime import datetime -from pathlib import Path -from typing import Any - -from huggingface_hub import HfApi, hf_hub_download -from huggingface_hub.repocard import metadata_load, metadata_save -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError - -logger = logging.getLogger(__name__) - -ML_INTERN_TAG = "ml-intern" -SUPPORTED_REPO_TYPES = {"model", "dataset", "space"} -PROVENANCE_MARKER = "" -_COLLECTION_TITLE_PREFIX = "ml-intern-artifacts" -_COLLECTION_TITLE_MAX_LENGTH = 59 -_UUID_SESSION_ID_RE = re.compile( - r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-" - r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" -) -_KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts" -_REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts" -_COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug" -_COLLECTION_TASK_ATTR = "_ml_intern_artifact_collection_task" -_SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {} -_USAGE_HEADING_RE = re.compile( - r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b", - re.IGNORECASE | re.MULTILINE, -) -_FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL) - - -def _safe_session_id(session: Any) -> str: - raw = str(getattr(session, "session_id", "") or "unknown-session") - safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-") - return safe or "unknown-session" - - -def session_artifact_date(session: Any) -> str: - """Return the YYYY-MM-DD partition date for a session.""" - raw = getattr(session, "session_start_time", None) - if raw: - try: - return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime( - "%Y-%m-%d" - ) - except ValueError: - logger.debug("Could not parse session_start_time=%r", raw) - return datetime.utcnow().strftime("%Y-%m-%d") - - -def _collection_session_id_fragment(session: Any) -> str: - safe_id = _safe_session_id(session) - if _UUID_SESSION_ID_RE.match(safe_id): - return safe_id[:8] - stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" - max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem)) - if len(safe_id) <= max_id_length: - return safe_id - return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length] - - -def artifact_collection_title(session: Any) -> str: - return ( - f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" - f"{_collection_session_id_fragment(session)}" - ) - - -def _artifact_key(repo_id: str, repo_type: str | None) -> str: - return f"{repo_type or 'model'}:{repo_id}" - - -def _sandbox_space_name_pattern() -> str: - from agent.tools.sandbox_tool import SANDBOX_SPACE_NAME_RE - - return SANDBOX_SPACE_NAME_RE.pattern - - -def is_sandbox_hub_repo(repo_id: str | None, repo_type: str | None) -> bool: - """Return True for ML Intern's ephemeral sandbox Space repos.""" - if (repo_type or "model") != "space" or not repo_id: - return False - repo_name = str(repo_id).rsplit("/", 1)[-1] - return bool(re.fullmatch(_sandbox_space_name_pattern(), repo_name)) - - -def _session_artifact_set(session: Any, attr: str) -> set[str]: - current = getattr(session, attr, None) - if isinstance(current, set): - return current - current = set() - try: - setattr(session, attr, current) - except Exception: - logger.warning( - "Could not attach %s to session; using process-local fallback state", - attr, - ) - return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set()) - return current - - -def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None: - if session is None or not repo_id: - return - _session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add( - _artifact_key(repo_id, repo_type) - ) - - -def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool: - if session is None or not repo_id: - return False - return _artifact_key(repo_id, repo_type) in _session_artifact_set( - session, _KNOWN_ARTIFACTS_ATTR - ) - - -def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]: - merged = dict(metadata) - raw_tags = merged.get("tags") - if raw_tags is None: - tags: list[str] = [] - elif isinstance(raw_tags, str): - tags = [raw_tags] - elif isinstance(raw_tags, list): - tags = [str(item) for item in raw_tags] - else: - tags = [str(raw_tags)] - - if tag not in tags: - tags.append(tag) - merged["tags"] = tags - return merged - - -def _metadata_from_content(content: str) -> dict[str, Any]: - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) / "README.md" - path.write_text(content, encoding="utf-8") - return metadata_load(path) or {} - - -def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str: - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) / "README.md" - path.write_text(content, encoding="utf-8") - metadata_save(path, metadata) - return path.read_text(encoding="utf-8") - - -def _body_without_metadata(content: str) -> str: - return _FRONT_MATTER_RE.sub("", content, count=1).strip() - - -def _append_section(content: str, section: str) -> str: - base = content.rstrip() - if base: - return f"{base}\n\n{section.strip()}\n" - return f"{section.strip()}\n" - - -def _provenance_section(repo_type: str) -> str: - label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub") - return f"""{PROVENANCE_MARKER} -## Generated by ML Intern - -This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. - -- Try ML Intern: https://smolagents-ml-intern.hf.space -- Source code: https://github.com/huggingface/ml-intern -""" - - -def _usage_section(repo_id: str, repo_type: str) -> str: - if repo_type == "dataset": - return f"""## Usage - -```python -from datasets import load_dataset - -dataset = load_dataset("{repo_id}") -``` -""" - - return f"""## Usage - -```python -from transformers import AutoModelForCausalLM, AutoTokenizer - -model_id = "{repo_id}" -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id) -``` - -For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class. -""" - - -def augment_repo_card_content( - content: str | None, - repo_id: str, - repo_type: str = "model", - *, - extra_metadata: dict[str, Any] | None = None, -) -> str: - """Return README content with ML Intern metadata and provenance added.""" - repo_type = repo_type or "model" - content = content or "" - metadata = _metadata_from_content(content) - if extra_metadata: - metadata = {**extra_metadata, **metadata} - metadata = _merge_tags(metadata) - updated = _content_with_metadata(content, metadata) - - if not _body_without_metadata(updated): - updated = _append_section(updated, f"# {repo_id}") - - if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated: - updated = _append_section(updated, _provenance_section(repo_type)) - if not _USAGE_HEADING_RE.search(content): - updated = _append_section(updated, _usage_section(repo_id, repo_type)) - - return updated - - -def _read_remote_readme( - api: Any, - repo_id: str, - repo_type: str, - *, - token: str | bool | None = None, -) -> str: - token_value = token if token is not None else getattr(api, "token", None) - try: - readme_path = hf_hub_download( - repo_id=repo_id, - filename="README.md", - repo_type=repo_type, - token=token_value, - ) - except (EntryNotFoundError, RepositoryNotFoundError): - return "" - return Path(readme_path).read_text(encoding="utf-8") - - -def _update_repo_card( - api: Any, - repo_id: str, - repo_type: str, - *, - token: str | bool | None = None, - extra_metadata: dict[str, Any] | None = None, -) -> None: - current = _read_remote_readme(api, repo_id, repo_type, token=token) - updated = augment_repo_card_content( - current, - repo_id, - repo_type, - extra_metadata=extra_metadata, - ) - if updated == current: - return - api.upload_file( - path_or_fileobj=updated.encode("utf-8"), - path_in_repo="README.md", - repo_id=repo_id, - repo_type=repo_type, - token=token, - commit_message="Update ML Intern artifact metadata", - ) - - -def _ensure_collection_slug( - api: Any, - session: Any, - *, - token: str | bool | None = None, -) -> str | None: - slug = getattr(session, _COLLECTION_SLUG_ATTR, None) - if slug: - return slug - - title = artifact_collection_title(session) - collection = api.create_collection( - title=title, - description=( - f"Artifacts generated by ML Intern session {_safe_session_id(session)} " - f"on {session_artifact_date(session)}." - ), - private=True, - exists_ok=True, - token=token, - ) - slug = getattr(collection, "slug", None) - if slug: - setattr(session, _COLLECTION_SLUG_ATTR, slug) - return slug - - -async def ensure_session_artifact_collection( - session: Any, - *, - token: str | bool | None = None, -) -> str | None: - """Create/cache the per-session artifact collection without raising.""" - if session is None or not getattr(session, "session_id", None): - return None - token_value = token if token is not None else getattr(session, "hf_token", None) - if not token_value: - return None - - try: - api = HfApi(token=token_value) - return await asyncio.to_thread( - _ensure_collection_slug, - api, - session, - token=token_value, - ) - except Exception as e: - logger.warning( - "ML Intern session collection creation failed for %s: %s", - _safe_session_id(session), - e, - ) - return None - - -def start_session_artifact_collection_task( - session: Any, - *, - token: str | bool | None = None, -) -> asyncio.Task | None: - """Schedule best-effort collection creation for a newly started session.""" - if session is None or not getattr(session, "session_id", None): - return None - if getattr(session, _COLLECTION_SLUG_ATTR, None): - return None - - token_value = token if token is not None else getattr(session, "hf_token", None) - if not token_value: - return None - - existing = getattr(session, _COLLECTION_TASK_ATTR, None) - if isinstance(existing, asyncio.Task) and not existing.done(): - return existing - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - return None - - async def _run() -> None: - await ensure_session_artifact_collection(session, token=token_value) - - task = loop.create_task(_run()) - try: - setattr(session, _COLLECTION_TASK_ATTR, task) - except Exception: - logger.debug("Could not attach ML Intern collection task to session") - return task - - -def _add_to_collection( - api: Any, - session: Any, - repo_id: str, - repo_type: str, - *, - token: str | bool | None = None, -) -> None: - slug = _ensure_collection_slug(api, session, token=token) - if not slug: - return - api.add_collection_item( - collection_slug=slug, - item_id=repo_id, - item_type=repo_type, - note=( - f"Generated by ML Intern session {_safe_session_id(session)} " - f"on {session_artifact_date(session)}." - ), - exists_ok=True, - token=token, - ) - - -def register_hub_artifact( - api: Any, - repo_id: str, - repo_type: str = "model", - *, - session: Any = None, - token: str | bool | None = None, - extra_metadata: dict[str, Any] | None = None, - force: bool = False, -) -> bool: - """Tag, card, and collection-register a Hub artifact without raising.""" - if session is None or not repo_id: - return False - repo_type = repo_type or "model" - if repo_type not in SUPPORTED_REPO_TYPES: - return False - if is_sandbox_hub_repo(repo_id, repo_type): - return False - - key = _artifact_key(repo_id, repo_type) - remember_hub_artifact(session, repo_id, repo_type) - registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR) - if key in registered and not force: - return True - - token_value = token if token is not None else getattr(api, "token", None) - card_updated = False - collection_updated = False - try: - _update_repo_card( - api, - repo_id, - repo_type, - token=token_value, - extra_metadata=extra_metadata, - ) - card_updated = True - except Exception as e: - logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e) - - try: - _add_to_collection(api, session, repo_id, repo_type, token=token_value) - collection_updated = True - except Exception as e: - logger.debug("ML Intern collection update failed for %s: %s", repo_id, e) - - if card_updated and collection_updated: - registered.add(key) - return True - return False - - -def build_hub_artifact_sitecustomize(session: Any) -> str: - """Build standalone sitecustomize.py code for HF Jobs Python processes.""" - if session is None or not getattr(session, "session_id", None): - return "" - - session_id = _safe_session_id(session) - session_date = session_artifact_date(session) - collection_title = artifact_collection_title(session) - collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None) - - return ( - textwrap.dedent( - f""" - # Auto-generated by ML Intern. Best-effort Hub artifact metadata only. - def _install_ml_intern_artifact_hooks(): - import os - import re - import tempfile - from pathlib import Path - - try: - import huggingface_hub as _hub - from huggingface_hub import HfApi, hf_hub_download - from huggingface_hub.repocard import metadata_load, metadata_save - from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError - except Exception: - return - - session_id = {session_id!r} - session_date = {session_date!r} - collection_title = {collection_title!r} - tag = {ML_INTERN_TAG!r} - marker = {PROVENANCE_MARKER!r} - supported = {sorted(SUPPORTED_REPO_TYPES)!r} - sandbox_space_re = re.compile({_sandbox_space_name_pattern()!r}) - registering = False - collection_slug = {collection_slug!r} - registered = set() - usage_re = re.compile( - r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b", - re.IGNORECASE | re.MULTILINE, - ) - front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL) - - def _token(value=None, api=None): - if isinstance(value, str) and value: - return value - api_token = getattr(api, "token", None) - if isinstance(api_token, str) and api_token: - return api_token - return ( - os.environ.get("HF_TOKEN") - or os.environ.get("HUGGINGFACE_HUB_TOKEN") - or None - ) - - def _merge_tags(metadata): - metadata = dict(metadata or {{}}) - raw_tags = metadata.get("tags") - if raw_tags is None: - tags = [] - elif isinstance(raw_tags, str): - tags = [raw_tags] - elif isinstance(raw_tags, list): - tags = [str(item) for item in raw_tags] - else: - tags = [str(raw_tags)] - if tag not in tags: - tags.append(tag) - metadata["tags"] = tags - return metadata - - def _metadata_from_content(content): - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) / "README.md" - path.write_text(content or "", encoding="utf-8") - return metadata_load(path) or {{}} - - def _content_with_metadata(content, metadata): - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) / "README.md" - path.write_text(content or "", encoding="utf-8") - metadata_save(path, metadata) - return path.read_text(encoding="utf-8") - - def _body_without_metadata(content): - return front_matter_re.sub("", content or "", count=1).strip() - - def _append_section(content, section): - base = (content or "").rstrip() - if base: - return base + "\\n\\n" + section.strip() + "\\n" - return section.strip() + "\\n" - - def _provenance(repo_type): - label = {{"model": "model", "dataset": "dataset"}}.get( - repo_type, "Hub" - ) - return ( - marker - + "\\n## Generated by ML Intern\\n\\n" - + f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n" - + "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n" - + "- Source code: https://github.com/huggingface/ml-intern\\n" - ) - - def _usage(repo_id, repo_type): - if repo_type == "dataset": - return ( - "## Usage\\n\\n" - "```python\\n" - "from datasets import load_dataset\\n\\n" - f"dataset = load_dataset({{repo_id!r}})\\n" - "```\\n" - ) - return ( - "## Usage\\n\\n" - "```python\\n" - "from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n" - f"model_id = {{repo_id!r}}\\n" - "tokenizer = AutoTokenizer.from_pretrained(model_id)\\n" - "model = AutoModelForCausalLM.from_pretrained(model_id)\\n" - "```\\n\\n" - "For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n" - ) - - def _augment(content, repo_id, repo_type, extra_metadata=None): - metadata = _metadata_from_content(content or "") - if extra_metadata: - metadata = {{**extra_metadata, **metadata}} - updated = _content_with_metadata(content or "", _merge_tags(metadata)) - if not _body_without_metadata(updated): - updated = _append_section(updated, f"# {{repo_id}}") - if repo_type in {{"model", "dataset"}} and marker not in updated: - updated = _append_section(updated, _provenance(repo_type)) - if not usage_re.search(content or ""): - updated = _append_section(updated, _usage(repo_id, repo_type)) - return updated - - def _readme(api, repo_id, repo_type, token_value): - try: - path = hf_hub_download( - repo_id=repo_id, - filename="README.md", - repo_type=repo_type, - token=token_value, - ) - except (EntryNotFoundError, RepositoryNotFoundError): - return "" - return Path(path).read_text(encoding="utf-8") - - def _ensure_collection(api, token_value): - nonlocal collection_slug - if collection_slug: - return collection_slug - collection = api.create_collection( - title=collection_title, - description=( - f"Artifacts generated by ML Intern session {{session_id}} " - f"on {{session_date}}." - ), - private=True, - exists_ok=True, - token=token_value, - ) - collection_slug = getattr(collection, "slug", None) - return collection_slug - - def _register( - repo_id, - repo_type="model", - token_value=None, - extra_metadata=None, - force=False, - ): - nonlocal registering - if registering or not repo_id: - return - repo_type = repo_type or "model" - if repo_type not in supported: - return - if _is_sandbox_repo(repo_id, repo_type): - return - key = f"{{repo_type}}:{{repo_id}}" - if key in registered and not force: - return - registering = True - try: - token_value = _token(token_value) - api = HfApi(token=token_value) - try: - current = _readme(api, repo_id, repo_type, token_value) - updated = _augment( - current, repo_id, repo_type, extra_metadata=extra_metadata - ) - if updated != current: - _original_upload_file( - api, - path_or_fileobj=updated.encode("utf-8"), - path_in_repo="README.md", - repo_id=repo_id, - repo_type=repo_type, - token=token_value, - commit_message="Update ML Intern artifact metadata", - ) - except Exception: - pass - try: - slug = _ensure_collection(api, token_value) - if slug: - api.add_collection_item( - collection_slug=slug, - item_id=repo_id, - item_type=repo_type, - note=( - f"Generated by ML Intern session {{session_id}} " - f"on {{session_date}}." - ), - exists_ok=True, - token=token_value, - ) - except Exception: - pass - registered.add(key) - finally: - registering = False - - _original_create_repo = HfApi.create_repo - _original_upload_file = HfApi.upload_file - _original_upload_folder = getattr(HfApi, "upload_folder", None) - _original_create_commit = getattr(HfApi, "create_commit", None) - - def _repo_id(args, kwargs): - return kwargs.get("repo_id") or (args[0] if args else None) - - def _repo_type(kwargs): - return kwargs.get("repo_type") or "model" - - def _is_sandbox_repo(repo_id, repo_type): - if (repo_type or "model") != "space" or not repo_id: - return False - repo_name = str(repo_id).rsplit("/", 1)[-1] - return bool(sandbox_space_re.fullmatch(repo_name)) - - def _patched_create_repo(self, *args, **kwargs): - result = _original_create_repo(self, *args, **kwargs) - repo_id = _repo_id(args, kwargs) - repo_type = _repo_type(kwargs) - extra = None - if repo_type == "space" and kwargs.get("space_sdk"): - extra = {{"sdk": kwargs.get("space_sdk")}} - _register(repo_id, repo_type, _token(kwargs.get("token"), self), extra) - return result - - def _patched_upload_file(self, *args, **kwargs): - result = _original_upload_file(self, *args, **kwargs) - if not kwargs.get("create_pr"): - force = kwargs.get("path_in_repo") == "README.md" - _register( - kwargs.get("repo_id"), - _repo_type(kwargs), - _token(kwargs.get("token"), self), - force=force, - ) - return result - - def _patched_upload_folder(self, *args, **kwargs): - result = _original_upload_folder(self, *args, **kwargs) - if not kwargs.get("create_pr"): - _register( - kwargs.get("repo_id"), - _repo_type(kwargs), - _token(kwargs.get("token"), self), - force=True, - ) - return result - - def _patched_create_commit(self, *args, **kwargs): - result = _original_create_commit(self, *args, **kwargs) - if not kwargs.get("create_pr"): - _register( - _repo_id(args, kwargs), - _repo_type(kwargs), - _token(kwargs.get("token"), self), - force=True, - ) - return result - - HfApi.create_repo = _patched_create_repo - HfApi.upload_file = _patched_upload_file - if _original_upload_folder is not None: - HfApi.upload_folder = _patched_upload_folder - if _original_create_commit is not None: - HfApi.create_commit = _patched_create_commit - - def _patch_module_func(name, method_name): - original = getattr(_hub, name, None) - if original is None: - return - method = getattr(HfApi, method_name) - - def _patched(*args, **kwargs): - api = HfApi(token=_token(kwargs.get("token"))) - return method(api, *args, **kwargs) - - setattr(_hub, name, _patched) - - _patch_module_func("create_repo", "create_repo") - _patch_module_func("upload_file", "upload_file") - if _original_upload_folder is not None: - _patch_module_func("upload_folder", "upload_folder") - if _original_create_commit is not None: - _patch_module_func("create_commit", "create_commit") - - try: - _install_ml_intern_artifact_hooks() - except Exception: - pass - """ - ).strip() - + "\n" - ) - - -def wrap_shell_command_with_hub_artifact_bootstrap( - command: str, - session: Any, -) -> str: - """Prefix a shell command so child Python processes load Hub hooks.""" - sitecustomize = build_hub_artifact_sitecustomize(session) - if not sitecustomize or not command: - return command - - encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii") - bootstrap = ( - '_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" ' - f"&& printf %s {shlex.quote(encoded)} | base64 -d " - '> "$_ml_intern_artifacts_dir/sitecustomize.py" ' - '&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"' - ) - return f"{bootstrap}; {command}" diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py deleted file mode 100644 index f95695fb88ff2d6664f3a5be357c97f8b83131d8..0000000000000000000000000000000000000000 --- a/agent/core/llm_params.py +++ /dev/null @@ -1,270 +0,0 @@ -"""LiteLLM kwargs resolution for the model ids this agent accepts. - -Kept separate from ``agent_loop`` so tools (research, context compaction, etc.) -can import it without pulling in the whole agent loop / tool router and -creating circular imports. -""" - -import os - -from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token -from agent.core.local_models import ( - LOCAL_MODEL_API_KEY_DEFAULT, - LOCAL_MODEL_API_KEY_ENV, - LOCAL_MODEL_BASE_URL_ENV, - is_reserved_local_model_id, - local_model_name, - local_model_provider, -) - - -def _resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: - """Backward-compatible private wrapper used by tests and older imports.""" - return resolve_hf_router_token(session_hf_token) - - -def _patch_litellm_effort_validation() -> None: - """Neuter LiteLLM 1.83's hardcoded effort-level validation. - - Context: at ``litellm/llms/anthropic/chat/transformation.py:~1443`` the - Anthropic adapter validates ``output_config.effort ∈ {high, medium, - low, max}`` and gates ``max`` behind an ``_is_opus_4_6_model`` check - that only matches the substring ``opus-4-6`` / ``opus_4_6``. Result: - - * ``xhigh`` β€” valid on Anthropic's real API for Claude 4.7 β€” is - rejected pre-flight with "Invalid effort value: xhigh". - * ``max`` on Opus 4.7 is rejected with "effort='max' is only supported - by Claude Opus 4.6", even though Opus 4.7 accepts it in practice. - - We don't want to maintain a parallel model table, so we let the - Anthropic API itself be the validator: widen ``_is_opus_4_6_model`` - to also match ``opus-4-7``+ families, and drop the valid-effort-set - check entirely. If Anthropic rejects an effort level, we see a 400 - and the cascade walks down β€” exactly the behavior we want for any - future model family. - - Removable once litellm ships 1.83.8-stable (which merges PR #25867, - "Litellm day 0 opus 4.7 support") β€” see commit 0868a82 on their main - branch. Until then, this one-time patch is the escape hatch. - """ - try: - from litellm.llms.anthropic.chat import transformation as _t - except Exception: - return - - cfg = getattr(_t, "AnthropicConfig", None) - if cfg is None: - return - - original = getattr(cfg, "_is_opus_4_6_model", None) - if original is None or getattr(original, "_hf_agent_patched", False): - return - - def _widened(model: str) -> bool: - m = model.lower() - # Original 4.6 match plus any future Opus >= 4.6. We only need this - # to return True for families where "max" / "xhigh" are acceptable - # at the API; the cascade handles the case when they're not. - return any( - v in m - for v in ( - "opus-4-6", - "opus_4_6", - "opus-4.6", - "opus_4.6", - "opus-4-7", - "opus_4_7", - "opus-4.7", - "opus_4.7", - ) - ) - - _widened._hf_agent_patched = True # type: ignore[attr-defined] - cfg._is_opus_4_6_model = staticmethod(_widened) - - -_patch_litellm_effort_validation() - - -# Effort levels accepted on the wire. -# Anthropic (4.6+): low | medium | high | xhigh | max (output_config.effort) -# OpenAI direct: minimal | low | medium | high | xhigh (reasoning_effort top-level) -# HF router: low | medium | high (extra_body.reasoning_effort) -# -# We validate *shape* here and let the probe cascade walk down on rejection; -# we deliberately do NOT maintain a per-model capability table. -_ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"} -_OPENAI_EFFORTS = {"minimal", "low", "medium", "high", "xhigh"} -_HF_EFFORTS = {"low", "medium", "high"} - - -class UnsupportedEffortError(ValueError): - """The requested effort isn't valid for this provider's API surface. - - Raised synchronously before any network call so the probe cascade can - skip levels the provider can't accept (e.g. ``max`` on HF router). - """ - - -def _local_api_base(base_url: str) -> str: - base = base_url.strip().rstrip("/") - if base.endswith("/v1"): - return base - return f"{base}/v1" - - -def _resolve_local_model_params( - model_name: str, - reasoning_effort: str | None = None, - strict: bool = False, -) -> dict: - if reasoning_effort and strict: - raise UnsupportedEffortError( - "Local OpenAI-compatible endpoints don't accept reasoning_effort" - ) - - local_name = local_model_name(model_name) - if local_name is None: - raise ValueError(f"Unsupported local model id: {model_name}") - - provider = local_model_provider(model_name) - assert provider is not None - raw_base = ( - os.environ.get(provider["base_url_env"]) - or os.environ.get(LOCAL_MODEL_BASE_URL_ENV) - or provider["base_url_default"] - ) - api_key = ( - os.environ.get(provider["api_key_env"]) - or os.environ.get(LOCAL_MODEL_API_KEY_ENV) - or LOCAL_MODEL_API_KEY_DEFAULT - ) - return { - "model": f"openai/{local_name}", - "api_base": _local_api_base(raw_base), - "api_key": api_key, - } - - -def _resolve_llm_params( - model_name: str, - session_hf_token: str | None = None, - reasoning_effort: str | None = None, - strict: bool = False, -) -> dict: - """ - Build LiteLLM kwargs for a given model id. - - β€’ ``anthropic/`` β€” native thinking config. We bypass LiteLLM's - ``reasoning_effort`` β†’ ``thinking`` mapping (which lags new Claude - releases like 4.7 and sends the wrong API shape). Instead we pass - both ``thinking={"type": "adaptive"}`` and ``output_config= - {"effort": }`` as top-level kwargs β€” LiteLLM's Anthropic - adapter forwards unknown top-level kwargs into the request body - verbatim (confirmed by live probe; ``extra_body`` does NOT work - here because Anthropic's API rejects it as "Extra inputs are not - permitted"). This is the stable API for 4.6 and 4.7. Older - extended-thinking models that only accept ``thinking.type.enabled`` - will reject this; the probe's cascade catches that and falls back - to no thinking. - - β€’ ``openai/`` β€” ``reasoning_effort`` forwarded as a top-level - kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``. - - β€’ ``ollama/``, ``vllm/``, ``lm_studio/``, and - ``llamacpp/`` β€” local OpenAI-compatible endpoints. The id prefix - selects a configurable localhost base URL, and the model suffix is sent - to LiteLLM as ``openai/``. These endpoints don't receive - ``reasoning_effort``. - - β€’ Anything else is treated as a HuggingFace router id. We hit the - auto-routing OpenAI-compatible endpoint at - ``https://router.huggingface.co/v1``. The id can be bare or carry an - HF routing suffix (``:fastest`` / ``:cheapest`` / ``:``). - A leading ``huggingface/`` is stripped. ``reasoning_effort`` is - forwarded via ``extra_body`` (LiteLLM's OpenAI adapter refuses it as - a top-level kwarg for non-OpenAI models). "minimal" normalizes to - "low". - - ``strict=True`` raises ``UnsupportedEffortError`` when the requested - effort isn't in the provider's accepted set, instead of silently - dropping it. The probe cascade uses strict mode so it can walk down - (``max`` β†’ ``xhigh`` β†’ ``high`` …) without making an API call. Regular - runtime callers leave ``strict=False``, so a stale cached effort - can't crash a turn β€” it just doesn't get sent. - - Token precedence (first non-empty wins): - 1. INFERENCE_TOKEN env β€” shared key on the hosted Space (inference is - free for users, billed to the Space owner via ``X-HF-Bill-To``). - 2. session.hf_token β€” the user's own token (CLI / OAuth / cache file). - 3. huggingface_hub cache β€” ``HF_TOKEN`` / ``HUGGING_FACE_HUB_TOKEN`` / - local ``hf auth login`` cache. - """ - if model_name.startswith("anthropic/"): - params: dict = {"model": model_name} - if reasoning_effort: - level = reasoning_effort - if level == "minimal": - level = "low" - if level not in _ANTHROPIC_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"Anthropic doesn't accept effort={level!r}" - ) - else: - # Adaptive thinking + output_config.effort is the stable - # Anthropic API for Claude 4.6 / 4.7. Both kwargs are - # passed top-level: LiteLLM forwards unknown params into - # the request body for Anthropic, so ``output_config`` - # reaches the API. ``extra_body`` does NOT work here β€” - # Anthropic rejects it as "Extra inputs are not - # permitted". - params["thinking"] = {"type": "adaptive"} - params["output_config"] = {"effort": level} - return params - - if model_name.startswith("bedrock/"): - # LiteLLM routes ``bedrock/...`` through the Converse adapter, which - # picks up AWS credentials from the standard env vars - # (``AWS_ACCESS_KEY_ID`` / ``AWS_SECRET_ACCESS_KEY`` / ``AWS_REGION``). - # The Anthropic thinking/effort shape is not forwarded through Converse - # the same way, so we leave it off for now. - return {"model": model_name} - - if model_name.startswith("openai/"): - params = {"model": model_name} - if reasoning_effort: - if reasoning_effort not in _OPENAI_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"OpenAI doesn't accept effort={reasoning_effort!r}" - ) - else: - params["reasoning_effort"] = reasoning_effort - return params - - if is_reserved_local_model_id(model_name): - raise ValueError(f"Unsupported local model id: {model_name}") - - if local_model_provider(model_name) is not None: - return _resolve_local_model_params(model_name, reasoning_effort, strict) - - hf_model = model_name.removeprefix("huggingface/") - api_key = _resolve_hf_router_token(session_hf_token) - params = { - "model": f"openai/{hf_model}", - "api_base": "https://router.huggingface.co/v1", - "api_key": api_key, - } - if bill_to := get_hf_bill_to(): - params["extra_headers"] = {"X-HF-Bill-To": bill_to} - if reasoning_effort: - hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort - if hf_level not in _HF_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"HF router doesn't accept effort={hf_level!r}" - ) - else: - params["extra_body"] = {"reasoning_effort": hf_level} - return params diff --git a/agent/core/local_models.py b/agent/core/local_models.py deleted file mode 100644 index 9f8a9491d635dd3892388ebfdd0f8384ac78144f..0000000000000000000000000000000000000000 --- a/agent/core/local_models.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Helpers for CLI local OpenAI-compatible model ids.""" - -LOCAL_MODEL_PROVIDERS: dict[str, dict[str, str]] = { - "ollama/": { - "base_url_env": "OLLAMA_BASE_URL", - "base_url_default": "http://localhost:11434", - "api_key_env": "OLLAMA_API_KEY", - }, - "vllm/": { - "base_url_env": "VLLM_BASE_URL", - "base_url_default": "http://localhost:8000", - "api_key_env": "VLLM_API_KEY", - }, - "lm_studio/": { - "base_url_env": "LMSTUDIO_BASE_URL", - "base_url_default": "http://127.0.0.1:1234", - "api_key_env": "LMSTUDIO_API_KEY", - }, - "llamacpp/": { - "base_url_env": "LLAMACPP_BASE_URL", - "base_url_default": "http://localhost:8080", - "api_key_env": "LLAMACPP_API_KEY", - }, -} - -LOCAL_MODEL_PREFIXES = tuple(LOCAL_MODEL_PROVIDERS) -RESERVED_LOCAL_MODEL_PREFIXES = ("openai-compat/",) -LOCAL_MODEL_BASE_URL_ENV = "LOCAL_LLM_BASE_URL" -LOCAL_MODEL_API_KEY_ENV = "LOCAL_LLM_API_KEY" -LOCAL_MODEL_API_KEY_DEFAULT = "sk-local-no-key-required" - - -def local_model_provider(model_id: str) -> dict[str, str] | None: - """Return provider config for a local model id, if it uses a local prefix.""" - for prefix, config in LOCAL_MODEL_PROVIDERS.items(): - if model_id.startswith(prefix): - return config - return None - - -def local_model_name(model_id: str) -> str | None: - """Return the backend model name with the local provider prefix removed.""" - for prefix in LOCAL_MODEL_PREFIXES: - if model_id.startswith(prefix): - name = model_id[len(prefix) :] - return name or None - return None - - -def is_local_model_id(model_id: str) -> bool: - """Return True for non-empty, whitespace-free local model ids.""" - if not model_id or any(char.isspace() for char in model_id): - return False - return local_model_name(model_id) is not None - - -def is_reserved_local_model_id(model_id: str) -> bool: - """Return True for local-style prefixes intentionally not supported.""" - return model_id.startswith(RESERVED_LOCAL_MODEL_PREFIXES) diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py deleted file mode 100644 index 34eaccdd1f127253bec68b4ccdd1159c7a3c4a0a..0000000000000000000000000000000000000000 --- a/agent/core/model_switcher.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Model-switching logic for the interactive CLI's ``/model`` command. - -Split out of ``agent.main`` so the REPL dispatcher stays focused on input -parsing. Exposes: - -* ``SUGGESTED_MODELS`` β€” the short list shown by ``/model`` with no arg. -* ``is_valid_model_id`` β€” loose format check on user input. -* ``probe_and_switch_model`` β€” async: checks routing, fires a 1-token - probe to resolve the effort cascade, then commits the switch (or - rejects it on hard error). - -The probe's cascade lives in ``agent.core.effort_probe``; this module -glues it to CLI output + session state. -""" - -from __future__ import annotations - -import asyncio - -from litellm import acompletion - -from agent.core.effort_probe import ProbeInconclusive, probe_effort -from agent.core.llm_params import _resolve_llm_params -from agent.core.local_models import ( - LOCAL_MODEL_PREFIXES, - is_local_model_id, - is_reserved_local_model_id, -) - - -# Suggested models shown by `/model` (not a gate). Users can paste any HF -# model id (e.g. "MiniMaxAI/MiniMax-M2.7") or an `anthropic/` / `openai/` -# prefix for direct API access. For HF ids, append ":fastest" / -# ":cheapest" / ":preferred" / ":" to override the default -# routing policy (auto = fastest with failover). -SUGGESTED_MODELS = [ - {"id": "openai/gpt-5.5", "label": "GPT-5.5"}, - {"id": "openai/gpt-5.4", "label": "GPT-5.4"}, - {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"}, - {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"}, - { - "id": "bedrock/us.anthropic.claude-opus-4-6-v1", - "label": "Claude Opus 4.6 via Bedrock", - }, - {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, - {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, - {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, - {"id": "deepseek-ai/DeepSeek-V4-Pro:deepinfra", "label": "DeepSeek V4 Pro"}, -] - - -_ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} -_DIRECT_PREFIXES = ("anthropic/", "openai/", *LOCAL_MODEL_PREFIXES) -_LOCAL_PROBE_TIMEOUT = 15.0 - - -def is_valid_model_id(model_id: str) -> bool: - """Loose format check β€” lets users pick any model id. - - Accepts: - β€’ anthropic/ - β€’ openai/ - β€’ ollama/, vllm/, lm_studio/, llamacpp/ - β€’ /[:] (HF router; tag = provider or policy) - β€’ huggingface//[:] (same, accepts legacy prefix) - - Actual availability is verified against the HF router catalog on - switch, and by the provider on the probe's ping call. - """ - if not model_id: - return False - if is_local_model_id(model_id): - return True - if is_reserved_local_model_id(model_id): - return False - if any(model_id.startswith(prefix) for prefix in LOCAL_MODEL_PREFIXES): - return False - if "/" not in model_id: - return False - head = model_id.split(":", 1)[0] - parts = head.split("/") - return len(parts) >= 2 and all(parts) - - -def _print_hf_routing_info(model_id: str, console) -> bool: - """Show HF router catalog info (providers, price, context, tool support) - for an HF-router model id. Returns ``True`` to signal the caller can - proceed with the switch, ``False`` to indicate a hard problem the user - should notice before we fire the effort probe. - - Anthropic / OpenAI ids return ``True`` without printing anything β€” - the probe below covers "does this model exist". - """ - if model_id.startswith(_DIRECT_PREFIXES): - return True - - from agent.core import hf_router_catalog as cat - - bare, _, tag = model_id.partition(":") - info = cat.lookup(bare) - if info is None: - console.print( - f"[bold red]Warning:[/bold red] '{bare}' isn't in the HF router " - "catalog. Checking anyway β€” first call may fail." - ) - suggestions = cat.fuzzy_suggest(bare) - if suggestions: - console.print(f"[dim]Did you mean: {', '.join(suggestions)}[/dim]") - return True - - live = info.live_providers - if not live: - console.print( - f"[bold red]Warning:[/bold red] '{bare}' has no live providers " - "right now. First call will likely fail." - ) - return True - - if tag and tag not in _ROUTING_POLICIES: - matched = [p for p in live if p.provider == tag] - if not matched: - names = ", ".join(p.provider for p in live) - console.print( - f"[bold red]Warning:[/bold red] provider '{tag}' doesn't serve " - f"'{bare}'. Live providers: {names}. Checking anyway." - ) - - if not info.any_supports_tools: - console.print( - f"[bold red]Warning:[/bold red] no provider for '{bare}' advertises " - "tool-call support. This agent relies on tool calls β€” expect errors." - ) - - if tag in _ROUTING_POLICIES: - policy = tag - elif tag: - policy = f"pinned to {tag}" - else: - policy = "auto (fastest)" - console.print(f" [dim]routing: {policy}[/dim]") - for p in live: - price = ( - f"${p.input_price:g}/${p.output_price:g} per M tok" - if p.input_price is not None and p.output_price is not None - else "price n/a" - ) - ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a" - tools = "tools" if p.supports_tools else "no tools" - console.print(f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]") - return True - - -def print_model_listing(config, console) -> None: - """Render the default ``/model`` (no-arg) view: current + suggested.""" - current = config.model_name if config else "" - console.print("[bold]Current model:[/bold]") - console.print(f" {current}") - console.print("\n[bold]Suggested:[/bold]") - for m in SUGGESTED_MODELS: - marker = " [dim]<-- current[/dim]" if m["id"] == current else "" - console.print(f" {m['id']} [dim]({m['label']})[/dim]{marker}") - console.print( - "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" - "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" - "Use 'anthropic/' or 'openai/' for direct API access.\n" - "Use 'ollama/', 'vllm/', 'lm_studio/', or " - "'llamacpp/' for local OpenAI-compatible endpoints.[/dim]" - ) - - -def print_invalid_id(arg: str, console) -> None: - console.print(f"[bold red]Invalid model id format:[/bold red] {arg}") - console.print( - "[dim]Expected:\n" - " β€’ /[:tag] (HF router β€” paste from huggingface.co)\n" - " β€’ anthropic/\n" - " β€’ openai/\n" - " β€’ ollama/ | vllm/ | lm_studio/ | llamacpp/[/dim]" - ) - - -async def _probe_local_model(model_id: str) -> None: - params = _resolve_llm_params(model_id) - await asyncio.wait_for( - acompletion( - messages=[{"role": "user", "content": "ping"}], - max_tokens=1, - stream=False, - **params, - ), - timeout=_LOCAL_PROBE_TIMEOUT, - ) - - -async def probe_and_switch_model( - model_id: str, - config, - session, - console, - hf_token: str | None, -) -> None: - """Validate model+effort with a 1-token ping, cache the effective effort, - then commit the switch. - - Three visible outcomes: - - * βœ“ ``effort: `` β€” model accepted the preferred effort (or a - fallback from the cascade; the note explains if so) - * βœ“ ``effort: off`` β€” model doesn't support thinking; we'll strip it - * βœ— hard error (auth, model-not-found, quota) β€” we reject the switch - and keep the current model so the user isn't stranded - - For non-local models, transient errors (5xx, timeout) complete the switch - with a yellow warning; the next real call re-surfaces the error if it's - persistent. Local models reject every probe error, including timeouts, and - keep the current model. - """ - if is_local_model_id(model_id): - console.print(f"[dim]checking local model {model_id}...[/dim]") - try: - await _probe_local_model(model_id) - except Exception as e: - console.print(f"[bold red]Switch failed:[/bold red] {e}") - console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") - return - - _commit_switch(model_id, config, session, effective=None, cache=True) - console.print( - f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" - ) - return - - preference = config.reasoning_effort - if not _print_hf_routing_info(model_id, console): - return - - if not preference: - # Nothing to validate with a ping that we couldn't validate on the - # first real call just as cheaply. Skip the probe entirely. - _commit_switch(model_id, config, session, effective=None, cache=False) - console.print( - f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" - ) - return - - console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]") - try: - outcome = await probe_effort(model_id, preference, hf_token, session=session) - except ProbeInconclusive as e: - _commit_switch(model_id, config, session, effective=None, cache=False) - console.print( - f"[yellow]Model switched to {model_id}[/yellow] " - f"[dim](couldn't validate: {e}; will verify on first message)[/dim]" - ) - return - except Exception as e: - # Hard persistent error β€” auth, unknown model, quota. Don't switch. - console.print(f"[bold red]Switch failed:[/bold red] {e}") - console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") - return - - _commit_switch( - model_id, - config, - session, - effective=outcome.effective_effort, - cache=True, - ) - effort_label = outcome.effective_effort or "off" - suffix = f" β€” {outcome.note}" if outcome.note else "" - console.print( - f"[green]Model switched to {model_id}[/green] " - f"[dim](effort: {effort_label}{suffix}, {outcome.elapsed_ms}ms)[/dim]" - ) - - -def _commit_switch(model_id, config, session, effective, cache: bool) -> None: - """Apply the switch to the session (or bare config if no session yet). - - ``effective`` is the probe's resolved effort; ``cache=True`` stores it - in the session's per-model cache so real calls use the resolved level - instead of re-probing. ``cache=False`` (inconclusive probe / effort - off) leaves the cache untouched β€” next call falls back to preference. - """ - if session is not None: - session.update_model(model_id) - if cache: - session.model_effective_effort[model_id] = effective - else: - session.model_effective_effort.pop(model_id, None) - else: - config.model_name = model_id diff --git a/agent/core/prompt_caching.py b/agent/core/prompt_caching.py deleted file mode 100644 index b30edd9fc4845738c08e972fdab712bf2ae3988d..0000000000000000000000000000000000000000 --- a/agent/core/prompt_caching.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Anthropic prompt caching breakpoints for outgoing LLM requests. - -Caching is GA on Anthropic's API and natively supported by litellm >=1.83 -via ``cache_control`` blocks. We apply two breakpoints (out of 4 allowed): - - 1. The tool block β€” caches all tool definitions as a single prefix. - 2. The system message β€” caches the rendered system prompt. - -Together these cover the ~4-5K static tokens that were being re-billed on -every turn. Subsequent turns within the 5-minute TTL hit cache_read pricing -(~10% of input cost) instead of full input. - -Non-Anthropic models (HF router, OpenAI) are passed through unchanged. -""" - -from typing import Any - - -def with_prompt_caching( - messages: list[Any], - tools: list[dict] | None, - model_name: str | None, -) -> tuple[list[Any], list[dict] | None]: - """Return (messages, tools) with cache_control breakpoints for Anthropic. - - No-op for non-Anthropic models. Original objects are not mutated; a fresh - list with replaced first message and last tool is returned, so callers - that share the underlying ``ContextManager.items`` list don't see their - persisted history rewritten. - """ - if not model_name or "anthropic" not in model_name: - return messages, tools - - if tools: - new_tools = list(tools) - last = dict(new_tools[-1]) - last["cache_control"] = {"type": "ephemeral"} - new_tools[-1] = last - tools = new_tools - - if messages: - first = messages[0] - role = ( - first.get("role") - if isinstance(first, dict) - else getattr(first, "role", None) - ) - if role == "system": - content = ( - first.get("content") - if isinstance(first, dict) - else getattr(first, "content", None) - ) - if isinstance(content, str) and content: - cached_block = [ - { - "type": "text", - "text": content, - "cache_control": {"type": "ephemeral"}, - } - ] - new_first = {"role": "system", "content": cached_block} - messages = [new_first] + list(messages[1:]) - - return messages, tools diff --git a/agent/core/redact.py b/agent/core/redact.py deleted file mode 100644 index 8978942c8a027b56e51acdd0f6485d4e9e0fbbf2..0000000000000000000000000000000000000000 --- a/agent/core/redact.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Secret scrubbing for session trajectories before upload. - -Users frequently paste HF / API / GitHub tokens into the chat, or scripts echo -them via env dumps. This module applies regex-based redaction to any string -value found recursively in a trajectory payload. The goal is best-effort β€” -strict formats are matched; we won't catch free-form leaks like "my password -is hunter2". -""" - -from __future__ import annotations - -import re -from typing import Any - -# Each entry: (compiled regex, replacement placeholder). -# Patterns are conservative: they only match tokens with the canonical prefix -# and a minimum body length so we don't paint over normal text. -_PATTERNS: list[tuple[re.Pattern, str]] = [ - # Hugging Face tokens: hf_[A-Za-z0-9]{30,} - (re.compile(r"hf_[A-Za-z0-9]{30,}"), "[REDACTED_HF_TOKEN]"), - # Anthropic: sk-ant-[A-Za-z0-9_\-]{20,} - (re.compile(r"sk-ant-[A-Za-z0-9_\-]{20,}"), "[REDACTED_ANTHROPIC_KEY]"), - # OpenAI: sk-[A-Za-z0-9]{40,} (legacy + proj keys) - (re.compile(r"sk-(?!ant-)[A-Za-z0-9_\-]{40,}"), "[REDACTED_OPENAI_KEY]"), - # GitHub classic PATs: ghp_, gho_, ghu_, ghs_, ghr_ followed by 36+ chars - (re.compile(r"gh[pousr]_[A-Za-z0-9]{36,}"), "[REDACTED_GITHUB_TOKEN]"), - # GitHub fine-grained PATs: github_pat_ - (re.compile(r"github_pat_[A-Za-z0-9_]{36,}"), "[REDACTED_GITHUB_TOKEN]"), - # AWS access key IDs: AKIA / ASIA + 16 uppercase alnum - (re.compile(r"\b(?:AKIA|ASIA)[A-Z0-9]{16}\b"), "[REDACTED_AWS_KEY_ID]"), - # Generic 'Bearer ' header values - (re.compile(r"(?i)bearer\s+[A-Za-z0-9_\-\.=]{20,}"), "Bearer [REDACTED]"), -] - -# Env-var-like exports: we scrub the value but keep the name so callers can -# still see which secret was referenced. Covers `KEY=value` and `KEY: value` -# when the key looks secret-y. -_SECRETY_NAMES = re.compile( - r"(?i)\b(HF_TOKEN|HUGGINGFACEHUB_API_TOKEN|ANTHROPIC_API_KEY|OPENAI_API_KEY|" - r"GITHUB_TOKEN|AWS_SECRET_ACCESS_KEY|AWS_ACCESS_KEY_ID|PASSWORD|SECRET|API_KEY)" - r"\s*[:=]\s*([^\s\"']+)" -) - - -def scrub_string(s: str) -> str: - """Apply all redaction patterns to a single string. Safe on non-strings.""" - if not isinstance(s, str) or not s: - return s - out = s - for pat, repl in _PATTERNS: - out = pat.sub(repl, out) - out = _SECRETY_NAMES.sub(lambda m: f"{m.group(1)}=[REDACTED]", out) - return out - - -def scrub(obj: Any) -> Any: - """Recursively scrub every string value in a nested dict/list structure. - - Returns a new object β€” inputs are not mutated.""" - if isinstance(obj, str): - return scrub_string(obj) - if isinstance(obj, dict): - return {k: scrub(v) for k, v in obj.items()} - if isinstance(obj, list): - return [scrub(v) for v in obj] - if isinstance(obj, tuple): - return tuple(scrub(v) for v in obj) - return obj diff --git a/agent/core/session.py b/agent/core/session.py index fb08c75f8a4d5aff0e86cdfb7bb276d585b93176..78e9fa7d7b32d39e41bbbc5a8c1ac02f069c238d 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -1,7 +1,6 @@ import asyncio import json import logging -import os import subprocess import sys import uuid @@ -13,45 +12,47 @@ from typing import Any, Optional from agent.config import Config from agent.context_manager.manager import ContextManager -from agent.messaging.gateway import NotificationGateway -from agent.messaging.models import NotificationRequest logger = logging.getLogger(__name__) +# Local max-token lookup β€” avoids litellm.get_max_tokens() which can hang +# on network calls for certain providers (known litellm issue). +_MAX_TOKENS_MAP: dict[str, int] = { + # Anthropic + "anthropic/claude-opus-4-6": 200_000, + "anthropic/claude-opus-4-5-20251101": 200_000, + "anthropic/claude-sonnet-4-5-20250929": 200_000, + "anthropic/claude-sonnet-4-20250514": 200_000, + "anthropic/claude-haiku-3-5-20241022": 200_000, + "anthropic/claude-3-5-sonnet-20241022": 200_000, + "anthropic/claude-3-opus-20240229": 200_000, + "huggingface/fireworks-ai/MiniMaxAI/MiniMax-M2.5": 200_000, + "huggingface/novita/minimax/minimax-m2.1": 196_608, + "huggingface/novita/moonshotai/kimi-k2.5": 262_144, + "huggingface/novita/zai-org/glm-5": 200_000, +} _DEFAULT_MAX_TOKENS = 200_000 -_TURN_COMPLETE_NOTIFICATION_CHARS = 39000 def _get_max_tokens_safe(model_name: str) -> int: - """Return the max input-context tokens for a model. - - Primary source: ``litellm.get_model_info(model)['max_input_tokens']`` β€” - LiteLLM maintains an upstream catalog that knows Claude Opus 4.6 is - 1M, GPT-5 is 272k, Sonnet 4.5 is 200k, and so on. Strips any HF routing - suffix / huggingface/ prefix so tagged ids ('moonshotai/Kimi-K2.6:cheapest') - look up the bare model. Falls back to a conservative 200k default for - models not in the catalog (typically HF-router-only models). - """ - from litellm import get_model_info - - candidates = [model_name] - stripped = model_name.removeprefix("huggingface/").split(":", 1)[0] - if stripped != model_name: - candidates.append(stripped) - for candidate in candidates: - try: - info = get_model_info(candidate) - max_input = info.get("max_input_tokens") if info else None - if isinstance(max_input, int) and max_input > 0: - return max_input - except Exception: - continue - logger.info( - "No litellm.get_model_info entry for %s, falling back to %d", - model_name, - _DEFAULT_MAX_TOKENS, - ) - return _DEFAULT_MAX_TOKENS + """Return the max context window for a model without network calls.""" + tokens = _MAX_TOKENS_MAP.get(model_name) + if tokens: + return tokens + # Fallback: try litellm but with a short timeout via threading + try: + from litellm import get_max_tokens + + result = get_max_tokens(model_name) + if result and isinstance(result, int): + return result + logger.warning( + f"get_max_tokens returned {result} for {model_name}, using default" + ) + return _DEFAULT_MAX_TOKENS + except Exception as e: + logger.warning(f"get_max_tokens failed for {model_name}, using default: {e}") + return _DEFAULT_MAX_TOKENS class OpType(Enum): @@ -67,7 +68,6 @@ class OpType(Enum): class Event: event_type: str data: Optional[dict[str, Any]] = None - seq: Optional[int] = None class Session: @@ -79,31 +79,19 @@ class Session: def __init__( self, event_queue: asyncio.Queue, - config: Config, + config: Config | None = None, tool_router=None, context_manager: ContextManager | None = None, hf_token: str | None = None, local_mode: bool = False, stream: bool = True, - notification_gateway: NotificationGateway | None = None, - notification_destinations: list[str] | None = None, - defer_turn_complete_notification: bool = False, - session_id: str | None = None, - user_id: str | None = None, - hf_username: str | None = None, - persistence_store: Any | None = None, ): self.hf_token: Optional[str] = hf_token - self.user_id: Optional[str] = user_id - self.hf_username: Optional[str] = hf_username - self.persistence_store = persistence_store self.tool_router = tool_router self.stream = stream - if config is None: - raise ValueError("Session requires a Config") tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else [] self.context_manager = context_manager or ContextManager( - model_max_tokens=_get_max_tokens_safe(config.model_name), + max_context=_get_max_tokens_safe(config.model_name), compact_size=0.1, untouched_messages=5, tool_specs=tool_specs, @@ -111,48 +99,26 @@ class Session: local_mode=local_mode, ) self.event_queue = event_queue - self.session_id = session_id or str(uuid.uuid4()) - self.config = config + self.session_id = str(uuid.uuid4()) + self.config = config or Config( + model_name="anthropic/claude-sonnet-4-5-20250929", + ) self.is_running = True self._cancelled = asyncio.Event() self.pending_approval: Optional[dict[str, Any]] = None self.sandbox = None - self.sandbox_hardware: Optional[str] = None - self.sandbox_preload_task: Optional[asyncio.Task] = None - self.sandbox_preload_error: Optional[str] = None - self.sandbox_preload_cancel_event: Any | None = None self._running_job_ids: set[str] = set() # HF job IDs currently executing - self.notification_gateway = notification_gateway - self.notification_destinations = list(notification_destinations or []) - self.defer_turn_complete_notification = defer_turn_complete_notification - self.auto_approval_enabled: bool = False - self.auto_approval_cost_cap_usd: float | None = None - self.auto_approval_estimated_spend_usd: float = 0.0 # Session trajectory logging self.logged_events: list[dict] = [] self.session_start_time = datetime.now().isoformat() self.turn_count: int = 0 self.last_auto_save_turn: int = 0 - # Stable local save path so heartbeat saves overwrite one file instead - # of spamming session_logs/. ``_last_heartbeat_ts`` is owned by - # ``agent.core.telemetry.HeartbeatSaver`` and lazily initialised there. - self._local_save_path: Optional[str] = None - self._last_heartbeat_ts: Optional[float] = None - - # Per-model probed reasoning-effort cache. Populated by the probe - # on /model switch, read by ``effective_effort_for`` below. Keys are - # raw model ids (including any ``:tag``). Values: - # str β†’ the effort level to send (may be a downgrade from the - # preference, e.g. "high" when user asked for "max") - # None β†’ model rejected all efforts in the cascade; send no - # thinking params at all - # Key absent β†’ not probed yet; fall back to the raw preference. - self.model_effective_effort: dict[str, str | None] = {} - self.context_manager.on_message_added = self._schedule_trace_message async def send_event(self, event: Event) -> None: """Send event back to client and log to trajectory""" + await self.event_queue.put(event) + # Log event to trajectory self.logged_events.append( { @@ -161,147 +127,6 @@ class Session: "data": event.data, } ) - if self.persistence_store is not None: - try: - event.seq = await self.persistence_store.append_event( - self.session_id, event.event_type, event.data - ) - except Exception as e: - logger.debug("Event persistence failed for %s: %s", self.session_id, e) - - await self.event_queue.put(event) - await self._enqueue_auto_notification_requests(event) - - # Mid-turn heartbeat flush (owned by telemetry module). - from agent.core.telemetry import HeartbeatSaver - - HeartbeatSaver.maybe_fire(self) - - def _schedule_trace_message(self, message: Any) -> None: - """Best-effort append-only trace save for SFT/KPI export.""" - if self.persistence_store is None: - return - try: - payload = message.model_dump(mode="json") - except Exception: - return - try: - loop = asyncio.get_running_loop() - except RuntimeError: - return - source = str(payload.get("role") or "message") - loop.create_task( - self.persistence_store.append_trace_message( - self.session_id, payload, source=source - ) - ) - - def set_notification_destinations(self, destinations: list[str]) -> None: - """Replace the session's opted-in auto-notification destinations.""" - deduped: list[str] = [] - seen: set[str] = set() - for destination in destinations: - if destination not in seen: - deduped.append(destination) - seen.add(destination) - self.notification_destinations = deduped - - async def send_deferred_turn_complete_notification(self, event: Event) -> None: - if event.event_type != "turn_complete": - return - await self._enqueue_auto_notification_requests( - event, - include_deferred_turn_complete=True, - ) - - async def _enqueue_auto_notification_requests( - self, - event: Event, - include_deferred_turn_complete: bool = False, - ) -> None: - if self.notification_gateway is None: - return - if not self.notification_destinations: - return - auto_events = set(self.config.messaging.auto_event_types) - if event.event_type not in auto_events: - return - if ( - self.defer_turn_complete_notification - and event.event_type == "turn_complete" - and not include_deferred_turn_complete - ): - return - - requests = self._build_auto_notification_requests(event) - for request in requests: - await self.notification_gateway.enqueue(request) - - def _build_auto_notification_requests( - self, event: Event - ) -> list[NotificationRequest]: - metadata = { - "session_id": self.session_id, - "model": self.config.model_name, - "event_type": event.event_type, - } - - title: str | None = None - message: str | None = None - severity = "info" - data = event.data or {} - if event.event_type == "approval_required": - tools = data.get("tools", []) - tool_names = [] - for tool in tools if isinstance(tools, list) else []: - if isinstance(tool, dict): - tool_name = str(tool.get("tool") or "").strip() - if tool_name and tool_name not in tool_names: - tool_names.append(tool_name) - count = len(tools) if isinstance(tools, list) else 0 - title = "Agent approval required" - message = ( - f"Session {self.session_id} is waiting for approval " - f"for {count} tool call(s)." - ) - if tool_names: - message += " Tools: " + ", ".join(tool_names) - severity = "warning" - elif event.event_type == "error": - title = "Agent error" - error = str(data.get("error") or "Unknown error") - message = f"Session {self.session_id} hit an error.\n{error[:500]}" - severity = "error" - elif event.event_type == "turn_complete": - title = "Agent task complete" - summary = str(data.get("final_response") or "").strip() - if summary: - summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS] - message = ( - f"Session {self.session_id} completed successfully.\n{summary}" - ) - else: - message = f"Session {self.session_id} completed successfully." - severity = "success" - - if message is None: - return [] - - requests: list[NotificationRequest] = [] - for destination in self.notification_destinations: - if not self.config.messaging.can_auto_send(destination): - continue - requests.append( - NotificationRequest( - destination=destination, - title=title, - message=message, - severity=severity, - metadata=metadata, - event_type=event.event_type, - ) - ) - return requests def cancel(self) -> None: """Signal cancellation to the running agent loop.""" @@ -318,54 +143,7 @@ class Session: def update_model(self, model_name: str) -> None: """Switch the active model and update the context window limit.""" self.config.model_name = model_name - self.context_manager.model_max_tokens = _get_max_tokens_safe(model_name) - - def set_auto_approval_policy( - self, *, enabled: bool, cost_cap_usd: float | None - ) -> None: - self.auto_approval_enabled = bool(enabled) - self.auto_approval_cost_cap_usd = cost_cap_usd - - def add_auto_approval_estimated_spend(self, amount_usd: float | None) -> None: - if amount_usd is None or amount_usd <= 0: - return - self.auto_approval_estimated_spend_usd = round( - self.auto_approval_estimated_spend_usd + float(amount_usd), 4 - ) - - @property - def auto_approval_remaining_usd(self) -> float | None: - if self.auto_approval_cost_cap_usd is None: - return None - return round( - max( - 0.0, - self.auto_approval_cost_cap_usd - - self.auto_approval_estimated_spend_usd, - ), - 4, - ) - - def auto_approval_policy_summary(self) -> dict[str, Any]: - return { - "enabled": self.auto_approval_enabled, - "cost_cap_usd": self.auto_approval_cost_cap_usd, - "estimated_spend_usd": round(self.auto_approval_estimated_spend_usd, 4), - "remaining_usd": self.auto_approval_remaining_usd, - } - - def effective_effort_for(self, model_name: str) -> str | None: - """Resolve the effort level to actually send for ``model_name``. - - Returns the probed result when we have one (may be ``None`` meaning - "model doesn't do thinking, strip it"), else the raw preference. - Unknown-model case falls back to the preference so a stale cache - from a prior ``/model`` can't poison research sub-calls that use a - different model id. - """ - if model_name in self.model_effective_effort: - return self.model_effective_effort[model_name] - return self.config.reasoning_effort + self.context_manager.max_context = _get_max_tokens_safe(model_name) def increment_turn(self) -> None: """Increment turn counter (called after each user interaction)""" @@ -389,31 +167,13 @@ class Session: def get_trajectory(self) -> dict: """Serialize complete session trajectory for logging""" - tools: list = [] - if self.tool_router is not None: - try: - tools = self.tool_router.get_tool_specs_for_llm() or [] - except Exception: - tools = [] - # Sum per-call cost from llm_call events so analyzers don't have to - # walk the events array themselves. Each `llm_call` event already - # carries cost_usd from `agent.core.telemetry.record_llm_call`. - total_cost_usd = sum( - float((e.get("data") or {}).get("cost_usd") or 0.0) - for e in self.logged_events - if e.get("event_type") == "llm_call" - ) return { "session_id": self.session_id, - "user_id": self.user_id, - "hf_username": self.hf_username, "session_start_time": self.session_start_time, "session_end_time": datetime.now().isoformat(), "model_name": self.config.model_name, - "total_cost_usd": total_cost_usd, "messages": [msg.model_dump() for msg in self.context_manager.items], "events": self.logged_events, - "tools": tools, } def save_trajectory_local( @@ -439,43 +199,16 @@ class Session: trajectory = self.get_trajectory() - # Scrub secrets at save time so session_logs/ never holds raw - # tokens on disk β€” a log aggregator, crash dump, or filesystem - # snapshot between heartbeats would otherwise leak them. - try: - from agent.core.redact import scrub - - for key in ("messages", "events", "tools"): - if key in trajectory: - trajectory[key] = scrub(trajectory[key]) - except Exception as _e: - logger.debug("Redact-on-save failed (non-fatal): %s", _e) - # Add upload metadata trajectory["upload_status"] = upload_status trajectory["upload_url"] = dataset_url trajectory["last_save_time"] = datetime.now().isoformat() - # Reuse one stable path per session so heartbeat saves overwrite - # the same file instead of creating a new timestamped file every - # minute. The timestamp in the filename is kept for first-save - # ordering; subsequent saves just rewrite that file. - if self._local_save_path and Path(self._local_save_path).parent == log_dir: - filepath = Path(self._local_save_path) - else: - filename = ( - f"session_{self.session_id}_" - f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - ) - filepath = log_dir / filename - self._local_save_path = str(filepath) - - # Atomic-ish write: stage to .tmp then rename so a crash mid-write - # doesn't leave a truncated JSON that breaks the retry scanner. - tmp_path = filepath.with_suffix(filepath.suffix + ".tmp") - with open(tmp_path, "w") as f: + filename = f"session_{self.session_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + filepath = log_dir / filename + + with open(filepath, "w") as f: json.dump(trajectory, f, indent=2) - tmp_path.replace(filepath) return str(filepath) except Exception as e: @@ -502,174 +235,62 @@ class Session: logger.error(f"Failed to update local save status: {e}") return False - def _personal_trace_repo_id(self) -> Optional[str]: - """Resolve the per-user trace repo id from config + HF username. + def save_and_upload_detached(self, repo_id: str) -> Optional[str]: + """ + Save session locally and spawn detached subprocess for upload (fire-and-forget) + + Args: + repo_id: HuggingFace dataset repo ID - Returns ``None`` when sharing is disabled, the user is anonymous, - or the template is missing β€” caller skips the personal upload in - those cases. + Returns: + Path to local save file """ - if not getattr(self.config, "share_traces", False): - return None - hf_user = self.hf_username or self.user_id - if not hf_user: - return None - template = getattr(self.config, "personal_trace_repo_template", None) - if not template: - return None - try: - return template.format(hf_user=hf_user) - except (KeyError, IndexError): - logger.debug("personal_trace_repo_template format failed: %r", template) + # Save locally first (fast, synchronous) + local_path = self.save_trajectory_local(upload_status="pending") + if not local_path: return None - def _spawn_uploader( - self, - action: str, - target: str, - repo_id: str, - *, - format: str, - token_env: Optional[str], - private: bool, - token_value: Optional[str] = None, - ) -> None: - """Fire-and-forget spawn of ``session_uploader.py`` with the given args.""" + # Spawn detached subprocess for upload (fire-and-forget) try: uploader_script = Path(__file__).parent / "session_uploader.py" - cmd = [ - sys.executable, - str(uploader_script), - action, - target, - repo_id, - "--format", - format, - "--private", - "true" if private else "false", - ] - if token_env: - cmd.extend(["--token-env", token_env]) - - env = os.environ.copy() - if token_value: - env["_ML_INTERN_PERSONAL_TOKEN"] = token_value + # Use Popen with detached process subprocess.Popen( - cmd, + [sys.executable, str(uploader_script), "upload", local_path, repo_id], stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, - env=env, start_new_session=True, # Detach from parent ) except Exception as e: logger.warning(f"Failed to spawn upload subprocess: {e}") - def save_and_upload_detached(self, repo_id: str) -> Optional[str]: - """ - Save session locally and spawn detached subprocess(es) for upload - (fire-and-forget). - - Always uploads to the shared org dataset (``repo_id``) in the - single-row format used by the KPI scheduler. When - ``config.share_traces`` is enabled and a username is known, also - uploads to the user's personal private dataset in Claude Code JSONL - format so the HF Agent Trace Viewer auto-renders it. - - Args: - repo_id: HuggingFace dataset repo ID for the org/KPI upload. - - Returns: - Path to local save file - """ - local_path = self.save_trajectory_local(upload_status="pending") - if not local_path: - return None - - self._spawn_uploader( - "upload", - local_path, - repo_id, - format="row", - token_env=None, # default org token chain - private=False, - ) - - personal_repo = self._personal_trace_repo_id() - if personal_repo: - # User's own HF_TOKEN write-scoped to their namespace. - self._spawn_uploader( - "upload", - local_path, - personal_repo, - format="claude_code", - token_env="HF_TOKEN", - token_value=self.hf_token, - private=True, - ) - return local_path @staticmethod def retry_failed_uploads_detached( - directory: str = "session_logs", - repo_id: Optional[str] = None, - *, - personal_repo_id: Optional[str] = None, + directory: str = "session_logs", repo_id: Optional[str] = None ) -> None: """ - Spawn detached subprocess(es) to retry failed/pending uploads - (fire-and-forget). + Spawn detached subprocess to retry failed/pending uploads (fire-and-forget) Args: directory: Directory containing session logs - repo_id: Target dataset repo ID for the shared org/KPI upload. - personal_repo_id: Per-user dataset for Claude-Code-format - retries. ``None`` skips the personal retry pass. + repo_id: Target dataset repo ID """ - if not repo_id and not personal_repo_id: + if not repo_id: return try: uploader_script = Path(__file__).parent / "session_uploader.py" - if repo_id: - subprocess.Popen( - [ - sys.executable, - str(uploader_script), - "retry", - directory, - repo_id, - "--format", - "row", - ], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - start_new_session=True, - ) - - if personal_repo_id: - subprocess.Popen( - [ - sys.executable, - str(uploader_script), - "retry", - directory, - personal_repo_id, - "--format", - "claude_code", - "--token-env", - "HF_TOKEN", - "--private", - "true", - ], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - start_new_session=True, - ) + # Spawn detached subprocess for retry + subprocess.Popen( + [sys.executable, str(uploader_script), "retry", directory, repo_id], + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, # Detach from parent + ) except Exception as e: logger.warning(f"Failed to spawn retry subprocess: {e}") diff --git a/agent/core/session_persistence.py b/agent/core/session_persistence.py deleted file mode 100644 index e12467211b16fe12ec75fbf5b60edb5ee54f4072..0000000000000000000000000000000000000000 --- a/agent/core/session_persistence.py +++ /dev/null @@ -1,509 +0,0 @@ -"""Optional durable session persistence for the hosted backend. - -The public CLI must keep working without MongoDB. This module therefore -exposes one small async store interface and returns a no-op implementation -unless ``MONGODB_URI`` is configured and reachable. -""" - -from __future__ import annotations - -import logging -import os -from datetime import UTC, datetime -from typing import Any - -from bson import BSON -from pymongo import AsyncMongoClient, DeleteMany, ReturnDocument, UpdateOne -from pymongo.errors import DuplicateKeyError, InvalidDocument, PyMongoError - -logger = logging.getLogger(__name__) - -SCHEMA_VERSION = 1 -MAX_BSON_BYTES = 15 * 1024 * 1024 - - -def _now() -> datetime: - return datetime.now(UTC) - - -def _doc_id(session_id: str, idx: int) -> str: - return f"{session_id}:{idx}" - - -def _safe_message_doc(message: dict[str, Any]) -> dict[str, Any]: - """Return a Mongo-safe message document payload. - - Mongo's hard document limit is 16 MB. We stay below that and store an - explicit marker rather than failing the whole snapshot for one huge tool log. - """ - try: - if len(BSON.encode({"message": message})) <= MAX_BSON_BYTES: - return message - except (InvalidDocument, OverflowError): - pass - return { - "role": "tool", - "content": ( - "[SYSTEM: A single persisted message exceeded MongoDB's document " - "size/encoding limit and was replaced by this marker.]" - ), - "ml_intern_persistence_error": "message_too_large_or_invalid", - } - - -class NoopSessionStore: - """Async no-op store used when Mongo is not configured.""" - - enabled = False - - async def init(self) -> None: - return None - - async def close(self) -> None: - return None - - async def upsert_session(self, **_: Any) -> None: - return None - - async def save_snapshot(self, **_: Any) -> None: - return None - - async def load_session(self, *_: Any, **__: Any) -> dict[str, Any] | None: - return None - - async def list_sessions(self, *_: Any, **__: Any) -> list[dict[str, Any]]: - return [] - - async def soft_delete_session(self, *_: Any, **__: Any) -> None: - return None - - async def update_session_fields(self, *_: Any, **__: Any) -> None: - return None - - async def append_event(self, *_: Any, **__: Any) -> int | None: - return None - - async def load_events_after(self, *_: Any, **__: Any) -> list[dict[str, Any]]: - return [] - - async def append_trace_message(self, *_: Any, **__: Any) -> int | None: - return None - - async def get_quota(self, *_: Any, **__: Any) -> int | None: - return None - - async def try_increment_quota(self, *_: Any, **__: Any) -> int | None: - return None - - async def refund_quota(self, *_: Any, **__: Any) -> None: - return None - - async def mark_pro_seen(self, *_: Any, **__: Any) -> dict[str, Any] | None: - return None - - -class MongoSessionStore(NoopSessionStore): - """MongoDB-backed session store.""" - - enabled = True - - def __init__(self, uri: str, db_name: str) -> None: - self.uri = uri - self.db_name = db_name - self.enabled = False - self.client: AsyncMongoClient | None = None - self.db = None - - async def init(self) -> None: - try: - self.client = AsyncMongoClient(self.uri, serverSelectionTimeoutMS=3000) - self.db = self.client[self.db_name] - await self.client.admin.command("ping") - await self._create_indexes() - self.enabled = True - logger.info("Mongo session persistence enabled (db=%s)", self.db_name) - except Exception as e: - logger.warning("Mongo session persistence disabled: %s", e) - self.enabled = False - if self.client is not None: - await self.client.close() - self.client = None - self.db = None - - async def close(self) -> None: - if self.client is not None: - await self.client.close() - self.client = None - self.db = None - - async def _create_indexes(self) -> None: - if self.db is None: - return - await self.db.sessions.create_index( - [("user_id", 1), ("visibility", 1), ("updated_at", -1)] - ) - await self.db.sessions.create_index( - [("visibility", 1), ("status", 1), ("last_active_at", -1)] - ) - await self.db.session_messages.create_index( - [("session_id", 1), ("idx", 1)], unique=True - ) - await self.db.session_events.create_index( - [("session_id", 1), ("seq", 1)], unique=True - ) - await self.db.session_trace_messages.create_index( - [("session_id", 1), ("seq", 1)], unique=True - ) - await self.db.session_trace_messages.create_index([("created_at", -1)]) - await self.db.pro_users.create_index([("first_seen_pro_at", -1)]) - - def _ready(self) -> bool: - return bool(self.enabled and self.db is not None) - - async def upsert_session( - self, - *, - session_id: str, - user_id: str, - model: str, - title: str | None = None, - surface: str = "frontend", - created_at: datetime | None = None, - runtime_state: str = "idle", - status: str = "active", - message_count: int = 0, - turn_count: int = 0, - pending_approval: list[dict[str, Any]] | None = None, - claude_counted: bool = False, - notification_destinations: list[str] | None = None, - auto_approval_enabled: bool = False, - auto_approval_cost_cap_usd: float | None = None, - auto_approval_estimated_spend_usd: float = 0.0, - ) -> None: - if not self._ready(): - return - now = _now() - await self.db.sessions.update_one( - {"_id": session_id}, - { - "$setOnInsert": { - "_id": session_id, - "session_id": session_id, - "user_id": user_id, - "surface": surface, - "created_at": created_at or now, - "schema_version": SCHEMA_VERSION, - "visibility": "live", - }, - "$set": { - "title": title, - "model": model, - "status": status, - "runtime_state": runtime_state, - "updated_at": now, - "last_active_at": now, - "message_count": message_count, - "turn_count": turn_count, - "pending_approval": pending_approval or [], - "claude_counted": claude_counted, - "notification_destinations": notification_destinations or [], - "auto_approval_enabled": auto_approval_enabled, - "auto_approval_cost_cap_usd": auto_approval_cost_cap_usd, - "auto_approval_estimated_spend_usd": auto_approval_estimated_spend_usd, - }, - }, - upsert=True, - ) - - async def save_snapshot( - self, - *, - session_id: str, - user_id: str, - model: str, - messages: list[dict[str, Any]], - title: str | None = None, - runtime_state: str = "idle", - status: str = "active", - turn_count: int = 0, - pending_approval: list[dict[str, Any]] | None = None, - claude_counted: bool = False, - created_at: datetime | None = None, - notification_destinations: list[str] | None = None, - auto_approval_enabled: bool = False, - auto_approval_cost_cap_usd: float | None = None, - auto_approval_estimated_spend_usd: float = 0.0, - ) -> None: - if not self._ready(): - return - now = _now() - await self.upsert_session( - session_id=session_id, - user_id=user_id, - model=model, - title=title, - created_at=created_at, - runtime_state=runtime_state, - status=status, - message_count=len(messages), - turn_count=turn_count, - pending_approval=pending_approval, - claude_counted=claude_counted, - notification_destinations=notification_destinations, - auto_approval_enabled=auto_approval_enabled, - auto_approval_cost_cap_usd=auto_approval_cost_cap_usd, - auto_approval_estimated_spend_usd=auto_approval_estimated_spend_usd, - ) - ops: list[Any] = [] - for idx, raw in enumerate(messages): - ops.append( - UpdateOne( - {"_id": _doc_id(session_id, idx)}, - { - "$set": { - "session_id": session_id, - "idx": idx, - "message": _safe_message_doc(raw), - "updated_at": now, - }, - "$setOnInsert": {"created_at": now}, - }, - upsert=True, - ) - ) - ops.append( - DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}}) - ) - try: - if ops: - await self.db.session_messages.bulk_write(ops, ordered=False) - except PyMongoError as e: - logger.warning("Failed to persist session %s snapshot: %s", session_id, e) - - async def load_session( - self, session_id: str, *, include_deleted: bool = False - ) -> dict[str, Any] | None: - if not self._ready(): - return None - meta = await self.db.sessions.find_one({"_id": session_id}) - if not meta: - return None - if meta.get("visibility") == "deleted" and not include_deleted: - return None - cursor = self.db.session_messages.find({"session_id": session_id}).sort( - "idx", 1 - ) - messages = [row.get("message") async for row in cursor] - return {"metadata": meta, "messages": messages} - - async def list_sessions( - self, user_id: str, *, include_deleted: bool = False - ) -> list[dict[str, Any]]: - if not self._ready(): - return [] - query: dict[str, Any] = {"user_id": user_id} - if user_id == "dev": - query = {} - if not include_deleted: - query["visibility"] = {"$ne": "deleted"} - cursor = self.db.sessions.find(query).sort("updated_at", -1) - return [row async for row in cursor] - - async def soft_delete_session(self, session_id: str) -> None: - if not self._ready(): - return - await self.db.sessions.update_one( - {"_id": session_id}, - { - "$set": { - "visibility": "deleted", - "runtime_state": "idle", - "updated_at": _now(), - } - }, - ) - - async def update_session_fields(self, session_id: str, **fields: Any) -> None: - if not self._ready() or not fields: - return - fields["updated_at"] = _now() - await self.db.sessions.update_one({"_id": session_id}, {"$set": fields}) - - async def _next_seq(self, counter_id: str) -> int: - doc = await self.db.counters.find_one_and_update( - {"_id": counter_id}, - {"$inc": {"seq": 1}}, - upsert=True, - return_document=ReturnDocument.AFTER, - ) - return int(doc["seq"]) - - async def append_event( - self, session_id: str, event_type: str, data: dict[str, Any] | None - ) -> int | None: - if not self._ready(): - return None - try: - seq = await self._next_seq(f"event:{session_id}") - await self.db.session_events.insert_one( - { - "_id": _doc_id(session_id, seq), - "session_id": session_id, - "seq": seq, - "event_type": event_type, - "data": data or {}, - "created_at": _now(), - } - ) - return seq - except PyMongoError as e: - logger.debug("Failed to append event for %s: %s", session_id, e) - return None - - async def load_events_after( - self, session_id: str, after_seq: int = 0 - ) -> list[dict[str, Any]]: - if not self._ready(): - return [] - cursor = self.db.session_events.find( - {"session_id": session_id, "seq": {"$gt": int(after_seq or 0)}} - ).sort("seq", 1) - return [row async for row in cursor] - - async def append_trace_message( - self, session_id: str, message: dict[str, Any], source: str = "message" - ) -> int | None: - if not self._ready(): - return None - try: - seq = await self._next_seq(f"trace:{session_id}") - await self.db.session_trace_messages.insert_one( - { - "_id": _doc_id(session_id, seq), - "session_id": session_id, - "seq": seq, - "role": message.get("role"), - "message": _safe_message_doc(message), - "source": source, - "created_at": _now(), - } - ) - return seq - except PyMongoError as e: - logger.debug("Failed to append trace message for %s: %s", session_id, e) - return None - - async def get_quota(self, user_id: str, day: str) -> int | None: - if not self._ready(): - return None - doc = await self.db.claude_quotas.find_one({"_id": f"{user_id}:{day}"}) - return int(doc.get("count", 0)) if doc else 0 - - async def try_increment_quota(self, user_id: str, day: str, cap: int) -> int | None: - if not self._ready(): - return None - key = f"{user_id}:{day}" - now = _now() - try: - await self.db.claude_quotas.insert_one( - { - "_id": key, - "user_id": user_id, - "day": day, - "count": 1, - "updated_at": now, - } - ) - return 1 - except DuplicateKeyError: - pass - doc = await self.db.claude_quotas.find_one_and_update( - {"_id": key, "count": {"$lt": cap}}, - {"$inc": {"count": 1}, "$set": {"updated_at": now}}, - return_document=ReturnDocument.AFTER, - ) - return int(doc["count"]) if doc else None - - async def refund_quota(self, user_id: str, day: str) -> None: - if not self._ready(): - return - await self.db.claude_quotas.update_one( - {"_id": f"{user_id}:{day}", "count": {"$gt": 0}}, - {"$inc": {"count": -1}, "$set": {"updated_at": _now()}}, - ) - - async def mark_pro_seen( - self, user_id: str, *, is_pro: bool - ) -> dict[str, Any] | None: - """Track per-user Pro state and detect freeβ†’Pro conversions. - - Returns ``{"converted": True, "first_seen_at": ..."}`` exactly once - per user β€” the first time we see them as Pro after having recorded - them as non-Pro at least once. Otherwise returns ``None``. - - Storing ``ever_non_pro`` lets us distinguish "user joined as Pro" - (no conversion) from "user upgraded" (conversion). The atomic - ``find_one_and_update`` on a guarded filter makes the conversion - emit at-most-once even under concurrent requests. - """ - if not self._ready() or not user_id: - return None - now = _now() - set_fields: dict[str, Any] = {"last_seen_at": now, "is_pro": bool(is_pro)} - if not is_pro: - set_fields["ever_non_pro"] = True - try: - await self.db.pro_users.update_one( - {"_id": user_id}, - { - "$setOnInsert": {"_id": user_id, "first_seen_at": now}, - "$set": set_fields, - }, - upsert=True, - ) - except PyMongoError as e: - logger.debug("mark_pro_seen upsert failed for %s: %s", user_id, e) - return None - - if not is_pro: - return None - - try: - doc = await self.db.pro_users.find_one_and_update( - { - "_id": user_id, - "ever_non_pro": True, - "first_seen_pro_at": {"$exists": False}, - }, - {"$set": {"first_seen_pro_at": now}}, - return_document=ReturnDocument.AFTER, - ) - except PyMongoError as e: - logger.debug("mark_pro_seen conversion check failed for %s: %s", user_id, e) - return None - - if not doc: - return None - return { - "converted": True, - "first_seen_at": (doc.get("first_seen_at") or now).isoformat(), - } - - -_store: NoopSessionStore | MongoSessionStore | None = None - - -def get_session_store() -> NoopSessionStore | MongoSessionStore: - global _store - if _store is None: - uri = os.environ.get("MONGODB_URI") - db_name = os.environ.get("MONGODB_DB", "ml-intern") - _store = MongoSessionStore(uri, db_name) if uri else NoopSessionStore() - return _store - - -def _reset_store_for_tests( - store: NoopSessionStore | MongoSessionStore | None = None, -) -> None: - global _store - _store = store diff --git a/agent/core/session_uploader.py b/agent/core/session_uploader.py index 404fd224563cdae3d91c2b93e05e8306ee91fb7e..ef2f9496d87f832489010f9a9529c538d939bedb 100644 --- a/agent/core/session_uploader.py +++ b/agent/core/session_uploader.py @@ -3,454 +3,32 @@ Standalone script for uploading session trajectories to HuggingFace. This runs as a separate process to avoid blocking the main agent. Uses individual file uploads to avoid race conditions. - -Two formats are supported: - -* ``row`` β€” single-line JSONL row used by the existing org telemetry/KPI - pipeline (``smolagents/ml-intern-sessions``). Compatible with - ``backend/kpis_scheduler.py``. -* ``claude_code`` β€” one event per line in the Claude Code JSONL schema, - auto-detected by the HF Agent Trace Viewer - (https://huggingface.co/changelog/agent-trace-viewer). Used for the - per-user private dataset (default ``{hf_user}/ml-intern-sessions``). """ -import argparse -import hashlib import json import os import sys from datetime import datetime from pathlib import Path -from typing import Any from dotenv import load_dotenv load_dotenv() -# Token resolution for the org KPI dataset. Fallback chain (least-privilege -# first) β€” matches backend/kpis_scheduler.py so one write-scoped token on the -# Space covers every telemetry dataset. Never hardcode tokens in source. -_ORG_TOKEN_FALLBACK_CHAIN = ( - "HF_SESSION_UPLOAD_TOKEN", - "HF_TOKEN", - "HF_ADMIN_TOKEN", -) -_PERSONAL_TOKEN_ENV = "_ML_INTERN_PERSONAL_TOKEN" - - -def _resolve_token(token_env: str | None) -> str: - """Resolve an HF token from env. ``token_env`` overrides the fallback chain.""" - if token_env == "HF_TOKEN": - try: - from agent.core.hf_tokens import resolve_hf_token - - return ( - resolve_hf_token( - os.environ.get(_PERSONAL_TOKEN_ENV), - os.environ.get("HF_TOKEN"), - ) - or "" - ) - except Exception: - token = os.environ.get(_PERSONAL_TOKEN_ENV) or os.environ.get("HF_TOKEN") - return token or "" - - if token_env: - return os.environ.get(token_env, "") or "" - for var in _ORG_TOKEN_FALLBACK_CHAIN: - val = os.environ.get(var) - if val: - return val - return "" - - -def _scrub(obj: Any) -> Any: - """Best-effort regex scrub for HF tokens / API keys before upload.""" - try: - from agent.core.redact import scrub # type: ignore - except Exception: - # Fallback for environments where the agent package isn't importable - # (shouldn't happen in our subprocess, but be defensive). - import importlib.util - - _spec = importlib.util.spec_from_file_location( - "_redact", - Path(__file__).parent / "redact.py", - ) - _mod = importlib.util.module_from_spec(_spec) - _spec.loader.exec_module(_mod) # type: ignore - scrub = _mod.scrub - return scrub(obj) - - -def _msg_uuid(session_id: str, role: str, idx: int) -> str: - """Deterministic UUID-shaped id for a Claude Code message. - - Uses sha1 of ``session_id::role::idx`` so re-uploads/heartbeats keep the - parent/child chain stable. Same convention as the example dataset - https://huggingface.co/datasets/clem/hf-coding-tools-traces. - """ - digest = hashlib.sha1(f"{session_id}::{role}::{idx}".encode("utf-8")).hexdigest() - # Format like a UUID for visual familiarity (32 hex chars w/ dashes). - return ( - f"{digest[0:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}" - ) - - -def _content_to_text(content: Any) -> str: - """Best-effort flatten of a litellm/openai content field to plain text.""" - if content is None: - return "" - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for block in content: - if isinstance(block, dict): - text = block.get("text") - if isinstance(text, str): - parts.append(text) - else: - # Unknown content block β€” keep round-trippable representation. - parts.append(json.dumps(block, default=str)) - else: - parts.append(str(block)) - return "\n".join(parts) - return str(content) - - -def _parse_tool_args(raw: Any) -> Any: - """Tool call arguments arrive as a JSON-encoded string from LLMs.""" - if isinstance(raw, dict): - return raw - if isinstance(raw, str): - try: - return json.loads(raw) - except (json.JSONDecodeError, TypeError): - return {"_raw": raw} - return raw - - -def to_claude_code_jsonl(trajectory: dict) -> list[dict]: - """Convert an internal trajectory dict to Claude Code JSONL events. - - Schema reference (per the HF Agent Trace Viewer auto-detector): - - {"type":"user","message":{"role":"user","content":"..."}, - "uuid":"...","parentUuid":null,"sessionId":"...","timestamp":"..."} - {"type":"assistant", - "message":{"role":"assistant","model":"...", - "content":[{"type":"text","text":"..."}, - {"type":"tool_use","id":"...","name":"...","input":{...}}]}, - "uuid":"...","parentUuid":"","sessionId":"...","timestamp":"..."} - {"type":"user","message":{"role":"user", - "content":[{"type":"tool_result", - "tool_use_id":"...","content":"..."}]}, - "uuid":"...","parentUuid":"","sessionId":"...","timestamp":"..."} - - System messages are skipped (they're not part of the viewer schema and - contain large prompts that pollute the trace viewer UI). - """ - session_id = trajectory["session_id"] - model_name = trajectory.get("model_name") or "" - fallback_timestamp = ( - trajectory.get("session_start_time") or datetime.now().isoformat() - ) - messages: list[dict] = trajectory.get("messages") or [] - - out: list[dict] = [] - parent_uuid: str | None = None - - for idx, msg in enumerate(messages): - if not isinstance(msg, dict): - continue - role = msg.get("role") - if role == "system": - continue - timestamp = msg.get("timestamp") or fallback_timestamp - - if role == "user": - content = _content_to_text(msg.get("content")) - event_uuid = _msg_uuid(session_id, "user", idx) - out.append( - { - "type": "user", - "message": {"role": "user", "content": content}, - "uuid": event_uuid, - "parentUuid": parent_uuid, - "sessionId": session_id, - "timestamp": timestamp, - } - ) - parent_uuid = event_uuid - - elif role == "assistant": - content_text = _content_to_text(msg.get("content")) - content_blocks: list[dict] = [] - if content_text: - content_blocks.append({"type": "text", "text": content_text}) - for tc in msg.get("tool_calls") or []: - if not isinstance(tc, dict): - continue - fn = tc.get("function") or {} - content_blocks.append( - { - "type": "tool_use", - "id": tc.get("id") or "", - "name": fn.get("name") or "", - "input": _parse_tool_args(fn.get("arguments")), - } - ) - if not content_blocks: - # Edge case: empty assistant turn (shouldn't normally happen, - # but skip rather than emit an empty content array which - # confuses the viewer). - continue - event_uuid = _msg_uuid(session_id, "assistant", idx) - out.append( - { - "type": "assistant", - "message": { - "role": "assistant", - "model": model_name, - "content": content_blocks, - }, - "uuid": event_uuid, - "parentUuid": parent_uuid, - "sessionId": session_id, - "timestamp": timestamp, - } - ) - parent_uuid = event_uuid - - elif role == "tool": - tool_call_id = msg.get("tool_call_id") or "" - content_text = _content_to_text(msg.get("content")) - event_uuid = _msg_uuid(session_id, "tool", idx) - out.append( - { - "type": "user", - "message": { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tool_call_id, - "content": content_text, - } - ], - }, - "uuid": event_uuid, - "parentUuid": parent_uuid, - "sessionId": session_id, - "timestamp": timestamp, - } - ) - parent_uuid = event_uuid - - return out - - -def _scrub_session_for_upload(data: dict) -> dict: - """Best-effort scrub of transcript fields before any upload temp file.""" - scrubbed = dict(data) - scrubbed["messages"] = _scrub(data.get("messages") or []) - scrubbed["events"] = _scrub(data.get("events") or []) - scrubbed["tools"] = _scrub(data.get("tools") or []) - return scrubbed - - -def _write_row_payload(data: dict, tmp_path: str) -> None: - """Single-row JSONL (existing format) β€” used by KPI scheduler.""" - scrubbed = _scrub_session_for_upload(data) - session_row = { - "session_id": data["session_id"], - "user_id": data.get("user_id"), - "session_start_time": data["session_start_time"], - "session_end_time": data["session_end_time"], - "model_name": data["model_name"], - "total_cost_usd": data.get("total_cost_usd"), - "messages": json.dumps(scrubbed["messages"]), - "events": json.dumps(scrubbed["events"]), - "tools": json.dumps(scrubbed["tools"]), - } - - with open(tmp_path, "w") as tmp: - json.dump(session_row, tmp) - - -def _write_claude_code_payload(data: dict, tmp_path: str) -> None: - """Multi-line JSONL in Claude Code schema for the HF trace viewer.""" - # Scrub before conversion so secrets never reach the upload temp file. - scrubbed = _scrub_session_for_upload(data) - events = to_claude_code_jsonl(scrubbed) - with open(tmp_path, "w") as tmp: - for event in events: - tmp.write(json.dumps(event)) - tmp.write("\n") - - -def _status_field(format: str) -> str: - """Per-format upload status field on the local trajectory file.""" - return "personal_upload_status" if format == "claude_code" else "upload_status" - - -def _url_field(format: str) -> str: - return "personal_upload_url" if format == "claude_code" else "upload_url" - - -def _read_session_file(session_file: str) -> dict: - """Read a local session file while respecting uploader file locks.""" - import fcntl - - with open(session_file, "r") as f: - fcntl.flock(f, fcntl.LOCK_SH) - try: - return json.load(f) - finally: - fcntl.flock(f, fcntl.LOCK_UN) - - -def _update_upload_status( - session_file: str, - status_key: str, - url_key: str, - status: str, - dataset_url: str | None = None, -) -> None: - """Atomically update only this uploader's status fields. - - The org and personal uploaders run as separate processes against the same - local session JSON file. Re-read under an exclusive lock so one uploader - cannot clobber fields written by the other. - """ - import fcntl - - with open(session_file, "r+") as f: - fcntl.flock(f, fcntl.LOCK_EX) - try: - data = json.load(f) - data[status_key] = status - if dataset_url is not None: - data[url_key] = dataset_url - data["last_save_time"] = datetime.now().isoformat() - f.seek(0) - json.dump(data, f, indent=2) - f.truncate() - f.flush() - os.fsync(f.fileno()) - finally: - fcntl.flock(f, fcntl.LOCK_UN) - - -def dataset_card_readme(repo_id: str) -> str: - """Dataset card for personal ML Intern session trace repos.""" - return """--- -pretty_name: "ML Intern Session Traces" -language: -- en -license: other -task_categories: -- text-generation -tags: -- agent-traces -- coding-agent -- ml-intern -- session-traces -- claude-code -- hf-agent-trace-viewer -configs: -- config_name: default - data_files: - - split: train - path: "sessions/**/*.jsonl" ---- - -# ML Intern session traces - -This dataset contains ML Intern coding agent session traces uploaded from local -ML Intern runs. The traces are stored as JSON Lines files under `sessions/`, -with one file per session. - -## Links - -- ML Intern demo: https://smolagents-ml-intern.hf.space -- ML Intern CLI: https://github.com/huggingface/ml-intern - -## Data description - -Each `*.jsonl` file contains a single ML Intern session converted to a -Claude-Code-style event stream for the Hugging Face Agent Trace Viewer. Entries -can include user messages, assistant messages, tool calls, tool results, model -metadata, and timestamps. - -Session files are written to paths of the form: - -```text -sessions/YYYY-MM-DD/.jsonl -``` - -## Redaction and review - -**WARNING: no comprehensive redaction or human review has been performed for this dataset.** - -ML Intern applies automated best-effort scrubbing for common secret patterns -such as Hugging Face, Anthropic, OpenAI, GitHub, and AWS tokens before upload. -This is not a privacy guarantee. - -These traces may contain sensitive information, including prompts, code, -terminal output, file paths, repository names, private task context, tool -outputs, or other data from the local development environment. Treat every -session as potentially sensitive. - -Do not make this dataset public unless you have manually inspected the uploaded -sessions and are comfortable sharing their full contents. - -## Limitations - -Coding agent transcripts can include private or off-topic content, failed -experiments, credentials accidentally pasted by a user, and outputs copied from -local files or services. Use with appropriate caution, especially before -changing repository visibility. -""" - - -def _upload_dataset_card(api: Any, repo_id: str, token: str, format: str) -> None: - """Create/update a README for personal trace datasets.""" - if format != "claude_code": - return - - api.upload_file( - path_or_fileobj=dataset_card_readme(repo_id).encode("utf-8"), - path_in_repo="README.md", - repo_id=repo_id, - repo_type="dataset", - token=token, - commit_message="Update dataset card", - ) +# Token for session uploads β€” loaded from env var (never hardcode tokens in source) +_SESSION_TOKEN = os.environ.get("HF_SESSION_UPLOAD_TOKEN", "") def upload_session_as_file( - session_file: str, - repo_id: str, - max_retries: int = 3, - format: str = "row", - token_env: str | None = None, - private: bool = False, + session_file: str, repo_id: str, max_retries: int = 3 ) -> bool: - """Upload a single session as an individual JSONL file (no race conditions). + """ + Upload a single session as an individual JSONL file (no race conditions) Args: session_file: Path to local session JSON file repo_id: HuggingFace dataset repo ID max_retries: Number of retry attempts - format: ``row`` (default, KPI-compatible) or ``claude_code`` (HF - Agent Trace Viewer compatible). - token_env: Name of the env var holding the HF token. ``None`` falls - back to the org-token chain (``HF_SESSION_UPLOAD_TOKEN`` β†’ - ``HF_TOKEN`` β†’ ``HF_ADMIN_TOKEN``). - private: When creating the repo for the first time, mark it private. Returns: True if successful, False otherwise @@ -461,60 +39,72 @@ def upload_session_as_file( print("Error: huggingface_hub library not available", file=sys.stderr) return False - status_key = _status_field(format) - url_key = _url_field(format) - try: - data = _read_session_file(session_file) + # Load session data + with open(session_file, "r") as f: + data = json.load(f) - # Skip if already uploaded for this format. - if data.get(status_key) == "success": + # Check if already uploaded + upload_status = data.get("upload_status") + if upload_status == "success": return True - hf_token = _resolve_token(token_env) + # Use dedicated session upload token (write-only access to session dataset) + hf_token = _SESSION_TOKEN if not hf_token: - _update_upload_status(session_file, status_key, url_key, "failed") + # Update status to failed + data["upload_status"] = "failed" + with open(session_file, "w") as f: + json.dump(data, f, indent=2) return False - # Build temp upload payload in the requested format. + # Prepare JSONL content (single line) + # Store messages and events as JSON strings to avoid schema conflicts + session_row = { + "session_id": data["session_id"], + "session_start_time": data["session_start_time"], + "session_end_time": data["session_end_time"], + "model_name": data["model_name"], + "messages": json.dumps(data["messages"]), + "events": json.dumps(data["events"]), + } + + # Create temporary JSONL file import tempfile with tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as tmp: + json.dump(session_row, tmp) # Single line JSON tmp_path = tmp.name try: - if format == "claude_code": - _write_claude_code_payload(data, tmp_path) - else: - _write_row_payload(data, tmp_path) - + # Generate unique path in repo: sessions/YYYY-MM-DD/session_id.jsonl session_id = data["session_id"] date_str = datetime.fromisoformat(data["session_start_time"]).strftime( "%Y-%m-%d" ) repo_path = f"sessions/{date_str}/{session_id}.jsonl" + # Upload with retries api = HfApi() for attempt in range(max_retries): try: - # Idempotent create β€” visibility is set on first creation - # only. Existing repos keep whatever the user picked via - # /share-traces. + # Try to create repo if it doesn't exist (idempotent) try: api.create_repo( repo_id=repo_id, repo_type="dataset", - private=private, + private=False, token=hf_token, - exist_ok=True, + exist_ok=True, # Don't fail if already exists ) + except Exception: + # Repo might already exist, continue pass - _upload_dataset_card(api, repo_id, hf_token, format) - + # Upload the session file api.upload_file( path_or_fileobj=tmp_path, path_in_repo=repo_path, @@ -524,13 +114,12 @@ def upload_session_as_file( commit_message=f"Add session {session_id}", ) - _update_upload_status( - session_file, - status_key, - url_key, - "success", - f"https://huggingface.co/datasets/{repo_id}", - ) + # Update local status to success + data["upload_status"] = "success" + data["upload_url"] = f"https://huggingface.co/datasets/{repo_id}" + with open(session_file, "w") as f: + json.dump(data, f, indent=2) + return True except Exception: @@ -540,12 +129,14 @@ def upload_session_as_file( wait_time = 2**attempt time.sleep(wait_time) else: - _update_upload_status( - session_file, status_key, url_key, "failed" - ) + # Final attempt failed + data["upload_status"] = "failed" + with open(session_file, "w") as f: + json.dump(data, f, indent=2) return False finally: + # Clean up temp file try: os.unlink(tmp_path) except Exception: @@ -556,102 +147,56 @@ def upload_session_as_file( return False -def retry_failed_uploads( - directory: str, - repo_id: str, - format: str = "row", - token_env: str | None = None, - private: bool = False, -): - """Retry all failed/pending uploads in a directory for the given format.""" +def retry_failed_uploads(directory: str, repo_id: str): + """Retry all failed/pending uploads in a directory""" log_dir = Path(directory) if not log_dir.exists(): return - status_key = _status_field(format) session_files = list(log_dir.glob("session_*.json")) for filepath in session_files: try: - data = _read_session_file(str(filepath)) - - # Only retry pending or failed uploads. Files predating this - # field don't have it; treat unknown as "not yet attempted" for - # the row format (legacy behavior) and "skip" for claude_code - # so we don't suddenly re-upload pre-existing sessions to a - # newly-introduced personal repo. - status = data.get(status_key, "unknown") - if format == "claude_code" and status_key not in data: - continue - - if status in ("pending", "failed", "unknown"): - upload_session_as_file( - str(filepath), - repo_id, - format=format, - token_env=token_env, - private=private, - ) + with open(filepath, "r") as f: + data = json.load(f) - except Exception: - pass + upload_status = data.get("upload_status", "unknown") + # Only retry pending or failed uploads + if upload_status in ["pending", "failed"]: + upload_session_as_file(str(filepath), repo_id) -def _str2bool(v: str) -> bool: - return str(v).strip().lower() in {"1", "true", "yes", "on"} + except Exception: + pass if __name__ == "__main__": - parser = argparse.ArgumentParser(prog="session_uploader.py") - sub = parser.add_subparsers(dest="command", required=True) - - p_upload = sub.add_parser("upload") - p_upload.add_argument("session_file") - p_upload.add_argument("repo_id") - p_upload.add_argument( - "--format", - choices=["row", "claude_code"], - default="row", - ) - p_upload.add_argument( - "--token-env", - default=None, - help="Env var name holding the HF token (default: org fallback chain).", - ) - p_upload.add_argument("--private", default="false") - - p_retry = sub.add_parser("retry") - p_retry.add_argument("directory") - p_retry.add_argument("repo_id") - p_retry.add_argument( - "--format", - choices=["row", "claude_code"], - default="row", - ) - p_retry.add_argument("--token-env", default=None) - p_retry.add_argument("--private", default="false") - - args = parser.parse_args() - - if args.command == "upload": - ok = upload_session_as_file( - args.session_file, - args.repo_id, - format=args.format, - token_env=args.token_env, - private=_str2bool(args.private), - ) - sys.exit(0 if ok else 1) - - if args.command == "retry": - retry_failed_uploads( - args.directory, - args.repo_id, - format=args.format, - token_env=args.token_env, - private=_str2bool(args.private), - ) + if len(sys.argv) < 3: + print("Usage: session_uploader.py ") + sys.exit(1) + + command = sys.argv[1] + + if command == "upload": + # python session_uploader.py upload + if len(sys.argv) < 4: + print("Usage: session_uploader.py upload ") + sys.exit(1) + session_file = sys.argv[2] + repo_id = sys.argv[3] + success = upload_session_as_file(session_file, repo_id) + sys.exit(0 if success else 1) + + elif command == "retry": + # python session_uploader.py retry + if len(sys.argv) < 4: + print("Usage: session_uploader.py retry ") + sys.exit(1) + directory = sys.argv[2] + repo_id = sys.argv[3] + retry_failed_uploads(directory, repo_id) sys.exit(0) - parser.print_help() - sys.exit(1) + else: + print(f"Unknown command: {command}") + sys.exit(1) diff --git a/agent/core/telemetry.py b/agent/core/telemetry.py deleted file mode 100644 index 38d2bbe761fee99d7c8051d6788fc849df8a8fae..0000000000000000000000000000000000000000 --- a/agent/core/telemetry.py +++ /dev/null @@ -1,422 +0,0 @@ -"""All agent observability in one module. - -Every telemetry signal the agent emits β€” LLM-call usage / cost, hf_jobs -lifecycle, sandbox lifecycle, user feedback, mid-turn heartbeat saves β€” is -defined here so business-logic files stay free of instrumentation noise. - -Callsites are one-liners:: - - await telemetry.record_llm_call(session, model=..., response=r, ...) - await telemetry.record_hf_job_submit(session, job, args, image=..., job_type="Python") - HeartbeatSaver.maybe_fire(session) - -All ``record_*`` functions emit a single ``Event`` via ``session.send_event`` -and never raise β€” telemetry is best-effort and must not break the agent. -""" - -from __future__ import annotations - -import asyncio -import logging -import time -from typing import Any - -logger = logging.getLogger(__name__) - - -# ── usage extraction ──────────────────────────────────────────────────────── - - -def extract_usage(response_or_chunk: Any) -> dict: - """Flat usage dict from a litellm response or final-chunk usage object. - - Normalizes across providers: Anthropic exposes cache tokens as - ``cache_read_input_tokens`` / ``cache_creation_input_tokens``; OpenAI uses - ``prompt_tokens_details.cached_tokens``. Exposed under the stable keys - ``cache_read_tokens`` / ``cache_creation_tokens``. - """ - u = getattr(response_or_chunk, "usage", None) - if u is None and isinstance(response_or_chunk, dict): - u = response_or_chunk.get("usage") - if u is None: - return {} - - def _g(name, default=0): - if isinstance(u, dict): - return u.get(name, default) or default - return getattr(u, name, default) or default - - prompt = _g("prompt_tokens") - completion = _g("completion_tokens") - total = _g("total_tokens") or (prompt + completion) - - cache_read = _g("cache_read_input_tokens") - cache_creation = _g("cache_creation_input_tokens") - - if not cache_read: - details = _g("prompt_tokens_details", None) - if details is not None: - if isinstance(details, dict): - cache_read = details.get("cached_tokens", 0) or 0 - else: - cache_read = getattr(details, "cached_tokens", 0) or 0 - - return { - "prompt_tokens": int(prompt), - "completion_tokens": int(completion), - "total_tokens": int(total), - "cache_read_tokens": int(cache_read), - "cache_creation_tokens": int(cache_creation), - } - - -# ── llm_call ──────────────────────────────────────────────────────────────── - - -async def record_llm_call( - session: Any, - *, - model: str, - response: Any = None, - latency_ms: int, - finish_reason: str | None, - kind: str = "main", -) -> dict: - """Emit an ``llm_call`` event and return the extracted usage dict so - callers can stash it on their result object if they want. - - ``kind`` tags the call site so downstream analytics can break spend - down by category. Values currently emitted by the codebase: - - * ``main`` β€” agent loop turn (user-facing reply or tool follow-up) - * ``research`` β€” research sub-agent inner loop (3 call sites) - * ``compaction`` β€” context-window summary on overflow - * ``effort_probe``β€” effort cascade walk on rejection / model switch - * ``restore`` β€” session re-seed summary after a Space restart - - Pre-2026-04-29 only ``main`` calls were instrumented; observed gap on - Cost Explorer was ~67%, with the other 5 call sites accounting for - the rest. Tagging lets us split the dataset's ``total_cost_usd`` by - category and validate against AWS billing. - - The ``/title`` (HF Router, not Bedrock) and ``/health/llm`` (diagnostic - endpoint, no session context) call sites are intentionally not - instrumented β€” together they're <1% of spend. - """ - usage = extract_usage(response) if response is not None else {} - cost_usd = 0.0 - if response is not None: - try: - from litellm import completion_cost - - cost_usd = float(completion_cost(completion_response=response) or 0.0) - except Exception: - cost_usd = 0.0 - from agent.core.session import Event # local import to avoid cycle - - try: - await session.send_event( - Event( - event_type="llm_call", - data={ - "model": model, - "latency_ms": latency_ms, - "finish_reason": finish_reason, - "cost_usd": cost_usd, - "kind": kind, - **usage, - }, - ) - ) - except Exception as e: - logger.debug("record_llm_call failed (non-fatal): %s", e) - return usage - - -# ── hf_jobs ──────────────────────────────────────────────────────────────── - - -def _infer_push_to_hub(script_or_cmd: Any) -> bool: - if not isinstance(script_or_cmd, str): - return False - return ( - "push_to_hub=True" in script_or_cmd - or "push_to_hub=true" in script_or_cmd - or "hub_model_id" in script_or_cmd - ) - - -async def record_hf_job_submit( - session: Any, - job: Any, - args: dict, - *, - image: str, - job_type: str, -) -> float: - """Emit ``hf_job_submit``. Returns the monotonic start timestamp so the - caller can pass it back into :func:`record_hf_job_complete`.""" - from agent.core.session import Event - - t_start = time.monotonic() - try: - script_text = args.get("script") or args.get("command") or "" - await session.send_event( - Event( - event_type="hf_job_submit", - data={ - "job_id": getattr(job, "id", None), - "job_url": getattr(job, "url", None), - "flavor": args.get("hardware_flavor", "cpu-basic"), - "timeout": args.get("timeout", "30m"), - "job_type": job_type, - "image": image, - "namespace": args.get("namespace"), - "push_to_hub": _infer_push_to_hub(script_text), - }, - ) - ) - except Exception as e: - logger.debug("record_hf_job_submit failed (non-fatal): %s", e) - return t_start - - -async def record_hf_job_complete( - session: Any, - job: Any, - *, - flavor: str, - final_status: str, - submit_ts: float, -) -> None: - from agent.core.session import Event - - try: - wall_time_s = int(time.monotonic() - submit_ts) - await session.send_event( - Event( - event_type="hf_job_complete", - data={ - "job_id": getattr(job, "id", None), - "flavor": flavor, - "final_status": final_status, - "wall_time_s": wall_time_s, - }, - ) - ) - except Exception as e: - logger.debug("record_hf_job_complete failed (non-fatal): %s", e) - - -# ── sandbox ───────────────────────────────────────────────────────────────── - - -async def record_sandbox_create( - session: Any, - sandbox: Any, - *, - hardware: str, - create_latency_s: int, -) -> None: - from agent.core.session import Event - - try: - # Pin created-at on the session so record_sandbox_destroy can diff. - session._sandbox_created_at = time.monotonic() - create_latency_s - await session.send_event( - Event( - event_type="sandbox_create", - data={ - "sandbox_id": getattr(sandbox, "space_id", None), - "hardware": hardware, - "create_latency_s": int(create_latency_s), - }, - ) - ) - except Exception as e: - logger.debug("record_sandbox_create failed (non-fatal): %s", e) - - -async def record_sandbox_destroy(session: Any, sandbox: Any) -> None: - from agent.core.session import Event - - try: - created = getattr(session, "_sandbox_created_at", None) - lifetime_s = int(time.monotonic() - created) if created else None - await session.send_event( - Event( - event_type="sandbox_destroy", - data={ - "sandbox_id": getattr(sandbox, "space_id", None), - "lifetime_s": lifetime_s, - }, - ) - ) - except Exception as e: - logger.debug("record_sandbox_destroy failed (non-fatal): %s", e) - - -# ── feedback ─────────────────────────────────────────────────────────────── - - -async def record_feedback( - session: Any, - *, - rating: str, - turn_index: int | None = None, - message_id: str | None = None, - comment: str | None = None, -) -> None: - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="feedback", - data={ - "rating": rating, - "turn_index": turn_index, - "message_id": message_id, - "comment": (comment or "")[:500], - }, - ) - ) - except Exception as e: - logger.debug("record_feedback failed (non-fatal): %s", e) - - -async def record_jobs_access_blocked( - session: Any, - *, - tool_call_ids: list[str], - plan: str, - eligible_namespaces: list[str], -) -> None: - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="jobs_access_blocked", - data={ - "tool_call_ids": tool_call_ids, - "plan": plan, - "eligible_namespaces": eligible_namespaces, - }, - ) - ) - except Exception as e: - logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e) - - -async def record_pro_cta_click( - session: Any, - *, - source: str, - target: str = "pro_pricing", -) -> None: - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="pro_cta_click", - data={"source": source, "target": target}, - ) - ) - except Exception as e: - logger.debug("record_pro_cta_click failed (non-fatal): %s", e) - - -async def record_pro_conversion( - session: Any, - *, - first_seen_at: str | None = None, -) -> None: - """Emit a ``pro_conversion`` event for a user we've previously observed - as non-Pro and now see as Pro for the first time. Detected upstream in - ``MongoSessionStore.mark_pro_seen``; fired into the user's first Pro - session so the rollup picks it up alongside other event-driven KPIs.""" - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="pro_conversion", - data={"first_seen_at": first_seen_at}, - ) - ) - except Exception as e: - logger.debug("record_pro_conversion failed (non-fatal): %s", e) - - -async def record_credits_topped_up( - session: Any, - *, - namespace: str | None = None, -) -> None: - """Emit a ``credits_topped_up`` event when an hf_job submits successfully - in a session that previously hit ``jobs_access_blocked`` β€” i.e. the user - came back from the HF billing top-up flow and unblocked themselves. - Caller is responsible for firing this at most once per session.""" - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="credits_topped_up", - data={"namespace": namespace}, - ) - ) - except Exception as e: - logger.debug("record_credits_topped_up failed (non-fatal): %s", e) - - -# ── heartbeat ────────────────────────────────────────────────────────────── - -# Module-level reference set for fire-and-forget heartbeat tasks. asyncio only -# keeps *weak* references to tasks, so the returned Task would otherwise be -# eligible for GC before running β€” the task gets discarded and the upload -# silently never happens. Hold strong refs until the task completes. -_heartbeat_tasks: set[asyncio.Task] = set() - - -class HeartbeatSaver: - """Time-gated mid-turn flush. - - Called from ``Session.send_event`` after every event. Fires - ``save_and_upload_detached`` in a worker thread at most once per - ``heartbeat_interval_s`` (default 60s). Guards against losing trace data - on long-running turns that crash before ``turn_complete``. - """ - - @staticmethod - def maybe_fire(session: Any) -> None: - if not getattr(session.config, "save_sessions", False): - return - interval = getattr(session.config, "heartbeat_interval_s", 0) or 0 - if interval <= 0: - return - now = time.monotonic() - last = getattr(session, "_last_heartbeat_ts", None) - if last is None: - # Initialise on first event; no save yet. - session._last_heartbeat_ts = now - return - if now - last < interval: - return - session._last_heartbeat_ts = now - repo_id = session.config.session_dataset_repo - try: - task = asyncio.get_running_loop().create_task( - asyncio.to_thread(session.save_and_upload_detached, repo_id) - ) - # Hold a strong reference until the task finishes so asyncio can't - # GC it. ``set.discard`` is a no-op on missing keys β†’ safe callback. - _heartbeat_tasks.add(task) - task.add_done_callback(_heartbeat_tasks.discard) - except RuntimeError: - try: - session.save_and_upload_detached(repo_id) - except Exception as e: - logger.debug("Heartbeat save failed (non-fatal): %s", e) diff --git a/agent/core/tools.py b/agent/core/tools.py index 1b750671605143958f1193c38ef7c1ee083a3cdc..9bbf91d798514fddbbae1b7c68d9f1826e82d824 100644 --- a/agent/core/tools.py +++ b/agent/core/tools.py @@ -8,6 +8,8 @@ import warnings from dataclasses import dataclass from typing import Any, Awaitable, Callable, Optional +logger = logging.getLogger(__name__) + from fastmcp import Client from fastmcp.exceptions import ToolError from mcp.types import EmbeddedResource, ImageContent, TextContent @@ -44,12 +46,10 @@ from agent.tools.hf_repo_git_tool import ( hf_repo_git_handler, ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler -from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler from agent.tools.sandbox_tool import get_sandbox_tools -from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git # from agent.tools.private_hf_repo_tools import ( @@ -62,8 +62,6 @@ warnings.filterwarnings( "ignore", category=DeprecationWarning, module="aiohttp.connector" ) -logger = logging.getLogger(__name__) - NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"] @@ -131,12 +129,7 @@ class ToolRouter: Based on codex-rs/core/src/tools/router.rs """ - def __init__( - self, - mcp_servers: dict[str, MCPServerConfig], - hf_token: str | None = None, - local_mode: bool = False, - ): + def __init__(self, mcp_servers: dict[str, MCPServerConfig], hf_token: str | None = None, local_mode: bool = False): self.tools: dict[str, ToolSpec] = {} self.mcp_servers: dict[str, dict[str, Any]] = {} @@ -149,9 +142,7 @@ class ToolRouter: for name, server in mcp_servers.items(): data = server.model_dump() if hf_token: - data.setdefault("headers", {})["Authorization"] = ( - f"Bearer {hf_token}" - ) + data.setdefault("headers", {})["Authorization"] = f"Bearer {hf_token}" mcp_servers_payload[name] = data self.mcp_client = Client({"mcpServers": mcp_servers_payload}) self._mcp_initialized = False @@ -225,9 +216,7 @@ class ToolRouter: await self.register_mcp_tools() self._mcp_initialized = True except Exception as e: - logger.warning( - "MCP connection failed, continuing without MCP tools: %s", e - ) + logger.warning("MCP connection failed, continuing without MCP tools: %s", e) self.mcp_client = None await self.register_openapi_tool() @@ -321,12 +310,6 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: parameters=HF_PAPERS_TOOL_SPEC["parameters"], handler=hf_papers_handler, ), - ToolSpec( - name=WEB_SEARCH_TOOL_SPEC["name"], - description=WEB_SEARCH_TOOL_SPEC["description"], - parameters=WEB_SEARCH_TOOL_SPEC["parameters"], - handler=web_search_handler, - ), # Dataset inspection tool (unified) ToolSpec( name=HF_INSPECT_DATASET_TOOL_SPEC["name"], @@ -341,12 +324,6 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: parameters=PLAN_TOOL_SPEC["parameters"], handler=plan_tool_handler, ), - ToolSpec( - name=NOTIFY_TOOL_SPEC["name"], - description=NOTIFY_TOOL_SPEC["description"], - parameters=NOTIFY_TOOL_SPEC["parameters"], - handler=notify_handler, - ), ToolSpec( name=HF_JOBS_TOOL_SPEC["name"], description=HF_JOBS_TOOL_SPEC["description"], @@ -389,7 +366,6 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: # Sandbox or local tools (highest priority) if local_mode: from agent.tools.local_tools import get_local_tools - tools = get_local_tools() + tools else: tools = get_sandbox_tools() + tools diff --git a/agent/main.py b/agent/main.py index a7262707a6bd0a9946233ed4e5d54a42ce27eeb0..5d64cc8867e9b3a3309836a2cb2b55a9914649ab 100644 --- a/agent/main.py +++ b/agent/main.py @@ -10,7 +10,6 @@ import argparse import asyncio import json import os -import signal import sys import time from dataclasses import dataclass @@ -21,14 +20,9 @@ import litellm from prompt_toolkit import PromptSession from agent.config import load_config -from agent.core.approval_policy import is_scheduled_operation from agent.core.agent_loop import submission_loop -from agent.core import model_switcher -from agent.core.hf_tokens import resolve_hf_token -from agent.core.local_models import is_local_model_id from agent.core.session import OpType from agent.core.tools import ToolRouter -from agent.messaging.gateway import NotificationGateway from agent.utils.reliability_checks import check_training_script_save_pattern from agent.utils.terminal_display import ( get_console, @@ -50,33 +44,15 @@ from agent.utils.terminal_display import ( ) litellm.drop_params = True -# Suppress the "Give Feedback / Get Help" banner LiteLLM prints to stderr -# on every error β€” users don't need it, and our friendly errors cover the case. -litellm.suppress_debug_info = True -CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json" - - -def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool: - if tool_info.get("tool") != "hf_jobs": - return False - arguments = tool_info.get("arguments") or {} - if isinstance(arguments, str): - try: - arguments = json.loads(arguments) - except json.JSONDecodeError: - return False - if not isinstance(arguments, dict): - return False - return is_scheduled_operation(arguments.get("operation")) - - -def _configure_runtime_logging() -> None: - """Keep third-party warning spam from punching through the interactive UI.""" - import logging - - logging.getLogger("LiteLLM").setLevel(logging.ERROR) - logging.getLogger("litellm").setLevel(logging.ERROR) +# ── Available models (mirrors backend/routes/agent.py) ────────────────── +AVAILABLE_MODELS = [ + {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"}, + {"id": "huggingface/fireworks-ai/MiniMaxAI/MiniMax-M2.5", "label": "MiniMax M2.5"}, + {"id": "huggingface/novita/moonshotai/kimi-k2.5", "label": "Kimi K2.5"}, + {"id": "huggingface/novita/zai-org/glm-5", "label": "GLM 5"}, +] +VALID_MODEL_IDS = {m["id"] for m in AVAILABLE_MODELS} def _safe_get_args(arguments: dict) -> dict: @@ -88,16 +64,26 @@ def _safe_get_args(arguments: dict) -> dict: return args if isinstance(args, dict) else {} -def _get_hf_user(token: str | None) -> str | None: - """Resolve the HF username for a token, if available.""" - if not token: - return None +def _get_hf_token() -> str | None: + """Get HF token from environment, huggingface_hub API, or cached token file.""" + token = os.environ.get("HF_TOKEN") + if token: + return token try: from huggingface_hub import HfApi - - return HfApi(token=token).whoami().get("name") + api = HfApi() + token = api.token + if token: + return token except Exception: - return None + pass + # Fallback: read the cached token file directly + token_path = Path.home() / ".cache" / "huggingface" / "token" + if token_path.exists(): + token = token_path.read_text().strip() + if token: + return token + return None async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str: @@ -137,13 +123,10 @@ async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str: login(token=token, add_to_git_credential=False) print("Token saved to ~/.cache/huggingface/token") except Exception as e: - print( - f"Warning: could not persist token ({e}), using for this session only." - ) + print(f"Warning: could not persist token ({e}), using for this session only.") return token - @dataclass class Operation: """Operation to be executed by the agent""" @@ -168,9 +151,9 @@ def _create_rich_console(): class _ThinkingShimmer: """Animated shiny/shimmer thinking indicator β€” a bright gradient sweeps across the text.""" - _BASE = (90, 90, 110) # dim base color - _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold) - _WIDTH = 5 # shimmer width in characters + _BASE = (90, 90, 110) # dim base color + _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold) + _WIDTH = 5 # shimmer width in characters _FPS = 24 def __init__(self, console): @@ -185,8 +168,6 @@ class _ThinkingShimmer: self._task = asyncio.ensure_future(self._animate()) def stop(self): - if not self._running: - return # no-op when never started (e.g. headless mode) self._running = False if self._task: self._task.cancel() @@ -231,10 +212,7 @@ class _ThinkingShimmer: class _StreamBuffer: - """Accumulates streamed tokens, renders markdown block-by-block as complete - blocks appear. A "block" is everything up to a paragraph break (\\n\\n). - Unclosed code fences (odd count of ```) hold back flushing until closed so - a code block is always rendered as one unit.""" + """Accumulates streamed tokens, renders full markdown on finish.""" def __init__(self, console): self._console = console @@ -243,43 +221,10 @@ class _StreamBuffer: def add_chunk(self, text: str): self._buffer += text - def _pop_block(self) -> str | None: - """Extract the next complete block, or return None if nothing complete.""" - if self._buffer.count("```") % 2 == 1: - return None # inside an open code fence β€” wait for close - idx = self._buffer.find("\n\n") - if idx == -1: - return None - block = self._buffer[:idx] - self._buffer = self._buffer[idx + 2 :] - return block - - async def flush_ready( - self, - cancel_event: "asyncio.Event | None" = None, - instant: bool = False, - ): - """Render any complete blocks that have accumulated; leave the tail.""" - while True: - if cancel_event is not None and cancel_event.is_set(): - return - block = self._pop_block() - if block is None: - return - if block.strip(): - await print_markdown(block, cancel_event=cancel_event, instant=instant) - - async def finish( - self, - cancel_event: "asyncio.Event | None" = None, - instant: bool = False, - ): - """Flush complete blocks, then render whatever incomplete tail remains.""" - await self.flush_ready(cancel_event=cancel_event, instant=instant) + def finish(self): + """Render the accumulated text as markdown, then reset.""" if self._buffer.strip(): - await print_markdown( - self._buffer, cancel_event=cancel_event, instant=instant - ) + print_markdown(self._buffer) self._buffer = "" def discard(self): @@ -293,7 +238,6 @@ async def event_listener( ready_event: asyncio.Event, prompt_session: PromptSession, config=None, - session_holder=None, ) -> None: """Background task that listens for events and displays them""" submission_id = [1000] @@ -302,37 +246,25 @@ async def event_listener( shimmer = _ThinkingShimmer(console) stream_buf = _StreamBuffer(console) - def _cancel_event(): - """Return the session's cancellation Event so print_markdown can abort - its typewriter loop mid-stream when Ctrl+C fires.""" - s = session_holder[0] if session_holder else None - return s._cancelled if s is not None else None - while True: try: event = await event_queue.get() if event.event_type == "ready": - tool_count = event.data.get("tool_count", 0) if event.data else 0 - print_init_done(tool_count=tool_count) + print_init_done() ready_event.set() elif event.event_type == "assistant_message": shimmer.stop() content = event.data.get("content", "") if event.data else "" if content: - await print_markdown(content, cancel_event=_cancel_event()) + print_markdown(content) elif event.event_type == "assistant_chunk": content = event.data.get("content", "") if event.data else "" if content: stream_buf.add_chunk(content) - # Flush any complete markdown blocks progressively so the - # user sees paragraphs appear as they're produced, not just - # at the end of the whole response. - shimmer.stop() - await stream_buf.flush_ready(cancel_event=_cancel_event()) elif event.event_type == "assistant_stream_end": shimmer.stop() - await stream_buf.finish(cancel_event=_cancel_event()) + stream_buf.finish() elif event.event_type == "tool_call": shimmer.stop() stream_buf.discard() @@ -356,9 +288,6 @@ async def event_listener( stream_buf.discard() print_turn_complete() print_plan() - session = session_holder[0] if session_holder else None - if session is not None: - await session.send_deferred_turn_complete_notification(event) turn_complete_event.set() elif event.event_type == "interrupted": shimmer.stop() @@ -372,19 +301,13 @@ async def event_listener( tool = event.data.get("tool", "") if event.data else "" log = event.data.get("log", "") if event.data else "" if log: - agent_id = event.data.get("agent_id", "") if event.data else "" - label = event.data.get("label", "") if event.data else "" - print_tool_log(tool, log, agent_id=agent_id, label=label) + print_tool_log(tool, log) elif event.event_type == "tool_state_change": pass # visual noise β€” approval flow handles this elif event.event_type == "error": shimmer.stop() stream_buf.discard() - error = ( - event.data.get("error", "Unknown error") - if event.data - else "Unknown error" - ) + error = event.data.get("error", "Unknown error") if event.data else "Unknown error" print_error(error) turn_complete_event.set() elif event.event_type == "shutdown": @@ -402,13 +325,8 @@ async def event_listener( tools_data = event.data.get("tools", []) if event.data else [] count = event.data.get("count", 0) if event.data else 0 - # If yolo mode is active, auto-approve everything except - # scheduled HF jobs, whose recurring cost stays manual. - if ( - config - and config.yolo_mode - and not any(_is_scheduled_hf_job_tool(t) for t in tools_data) - ): + # If yolo mode is active, auto-approve everything + if config and config.yolo_mode: approvals = [ { "tool_call_id": t.get("tool_call_id", ""), @@ -641,35 +559,10 @@ async def event_listener( if gated is not None: print(f"Gated: {gated}") - # Get user decision for this item. Ctrl+C / EOF here is - # treated as "reject remaining" (matches Codex's modal - # priority and Forgecode's approval-cancel path). Without - # this, KeyboardInterrupt kills the event listener and - # the main loop deadlocks waiting for turn_complete. - try: - response = await prompt_session.prompt_async( - f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): " - ) - except (KeyboardInterrupt, EOFError): - get_console().print( - "[dim]Approval cancelled β€” rejecting remaining items[/dim]" - ) - approvals.append( - { - "tool_call_id": tool_call_id, - "approved": False, - "feedback": "User cancelled approval", - } - ) - for remaining in tools_data[i:]: - approvals.append( - { - "tool_call_id": remaining.get("tool_call_id", ""), - "approved": False, - "feedback": None, - } - ) - break + # Get user decision for this item + response = await prompt_session.prompt_async( + f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): " + ) response = response.strip().lower() @@ -739,7 +632,7 @@ async def get_user_input(prompt_session: PromptSession) -> str: # Slash commands are defined in terminal_display -async def _handle_slash_command( +def _handle_slash_command( cmd: str, config, session_holder: list, @@ -749,9 +642,6 @@ async def _handle_slash_command( """ Handle a slash command. Returns a Submission to enqueue, or None if the command was handled locally (caller should set turn_complete_event). - - Async because ``/model`` fires a probe ping to validate the model+effort - combo before committing the switch. """ parts = cmd.strip().split(None, 1) command = parts[0].lower() @@ -776,22 +666,25 @@ async def _handle_slash_command( ) if command == "/model": - console = get_console() if not arg: - model_switcher.print_model_listing(config, console) + print("Available models:") + session = session_holder[0] if session_holder else None + current = config.model_name if config else "" + for m in AVAILABLE_MODELS: + marker = " <-- current" if m["id"] == current else "" + print(f" {m['id']} ({m['label']}){marker}") return None - if not model_switcher.is_valid_model_id(arg): - model_switcher.print_invalid_id(arg, console) + if arg not in VALID_MODEL_IDS: + print(f"Unknown model: {arg}") + print(f"Valid: {', '.join(VALID_MODEL_IDS)}") return None - normalized = arg.removeprefix("huggingface/") session = session_holder[0] if session_holder else None - await model_switcher.probe_and_switch_model( - normalized, - config, - session, - console, - resolve_hf_token(), - ) + if session: + session.update_model(arg) + print(f"Model switched to {arg}") + else: + config.model_name = arg + print(f"Model set to {arg} (session not started yet)") return None if command == "/yolo": @@ -800,194 +693,34 @@ async def _handle_slash_command( print(f"YOLO mode: {state}") return None - if command == "/effort": - console = get_console() - valid = {"minimal", "low", "medium", "high", "xhigh", "max", "off"} - session = session_holder[0] if session_holder else None - if not arg: - current = config.reasoning_effort or "off" - console.print(f"[bold]Reasoning effort preference:[/bold] {current}") - if session and session.model_effective_effort: - console.print("[dim]Probed per model:[/dim]") - for m, eff in session.model_effective_effort.items(): - console.print(f" [dim]{m}: {eff or 'off'}[/dim]") - console.print( - "[dim]Set with '/effort minimal|low|medium|high|xhigh|max|off'. " - "'max' is Anthropic-only; 'xhigh' is also supported by current " - "OpenAI GPT-5 models. The cascade falls back to whatever the " - "model actually accepts.[/dim]" - ) - return None - level = arg.lower() - if level not in valid: - console.print(f"[bold red]Invalid level:[/bold red] {arg}") - console.print(f"[dim]Expected one of: {', '.join(sorted(valid))}[/dim]") - return None - config.reasoning_effort = None if level == "off" else level - # Drop the per-model probe cache β€” the new preference may resolve - # differently. Next ``/model`` (or the retry safety net) reprobes. - if session is not None: - session.model_effective_effort.clear() - console.print(f"[green]Reasoning effort: {level}[/green]") - if session is not None: - console.print( - "[dim]run /model to re-probe, or send a message β€” " - "the agent adjusts automatically if the new level isn't supported.[/dim]" - ) - return None - if command == "/status": session = session_holder[0] if session_holder else None print(f"Model: {config.model_name}") - print(f"Reasoning effort: {config.reasoning_effort or 'off'}") if session: print(f"Turns: {session.turn_count}") print(f"Context items: {len(session.context_manager.items)}") return None - if command == "/share-traces": - session = session_holder[0] if session_holder else None - await _handle_share_traces_command(arg, config, session) - return None - print(f"Unknown command: {command}. Type /help for available commands.") return None -async def _handle_share_traces_command(arg: str, config, session) -> None: - """Show or flip visibility of the user's personal trace dataset. - - Uses the user's own HF_TOKEN (write-scoped to their namespace). Only - operates on the personal trace repo configured via - ``personal_trace_repo_template`` β€” never touches the shared org dataset. - """ - from huggingface_hub import HfApi - from huggingface_hub.utils import HfHubHTTPError - - console = get_console() - if session is None: - console.print("[bold red]No active session.[/bold red]") - return - - repo_id = session._personal_trace_repo_id() if session is not None else None - if not repo_id: - if not getattr(config, "share_traces", False): - console.print( - "[yellow]share_traces is disabled in config. " - "Set it to true to publish per-session traces to your HF dataset." - "[/yellow]" - ) - return - if not session.user_id: - console.print( - "[yellow]No HF username resolved \u2014 cannot pick a personal " - "trace repo. Set HF_TOKEN to a token tied to your account.[/yellow]" - ) - return - console.print( - "[yellow]personal_trace_repo_template is unset \u2014 nothing to do.[/yellow]" - ) - return - - token = session.hf_token or resolve_hf_token() - if not token: - console.print( - "[bold red]No HF_TOKEN available.[/bold red] Cannot read or change " - "dataset visibility." - ) - return - - api = HfApi(token=token) - url = f"https://huggingface.co/datasets/{repo_id}" - target = arg.strip().lower() - - if not target: - try: - info = await asyncio.to_thread( - api.repo_info, repo_id=repo_id, repo_type="dataset" - ) - visibility = "private" if getattr(info, "private", False) else "public" - console.print(f"[bold]Trace dataset:[/bold] {url}") - console.print(f"[bold]Visibility:[/bold] {visibility}") - console.print( - "[dim]Use '/share-traces public' to publish, " - "'/share-traces private' to lock it back down.[/dim]" - ) - except HfHubHTTPError as e: - if getattr(e.response, "status_code", None) == 404: - console.print( - f"[dim]Dataset {repo_id} doesn't exist yet \u2014 it'll be " - "created (private) on the next session save.[/dim]" - ) - else: - console.print(f"[bold red]Hub error:[/bold red] {e}") - except Exception as e: - console.print(f"[bold red]Could not fetch dataset info:[/bold red] {e}") - return - - if target not in {"public", "private"}: - console.print( - f"[bold red]Unknown argument:[/bold red] {target}. " - "Expected 'public' or 'private'." - ) - return - - private = target == "private" - try: - # Idempotent β€” create if missing so first-flip works even before any - # session has been saved yet. - await asyncio.to_thread( - api.create_repo, - repo_id=repo_id, - repo_type="dataset", - private=private, - token=token, - exist_ok=True, - ) - await asyncio.to_thread( - api.update_repo_settings, - repo_id=repo_id, - repo_type="dataset", - private=private, - token=token, - ) - except Exception as e: - console.print(f"[bold red]Failed to update visibility:[/bold red] {e}") - return - - label = "PUBLIC" if not private else "private" - console.print(f"[green]Dataset is now {label}.[/green] {url}") - - -async def main(model: str | None = None): +async def main(): """Interactive chat with the agent""" # Clear screen os.system("clear" if os.name != "nt" else "cls") + print_banner() + # Create prompt session for input (needed early for token prompt) prompt_session = PromptSession() - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) - if model: - config.model_name = model - - # HF token β€” required for Hub-backed models/tools, but not for local LLMs. - hf_token = resolve_hf_token() - if not hf_token and not is_local_model_id(config.model_name): + # HF token β€” required, prompt if missing + hf_token = _get_hf_token() + if not hf_token: hf_token = await _prompt_and_save_hf_token(prompt_session) - # Resolve username for banner - hf_user = _get_hf_user(hf_token) - - print_banner(model=config.model_name, hf_user=hf_user) - - # Pre-warm the HF router catalog in the background so /model switches - # don't block on a network fetch. - from agent.core import hf_router_catalog - - asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm)) - # Create queues for communication submission_queue = asyncio.Queue() event_queue = asyncio.Queue() @@ -997,8 +730,10 @@ async def main(model: str | None = None): turn_complete_event.set() ready_event = asyncio.Event() - notification_gateway = NotificationGateway(config.messaging) - await notification_gateway.start() + # Start agent loop in background + config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json" + config = load_config(config_path) + # Create tool router with local mode tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) @@ -1013,12 +748,8 @@ async def main(model: str | None = None): tool_router=tool_router, session_holder=session_holder, hf_token=hf_token, - user_id=hf_user, local_mode=True, stream=True, - notification_gateway=notification_gateway, - notification_destinations=config.messaging.default_auto_destinations(), - defer_turn_complete_notification=True, ) ) @@ -1031,94 +762,44 @@ async def main(model: str | None = None): ready_event, prompt_session, config, - session_holder=session_holder, ) ) await ready_event.wait() submission_id = [0] - # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137 - # (`QUIT_SHORTCUT_TIMEOUT = Duration::from_secs(1)`). Two Ctrl+C presses - # within this window quit; a single press cancels the in-flight turn. - CTRL_C_QUIT_WINDOW = 1.0 - # Hint string matches codex-rs/tui/src/bottom_pane/footer.rs:746 - # (`" again to quit"` prefixed with the key binding, rendered dim). - CTRL_C_HINT = "[dim]ctrl + c again to quit[/dim]" - interrupt_state = {"last": 0.0, "exit": False} - - loop = asyncio.get_running_loop() - - def _on_sigint() -> None: - """SIGINT handler β€” fires while the agent is generating (terminal is - in cooked mode between prompts). Mirrors Codex's `on_ctrl_c` in - codex-rs/tui/src/chatwidget.rs: first press cancels active work and - arms the quit hint; second press within the window quits.""" - now = time.monotonic() - session = session_holder[0] - - if now - interrupt_state["last"] < CTRL_C_QUIT_WINDOW: - interrupt_state["exit"] = True - if session: - session.cancel() - # Wake the main loop out of turn_complete_event.wait() - turn_complete_event.set() - return - - interrupt_state["last"] = now - if session and not session.is_cancelled: - session.cancel() - get_console().print(f"\n{CTRL_C_HINT}") - - def _install_sigint() -> bool: - try: - loop.add_signal_handler(signal.SIGINT, _on_sigint) - return True - except (NotImplementedError, RuntimeError): - return False # Windows or non-main thread - - # prompt_toolkit's prompt_async installs its own SIGINT handler and, on - # exit, calls loop.remove_signal_handler(SIGINT) β€” which wipes ours too. - # So we re-arm at the top of every loop iteration, right before the busy - # wait. Without this, Ctrl+C during agent streaming after the first turn - # falls through to the default handler and the terminal just echoes ^C. - sigint_available = _install_sigint() + last_interrupt_time = 0.0 + agent_busy = False # True only while the agent is processing a submission try: while True: - if sigint_available: - _install_sigint() - + # Wait for previous turn to complete, with interrupt support try: await turn_complete_event.wait() except asyncio.CancelledError: break turn_complete_event.clear() + agent_busy = False - if interrupt_state["exit"]: - break - - # Get user input. prompt_toolkit puts the terminal in raw mode and - # installs its own SIGINT handling; ^C arrives as \x03 and surfaces - # as KeyboardInterrupt here. On return, prompt_toolkit removes the - # loop's SIGINT handler β€” we re-arm at the top of the next iter. + # Get user input try: user_input = await get_user_input(prompt_session) except EOFError: break except KeyboardInterrupt: now = time.monotonic() - if now - interrupt_state["last"] < CTRL_C_QUIT_WINDOW: + if now - last_interrupt_time < 3.0: break - interrupt_state["last"] = now - get_console().print(CTRL_C_HINT) - turn_complete_event.set() + last_interrupt_time = now + # If agent is actually working, cancel it + session = session_holder[0] + if agent_busy and session: + session.cancel() + else: + get_console().print("[dim]Ctrl+C again to exit[/dim]") + turn_complete_event.set() continue - # A successful read ends the double-press window β€” an unrelated - # Ctrl+C during the next turn should start a fresh arming. - interrupt_state["last"] = 0.0 - # Check for exit commands if user_input.strip().lower() in ["exit", "quit", "/quit", "/exit"]: break @@ -1130,18 +811,15 @@ async def main(model: str | None = None): # Handle slash commands if user_input.strip().startswith("/"): - sub = await _handle_slash_command( - user_input.strip(), - config, - session_holder, - submission_queue, - submission_id, + sub = _handle_slash_command( + user_input.strip(), config, session_holder, submission_queue, submission_id ) if sub is None: # Command handled locally, loop back for input turn_complete_event.set() continue else: + agent_busy = True await submission_queue.put(sub) continue @@ -1153,16 +831,11 @@ async def main(model: str | None = None): op_type=OpType.USER_INPUT, data={"text": user_input} ), ) + agent_busy = True await submission_queue.put(submission) except KeyboardInterrupt: pass - finally: - if sigint_available: - try: - loop.remove_signal_handler(signal.SIGINT) - except (NotImplementedError, RuntimeError): - pass # Shutdown shutdown_submission = Submission( @@ -1178,8 +851,6 @@ async def main(model: str | None = None): agent_task.cancel() # Agent didn't shut down cleanly β€” close MCP explicitly await tool_router.__aexit__(None, None, None) - finally: - await notification_gateway.close() # Now safe to cancel the listener (agent is done emitting events) listener_task.cancel() @@ -1197,29 +868,21 @@ async def headless_main( import logging logging.basicConfig(level=logging.WARNING) - _configure_runtime_logging() - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) + hf_token = _get_hf_token() + if not hf_token: + print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr) + sys.exit(1) + + print(f"HF token loaded", file=sys.stderr) + + config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json" + config = load_config(config_path) config.yolo_mode = True # Auto-approve everything in headless mode if model: config.model_name = model - hf_token = resolve_hf_token() - if not hf_token and not is_local_model_id(config.model_name): - print( - "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", - file=sys.stderr, - ) - sys.exit(1) - - if hf_token: - print("HF token loaded", file=sys.stderr) - - notification_gateway = NotificationGateway(config.messaging) - await notification_gateway.start() - hf_user = _get_hf_user(hf_token) - if max_iterations is not None: config.max_iterations = max_iterations @@ -1242,12 +905,8 @@ async def headless_main( tool_router=tool_router, session_holder=session_holder, hf_token=hf_token, - user_id=hf_user, local_mode=True, stream=stream, - notification_gateway=notification_gateway, - notification_destinations=config.messaging.default_auto_destinations(), - defer_turn_complete_notification=True, ) ) @@ -1264,17 +923,13 @@ async def headless_main( ) await submission_queue.put(submission) - # Process events until turn completes. Headless mode is for scripts / - # log capture: no shimmer animation, no typewriter, no live-redrawing - # research overlay. Output is plain, append-only text. + # Process events until turn completes console = _create_rich_console() + shimmer = _ThinkingShimmer(console) stream_buf = _StreamBuffer(console) _hl_last_tool = [None] _hl_sub_id = [1] - # Research sub-agent tool calls are buffered per agent_id and dumped as - # a static block once each sub-agent finishes, instead of streaming via - # the live redrawing SubAgentDisplayManager (which is TTY-only). - _hl_research_buffers: dict[str, dict] = {} + shimmer.start() while True: event = await event_queue.get() @@ -1283,14 +938,16 @@ async def headless_main( content = event.data.get("content", "") if event.data else "" if content: stream_buf.add_chunk(content) - await stream_buf.flush_ready(instant=True) elif event.event_type == "assistant_stream_end": - await stream_buf.finish(instant=True) + shimmer.stop() + stream_buf.finish() elif event.event_type == "assistant_message": + shimmer.stop() content = event.data.get("content", "") if event.data else "" if content: - await print_markdown(content, instant=True) + print_markdown(content) elif event.event_type == "tool_call": + shimmer.stop() stream_buf.discard() tool_name = event.data.get("tool", "") if event.data else "" arguments = event.data.get("arguments", {}) if event.data else {} @@ -1304,92 +961,47 @@ async def headless_main( success = event.data.get("success", False) if event.data else False if _hl_last_tool[0] == "plan_tool" and output: print_tool_output(output, success, truncate=False) + shimmer.start() elif event.event_type == "tool_log": tool = event.data.get("tool", "") if event.data else "" log = event.data.get("log", "") if event.data else "" - if not log: - pass - elif tool == "research": - # Headless mode: buffer research sub-agent activity per-agent, - # then dump each as a static block on completion. The live - # SubAgentDisplayManager uses terminal cursor tricks that are - # unfit for non-TTY output, but parallel agents still need - # distinct output so we key buffers by agent_id. - agent_id = event.data.get("agent_id", "") if event.data else "" - label = event.data.get("label", "") if event.data else "" - aid = agent_id or "research" - if log == "Starting research sub-agent...": - _hl_research_buffers[aid] = { - "label": label or "research", - "calls": [], - } - elif log == "Research complete.": - buf = _hl_research_buffers.pop(aid, None) - if buf is not None: - f = get_console().file - f.write(f" \033[38;2;255;200;80mβ–Έ {buf['label']}\033[0m\n") - for call in buf["calls"]: - f.write(f" \033[2m{call}\033[0m\n") - f.flush() - elif log.startswith("tokens:") or log.startswith("tools:"): - pass # stats updates β€” only useful for the live display - elif aid in _hl_research_buffers: - _hl_research_buffers[aid]["calls"].append(log) - else: - # Orphan event (Start was missed) β€” fall back to raw print - print_tool_log(tool, log, agent_id=agent_id, label=label) - else: + if log: print_tool_log(tool, log) elif event.event_type == "approval_required": - # Auto-approve in headless mode, except scheduled HF jobs. Those - # are rejected because their recurring cost needs manual approval. + # Auto-approve everything in headless mode (safety net if yolo_mode + # didn't prevent the approval event for some reason) tools_data = event.data.get("tools", []) if event.data else [] approvals = [ { "tool_call_id": t.get("tool_call_id", ""), - "approved": not _is_scheduled_hf_job_tool(t), - "feedback": ( - "Scheduled HF jobs require manual approval." - if _is_scheduled_hf_job_tool(t) - else None - ), + "approved": True, + "feedback": None, } for t in tools_data ] _hl_sub_id[0] += 1 - await submission_queue.put( - Submission( - id=f"hl_approval_{_hl_sub_id[0]}", - operation=Operation( - op_type=OpType.EXEC_APPROVAL, - data={"approvals": approvals}, - ), - ) - ) + await submission_queue.put(Submission( + id=f"hl_approval_{_hl_sub_id[0]}", + operation=Operation( + op_type=OpType.EXEC_APPROVAL, + data={"approvals": approvals}, + ), + )) elif event.event_type == "compacted": old_tokens = event.data.get("old_tokens", 0) if event.data else 0 new_tokens = event.data.get("new_tokens", 0) if event.data else 0 print_compacted(old_tokens, new_tokens) elif event.event_type == "error": + shimmer.stop() stream_buf.discard() - error = ( - event.data.get("error", "Unknown error") - if event.data - else "Unknown error" - ) + error = event.data.get("error", "Unknown error") if event.data else "Unknown error" print_error(error) break elif event.event_type in ("turn_complete", "interrupted"): + shimmer.stop() stream_buf.discard() history_size = event.data.get("history_size", "?") if event.data else "?" - print( - f"\n--- Agent {event.event_type} (history_size={history_size}) ---", - file=sys.stderr, - ) - if event.event_type == "turn_complete": - session = session_holder[0] if session_holder else None - if session is not None: - await session.send_deferred_turn_complete_notification(event) + print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr) break # Shutdown @@ -1403,41 +1015,23 @@ async def headless_main( except asyncio.TimeoutError: agent_task.cancel() await tool_router.__aexit__(None, None, None) - finally: - await notification_gateway.close() -def cli(): - """Entry point for the ml-intern CLI command.""" +if __name__ == "__main__": import logging as _logging import warnings - # Suppress aiohttp "Unclosed client session" noise during event loop teardown _logging.getLogger("asyncio").setLevel(_logging.CRITICAL) - _configure_runtime_logging() # Suppress litellm pydantic deprecation warnings warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm") - # Suppress whoosh invalid escape sequence warnings (third-party, unfixed upstream) - warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh") parser = argparse.ArgumentParser(description="Hugging Face Agent CLI") - parser.add_argument( - "prompt", nargs="?", default=None, help="Run headlessly with this prompt" - ) - parser.add_argument( - "--model", "-m", default=None, help="Model to use (default: from config)" - ) - parser.add_argument( - "--max-iterations", - type=int, - default=None, - help="Max LLM requests per turn (default: 50, use -1 for unlimited)", - ) - parser.add_argument( - "--no-stream", - action="store_true", - help="Disable token streaming (use non-streaming LLM calls)", - ) + parser.add_argument("prompt", nargs="?", default=None, help="Run headlessly with this prompt") + parser.add_argument("--model", "-m", default=None, help=f"Model to use (default: from config)") + parser.add_argument("--max-iterations", type=int, default=None, + help="Max LLM requests per turn (default: 50, use -1 for unlimited)") + parser.add_argument("--no-stream", action="store_true", + help="Disable token streaming (use non-streaming LLM calls)") args = parser.parse_args() try: @@ -1445,19 +1039,8 @@ def cli(): max_iter = args.max_iterations if max_iter is not None and max_iter < 0: max_iter = 10_000 # effectively unlimited - asyncio.run( - headless_main( - args.prompt, - model=args.model, - max_iterations=max_iter, - stream=not args.no_stream, - ) - ) + asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream)) else: - asyncio.run(main(model=args.model)) + asyncio.run(main()) except KeyboardInterrupt: print("\n\nGoodbye!") - - -if __name__ == "__main__": - cli() diff --git a/agent/messaging/__init__.py b/agent/messaging/__init__.py deleted file mode 100644 index c399d254e30fcbce555d6f51b810440b1171ec1a..0000000000000000000000000000000000000000 --- a/agent/messaging/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from agent.messaging.gateway import NotificationGateway -from agent.messaging.models import ( - MessagingConfig, - NotificationRequest, - NotificationResult, - SUPPORTED_AUTO_EVENT_TYPES, -) - -__all__ = [ - "MessagingConfig", - "NotificationGateway", - "NotificationRequest", - "NotificationResult", - "SUPPORTED_AUTO_EVENT_TYPES", -] diff --git a/agent/messaging/base.py b/agent/messaging/base.py deleted file mode 100644 index a74f9cf0d1cb2a77328124414b04de9ebbd6b582..0000000000000000000000000000000000000000 --- a/agent/messaging/base.py +++ /dev/null @@ -1,31 +0,0 @@ -from abc import ABC, abstractmethod - -import httpx - -from agent.messaging.models import ( - DestinationConfig, - NotificationRequest, - NotificationResult, -) - - -class NotificationError(Exception): - """Delivery failed and should not be retried.""" - - -class RetryableNotificationError(NotificationError): - """Delivery failed transiently and can be retried.""" - - -class NotificationProvider(ABC): - provider_name: str - - @abstractmethod - async def send( - self, - client: httpx.AsyncClient, - destination_name: str, - destination: DestinationConfig, - request: NotificationRequest, - ) -> NotificationResult: - """Deliver a notification to one destination.""" diff --git a/agent/messaging/gateway.py b/agent/messaging/gateway.py deleted file mode 100644 index 1de9438f5c5c8ae2847ef1bf4a398d10e8903048..0000000000000000000000000000000000000000 --- a/agent/messaging/gateway.py +++ /dev/null @@ -1,172 +0,0 @@ -import asyncio -import logging -from collections.abc import Iterable - -import httpx - -from agent.messaging.base import ( - NotificationError, - NotificationProvider, - RetryableNotificationError, -) -from agent.messaging.models import ( - MessagingConfig, - NotificationRequest, - NotificationResult, -) -from agent.messaging.slack import SlackProvider - -logger = logging.getLogger(__name__) - -_RETRY_DELAYS = (1, 2, 4) - - -class NotificationGateway: - def __init__(self, config: MessagingConfig): - self.config = config - self._providers: dict[str, NotificationProvider] = { - "slack": SlackProvider(), - } - self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue() - self._worker_task: asyncio.Task | None = None - self._client: httpx.AsyncClient | None = None - - @property - def enabled(self) -> bool: - return self.config.enabled - - async def start(self) -> None: - if not self.enabled or self._worker_task is not None: - return - self._client = httpx.AsyncClient(timeout=10.0) - self._worker_task = asyncio.create_task( - self._worker(), name="notification-gateway" - ) - - async def flush(self) -> None: - if not self.enabled: - return - await self._queue.join() - - async def close(self) -> None: - if not self.enabled: - return - await self.flush() - if self._worker_task is not None: - self._worker_task.cancel() - try: - await self._worker_task - except asyncio.CancelledError: - pass - self._worker_task = None - if self._client is not None: - await self._client.aclose() - self._client = None - - async def send(self, request: NotificationRequest) -> NotificationResult: - if not self.enabled: - return NotificationResult( - destination=request.destination, - ok=False, - provider="disabled", - error="Messaging is disabled", - ) - - destination = self.config.get_destination(request.destination) - if destination is None: - return NotificationResult( - destination=request.destination, - ok=False, - provider="unknown", - error=f"Unknown destination '{request.destination}'", - ) - - provider = self._providers.get(destination.provider) - if provider is None: - return NotificationResult( - destination=request.destination, - ok=False, - provider=destination.provider, - error=f"No provider implementation for '{destination.provider}'", - ) - return await self._send_with_retries( - provider, request.destination, destination, request - ) - - async def send_many( - self, requests: Iterable[NotificationRequest] - ) -> list[NotificationResult]: - results: list[NotificationResult] = [] - for request in requests: - results.append(await self.send(request)) - return results - - async def enqueue(self, request: NotificationRequest) -> bool: - if not self.enabled or self._worker_task is None: - return False - await self._queue.put(request) - return True - - async def _worker(self) -> None: - while True: - request = await self._queue.get() - try: - result = await self.send(request) - if not result.ok: - logger.warning( - "Notification delivery failed for %s: %s", - request.destination, - result.error, - ) - except Exception: - logger.exception("Unexpected notification worker failure") - finally: - self._queue.task_done() - - async def _send_with_retries( - self, - provider: NotificationProvider, - destination_name: str, - destination, - request: NotificationRequest, - ) -> NotificationResult: - client = self._client or httpx.AsyncClient(timeout=10.0) - owns_client = self._client is None - try: - for attempt in range(len(_RETRY_DELAYS) + 1): - try: - return await provider.send( - client, destination_name, destination, request - ) - except RetryableNotificationError as exc: - if attempt >= len(_RETRY_DELAYS): - return NotificationResult( - destination=destination_name, - ok=False, - provider=provider.provider_name, - error=str(exc), - ) - delay = _RETRY_DELAYS[attempt] - logger.warning( - "Retrying notification to %s in %ss after transient error: %s", - destination_name, - delay, - exc, - ) - await asyncio.sleep(delay) - except NotificationError as exc: - return NotificationResult( - destination=destination_name, - ok=False, - provider=provider.provider_name, - error=str(exc), - ) - return NotificationResult( - destination=destination_name, - ok=False, - provider=provider.provider_name, - error="Notification delivery exhausted retries", - ) - finally: - if owns_client: - await client.aclose() diff --git a/agent/messaging/models.py b/agent/messaging/models.py deleted file mode 100644 index 16148a8179f5de3fa38b36ce76166a48e9f54a83..0000000000000000000000000000000000000000 --- a/agent/messaging/models.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import Annotated, Literal - -from pydantic import BaseModel, Field, field_validator, model_validator - -_DESTINATION_NAME_CHARS = set("abcdefghijklmnopqrstuvwxyz0123456789._-") -SUPPORTED_AUTO_EVENT_TYPES = {"approval_required", "error", "turn_complete"} - - -class SlackDestinationConfig(BaseModel): - provider: Literal["slack"] = "slack" - token: str - channel: str - allow_agent_tool: bool = False - allow_auto_events: bool = False - username: str | None = None - icon_emoji: str | None = None - - @field_validator("token", "channel") - @classmethod - def _require_non_empty(cls, value: str) -> str: - value = value.strip() - if not value: - raise ValueError("must not be empty") - return value - - -DestinationConfig = Annotated[SlackDestinationConfig, Field(discriminator="provider")] - - -class MessagingConfig(BaseModel): - enabled: bool = False - auto_event_types: list[str] = Field( - default_factory=lambda: ["approval_required", "error", "turn_complete"] - ) - destinations: dict[str, DestinationConfig] = Field(default_factory=dict) - - @field_validator("destinations") - @classmethod - def _validate_destination_names( - cls, destinations: dict[str, DestinationConfig] - ) -> dict[str, DestinationConfig]: - for name in destinations: - if not name or any(char not in _DESTINATION_NAME_CHARS for char in name): - raise ValueError( - "destination names must use lowercase letters, digits, '.', '_' or '-'" - ) - return destinations - - @field_validator("auto_event_types") - @classmethod - def _validate_auto_event_types(cls, event_types: list[str]) -> list[str]: - if not event_types: - return [] - normalized: list[str] = [] - seen: set[str] = set() - for event_type in event_types: - if event_type not in SUPPORTED_AUTO_EVENT_TYPES: - raise ValueError(f"unsupported auto event type '{event_type}'") - if event_type not in seen: - normalized.append(event_type) - seen.add(event_type) - return normalized - - @model_validator(mode="after") - def _require_destinations_when_enabled(self) -> "MessagingConfig": - if self.enabled and not self.destinations: - raise ValueError("messaging.enabled requires at least one destination") - return self - - def get_destination(self, name: str) -> DestinationConfig | None: - return self.destinations.get(name) - - def can_agent_tool_send(self, name: str) -> bool: - destination = self.get_destination(name) - return bool(destination and destination.allow_agent_tool) - - def can_auto_send(self, name: str) -> bool: - destination = self.get_destination(name) - return bool(destination and destination.allow_auto_events) - - def default_auto_destinations(self) -> list[str]: - if not self.enabled: - return [] - return [name for name in self.destinations if self.can_auto_send(name)] - - -class NotificationRequest(BaseModel): - destination: str - title: str | None = None - message: str - severity: Literal["info", "success", "warning", "error"] = "info" - metadata: dict[str, str] = Field(default_factory=dict) - event_type: str | None = None - - @field_validator("destination", "message") - @classmethod - def _require_text(cls, value: str) -> str: - value = value.strip() - if not value: - raise ValueError("must not be empty") - return value - - @field_validator("title") - @classmethod - def _normalize_title(cls, value: str | None) -> str | None: - if value is None: - return None - value = value.strip() - return value or None - - -class NotificationResult(BaseModel): - destination: str - ok: bool - provider: str - error: str | None = None - external_id: str | None = None diff --git a/agent/messaging/slack.py b/agent/messaging/slack.py deleted file mode 100644 index 3790e44af790db8579a9a8efb88a2a16283ec71d..0000000000000000000000000000000000000000 --- a/agent/messaging/slack.py +++ /dev/null @@ -1,184 +0,0 @@ -import json -import re - -import httpx - -from agent.messaging.base import ( - NotificationError, - NotificationProvider, - RetryableNotificationError, -) -from agent.messaging.models import ( - NotificationRequest, - NotificationResult, - SlackDestinationConfig, -) - -_SEVERITY_PREFIX = { - "info": "[INFO]", - "success": "[SUCCESS]", - "warning": "[WARNING]", - "error": "[ERROR]", -} - - -def _format_slack_mrkdwn(content: str) -> str: - """Convert common Markdown constructs to Slack's mrkdwn syntax.""" - if not content: - return content - - placeholders: dict[str, str] = {} - placeholder_index = 0 - - def placeholder(value: str) -> str: - nonlocal placeholder_index - key = f"\x00SLACK{placeholder_index}\x00" - placeholder_index += 1 - placeholders[key] = value - return key - - text = content - - # Protect code before any formatting conversion. Slack's mrkdwn ignores - # formatting inside backticks, so these regions should stay byte-for-byte. - text = re.sub( - r"(```(?:[^\n]*\n)?[\s\S]*?```)", - lambda match: placeholder(match.group(0)), - text, - ) - text = re.sub(r"(`[^`\n]+`)", lambda match: placeholder(match.group(0)), text) - - def convert_markdown_link(match: re.Match[str]) -> str: - label = match.group(1) - url = match.group(2).strip() - if url.startswith("<") and url.endswith(">"): - url = url[1:-1].strip() - return placeholder(f"<{url}|{label}>") - - text = re.sub( - r"\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)", - convert_markdown_link, - text, - ) - - # Preserve existing Slack entities and manual mrkdwn links before escaping. - text = re.sub( - r"(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)", - lambda match: placeholder(match.group(1)), - text, - ) - text = re.sub( - r"^(>+\s)", - lambda match: placeholder(match.group(0)), - text, - flags=re.MULTILINE, - ) - - text = text.replace("&", "&").replace("<", "<").replace(">", ">") - text = text.replace("&", "&").replace("<", "<").replace(">", ">") - - def convert_header(match: re.Match[str]) -> str: - header = match.group(1).strip() - header = re.sub(r"\*\*(.+?)\*\*", r"\1", header) - return placeholder(f"*{header}*") - - text = re.sub(r"^#{1,6}\s+(.+)$", convert_header, text, flags=re.MULTILINE) - text = re.sub( - r"\*\*\*(.+?)\*\*\*", - lambda match: placeholder(f"*_{match.group(1)}_*"), - text, - ) - text = re.sub( - r"\*\*(.+?)\*\*", - lambda match: placeholder(f"*{match.group(1)}*"), - text, - ) - text = re.sub( - r"(? str: - lines: list[str] = [] - prefix = _SEVERITY_PREFIX[request.severity] - if request.title: - lines.append(f"{prefix} {request.title}") - else: - lines.append(prefix) - lines.append(request.message) - for key, value in request.metadata.items(): - lines.append(f"{key}: {value}") - return _format_slack_mrkdwn("\n".join(lines)) - - -class SlackProvider(NotificationProvider): - provider_name = "slack" - - async def send( - self, - client: httpx.AsyncClient, - destination_name: str, - destination: SlackDestinationConfig, - request: NotificationRequest, - ) -> NotificationResult: - payload = { - "channel": destination.channel, - "text": _format_text(request), - "mrkdwn": True, - "unfurl_links": False, - "unfurl_media": False, - } - if destination.username: - payload["username"] = destination.username - if destination.icon_emoji: - payload["icon_emoji"] = destination.icon_emoji - - try: - response = await client.post( - "https://slack.com/api/chat.postMessage", - headers={ - "Authorization": f"Bearer {destination.token}", - "Content-Type": "application/json; charset=utf-8", - }, - content=json.dumps(payload), - ) - except httpx.TimeoutException as exc: - raise RetryableNotificationError("Slack request timed out") from exc - except httpx.TransportError as exc: - raise RetryableNotificationError("Slack transport error") from exc - - if response.status_code == 429 or response.status_code >= 500: - raise RetryableNotificationError(f"Slack HTTP {response.status_code}") - if response.status_code >= 400: - raise NotificationError(f"Slack HTTP {response.status_code}") - - try: - data = response.json() - except ValueError as exc: - raise RetryableNotificationError("Slack returned invalid JSON") from exc - - if not data.get("ok"): - error = str(data.get("error") or "unknown_error") - if error == "ratelimited": - raise RetryableNotificationError(error) - raise NotificationError(error) - - return NotificationResult( - destination=destination_name, - ok=True, - provider=self.provider_name, - external_id=str(data.get("ts") or ""), - error=None, - ) diff --git a/agent/prompts/system_prompt_v3.yaml b/agent/prompts/system_prompt_v3.yaml index 4543048f1fd6721264b2ca9ff72b96fb9da472ee..194d2cd1364349ea1fcb26f830a487415e34400b 100644 --- a/agent/prompts/system_prompt_v3.yaml +++ b/agent/prompts/system_prompt_v3.yaml @@ -1,5 +1,5 @@ system_prompt: | - You are ML Intern, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face (HF) ecosystem. + You are Hugging Face Agent, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face ecosystem. Your goal is to complete what the user requested with zero errors. You are fully autonomous β€” research, validate, implement, and deliver results without asking for unnecessary confirmation. @@ -7,20 +7,13 @@ system_prompt: | You do not know current APIs for TRL, Transformers, PEFT, Trackio, or other HF libraries. Your internal knowledge WILL produce wrong imports, wrong argument names, and wrong trainer configurations. - Before writing any ML implementation code, start from the literature. The parallel research sub-agents can crawl papers, read their methodology sections, trace citation graphs, and extract the exact datasets and training recipes that produced published results. This is your primary advantage β€” use it. - - Your default workflow for any ML task: - 1. Find the landmark paper(s) for the task or domain - 2. Crawl their citation graphs to find recent downstream work - 3. Read methodology sections (not abstracts) of the most promising papers β€” especially recent ones with strong results, lot of citations, and publications in high-impact conferences - 4. Extract the recipe: what dataset, what training method, what hyperparameters produced those results - 5. Validate and use those datasets for training + Before writing any ML implementation code (training, fine-tuning, inference, data processing), use the `research` tool. It spawns a sub-agent that explores docs, reads example code, and returns a concise summary β€” keeping your context clean. ``` - research({"task": "Literature crawl for [task]. Start from [paper/topic]. Crawl citation graph for recent downstream papers. Read their methodology sections (3, 4, 5) β€” extract the exact datasets, training methods, and hyperparameters that produced their best results. Attribute every finding to a specific result (e.g. 'Dataset X + method Y β†’ 85.3% on benchmark Z'). Also find working code examples using current TRL/Transformers APIs.", "context": "User wants to [goal]. We need the best training recipe backed by published results."}) + research({"task": "Research current TRL SFTTrainer: find working example scripts, read the implementation, check SFTConfig parameters, and verify trackio setup.", "context": "User wants to SFT fine-tune a model."}) ``` - The sub-agent knows how to use github_find_examples, github_read_file, explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, and hf_papers (with citation_graph, read_paper, snippet_search, find_datasets). Be specific in your task description β€” name anchor papers or arxiv IDs when you have them. + The sub-agent knows how to use github_find_examples, github_read_file, explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, and hf_papers. Be specific in your task description. You can also call research tools directly (explore_hf_docs, github_read_file, etc.) for quick lookups. @@ -28,7 +21,7 @@ system_prompt: | # Mistakes you WILL make without research - HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio config field names. Fix: read a current example script first. + HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first. WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs. @@ -42,7 +35,7 @@ system_prompt: | SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do. - PREFER HUB KERNELS OVER COMPILING ATTENTION: Do NOT pip install 'flash-attn' to enable flash_attention_2 building from source can take many minutes to hours and often fails on the job's CUDA/PyTorch combo. Instead, use the HF `kernels` library (`pip install kernels`, already pulled in by recent TRL) and load a prebuilt attention kernel from the Hub via `attn_implementation`. Examples: `AutoModelForCausalLM.from_pretrained(..., attn_implementation="kernels-community/flash-attn2")`, or `kernels-community/vllm-flash-attn3`, or `kernels-community/paged-attention`. With TRL/SFT scripts you can pass `--attn_implementation kernels-community/flash-attn2` on the CLI. Search additional kernels at https://huggingface.co/models?other=kernel. Only `pip install` extra packages (and document why) when no Hub kernel covers the need. + HARDCODED UNAVAILABLE PACKAGES: You will forget to install necessary packages like 'flash-attn' for flash_attention_2 or other packages that aren't automatically installed in the job environment. Fix: install necessary packages before running the job. SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself β€” switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task. @@ -60,38 +53,6 @@ system_prompt: | DPO: "prompt", "chosen", "rejected" GRPO: "prompt" - # Trackio - - Trackio is natively integrated with Transformers Trainer and all TRL trainers β€” the built-in TrackioCallback handles init/log/finish. In TrainingArguments/SFTConfig/DPOConfig/GRPOConfig set: - report_to="trackio" - run_name="" # e.g. "sft_qwen3-4b_lr2e-5_bs128" - project="" # keeps related runs grouped so you can compare them - trackio_space_id="/mlintern-<8-char-id>" # creates a public dashboard Space - `project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars. - - Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels: - ERROR β€” stop and change approach (divergence, NaN, OOM) - WARN β€” tweak hyperparameters (overfitting, early stopping, KL spike, reward collapse, slow convergence) - INFO β€” milestones (training complete, target reached, checkpoint saved) - Always include numeric values and an actionable suggestion in `text`, e.g. "loss=12.4 at step 200 β€” lr likely too high, try Γ—0.1". A future call must be able to parse it and act on it. - - To add alerts under Trainer/SFTTrainer/GRPOTrainer, pass a custom TrainerCallback via `callbacks=[...]` that calls trackio.alert() inside `on_log` (training metrics like loss, reward, kl) and `on_evaluate` (eval metrics β€” only available here, not in `on_log`). Keep each `if` simple: one metric, one threshold. Conditions stay easy to adjust between runs. - - Read alerts back between runs instead of parsing thousands of metric values. CLI β€” always use --json: - trackio get alerts --project

--run --json - trackio get alerts --project

--since --json # incremental polling - trackio get run --project

--run --json - trackio get metric --project

--run --metric --json - trackio list runs --project

--json - Python: api = trackio.Api(); api.alerts(

, run=, since=); api.runs(

) (each run has .name, .config, .alerts()). - - Drive the next config from prior alerts: - diverged β†’ lr Γ— 0.1 - overfitting β†’ weight_decay Γ— 10 or reduce capacity - early stopping β†’ lr Γ— 0.5 or adjust schedule - high accuracy β†’ refine around current config - Read prior config via api.runs(...).config and only mutate keys the alerts justify changing. - # Data audit Before working with any dataset, audit it first. Do not assume you know what the data looks like β€” inspect it. @@ -107,7 +68,7 @@ system_prompt: | - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details] - push_to_hub=True and hub_model_id set - timeout: [value] (based on: [model size] on [hardware]) - - Trackio monitoring included and deploying metrics to a public Space + - Trackio monitoring included and working If you cannot fill in all items, stop and complete the missing steps first. @@ -122,10 +83,8 @@ system_prompt: | # Sandbox-first development - A private cpu-basic sandbox is already available for normal code execution in each session. For non-trivial scripts, develop and test there before launching via hf_jobs: - write script β†’ pip install β†’ test with small run using bash/read/write/edit β†’ fix errors β†’ launch via hf_jobs at scale - - Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier. + For non-trivial scripts, develop and test in a sandbox before launching via hf_jobs: + sandbox_create β†’ install deps β†’ write script β†’ test with small run β†’ fix errors β†’ launch via hf_jobs at scale Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths. @@ -175,7 +134,7 @@ system_prompt: | HYPERPARAMETER TUNING: Do not tune hyperparameters by hand one-at-a-time. Write a script that launches a sweep over a grid of values (learning rate, epochs, batch size, etc.) and evaluates each run automatically. One well-designed sweep script beats ten manual experiments. - If you run out of ideas: go back to the literature. Crawl citation graphs deeper β€” find papers you haven't read yet, read their methodology sections, extract new datasets or training tricks. Look for papers that cite your current approach and improved on it. Try combining recipes from different papers. Re-read the task prompt for angles you missed. Re-read the training logs for clues. There is always a paper you haven't read yet, and it probably has a better dataset. + If you run out of ideas: research. Use the research tool to find papers on the task or technique β€” look for recent methods, ablation results, tricks that worked for similar problems. Re-read the task prompt for angles you missed. Re-read the training logs for clues. Try combining approaches from different papers. Try a fundamentally different strategy from the literature. There is always a paper you haven't read yet. Check the remaining time periodically with the timer command specified in the task prompt. Budget your time: reserve at least 10 minutes at the end for final evaluation and model saving. @@ -190,7 +149,6 @@ system_prompt: | - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs. - For errors: state what went wrong, why, and what you're doing to fix it. - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity. - - Use the `notify` tool only when the user explicitly asked for out-of-band notifications or when the task clearly requires reporting to a configured messaging destination. Do not use it for routine chat updates. # Tool usage diff --git a/agent/sft/tagger.py b/agent/sft/tagger.py deleted file mode 100644 index 528bc9d0d80b7e63bc63f527e94cabf59b214966..0000000000000000000000000000000000000000 --- a/agent/sft/tagger.py +++ /dev/null @@ -1,353 +0,0 @@ -"""Derive tags for a session trajectory. - -``tag_session(trajectory)`` β†’ ``list[str]``. Pure function. No filtering, no -mutation β€” tags are purely metadata so downstream pipelines can slice the raw -SFT dataset (``where 'hf_job:succeeded' in tags``) without re-reading trajectories. - -Tag namespaces (all tags are ``":"`` strings): - -* ``tool:`` β€” every tool called at least once (``tool:hf_jobs``, …) -* ``outcome:`` β€” ``completed`` / ``errored`` / ``interrupted`` / - ``ongoing`` / ``doom_loop`` / ``context_exceeded`` -* ``hf_job:`` β€” ``submitted``, ``succeeded``, ``failed``, - ``multi`` (>1), ``oom``, ``push_to_hub`` -* ``gpu:`` β€” ``none``, ``t4``, ``a10g``, ``a100``, ``l40s``, - ``h100``, plus ``gpu:multi`` for x2/x4/x8 flavors -* ``sandbox:`` β€” ``created``, ``gpu``, ``cpu``, ``long_lived`` (>30 min) -* ``feedback:`` β€” ``up``, ``down``, ``mixed``, ``none`` -* ``model:`` β€” ``opus`` / ``sonnet`` / ``haiku`` / ``kimi`` / - ``gpt`` / ``deepseek`` / ``qwen`` / ``other`` -* ``turns:`` β€” ``short`` (<5) / ``medium`` (5–20) / ``long`` (>20) -* ``cost:`` β€” ``low`` (<$0.10) / ``med`` (<$1) / ``high`` -* ``task:`` β€” ``training`` / ``inference`` / ``data_prep`` / - ``research_only`` (heuristic on tools + scripts) - -Tags are deduplicated before returning. -""" - -from __future__ import annotations - -from typing import Iterable - -# Flavor β†’ GPU-family mapping. Keep conservative; unknown flavors β†’ "none". -_GPU_FAMILY = { - "cpu-basic": "none", - "cpu-upgrade": "none", - "t4-small": "t4", - "t4-medium": "t4", - "l4x1": "l40s", - "l4x4": "l40s", - "l40sx1": "l40s", - "l40sx4": "l40s", - "l40sx8": "l40s", - "a10g-small": "a10g", - "a10g-large": "a10g", - "a10g-largex2": "a10g", - "a10g-largex4": "a10g", - "a100-large": "a100", - "a100x2": "a100", - "a100x4": "a100", - "a100x8": "a100", - "h100": "h100", - "h100x8": "h100", -} - -# Substrings that count a flavor as multi-GPU. -_MULTI_GPU_MARKERS = ("x2", "x4", "x8") - -# Tool names that don't touch training/inference or sandbox/jobs. If a session -# only used these, we tag it research_only. -_RESEARCH_ONLY_TOOLS = { - "research", - "github_find_examples", - "github_read_file", - "github_list_repos", - "hf_papers", - "explore_hf_docs", - "fetch_hf_docs", - "hub_repo_details", - "plan", - "hf_inspect_dataset", - "web_search", -} - -# Tool names that signal data manipulation workflows. -_DATA_PREP_TOOLS = {"hf_inspect_dataset", "dataset_tools", "hub_repo_details"} - - -def _model_family(model_name: str | None) -> str: - if not model_name: - return "other" - n = model_name.lower() - if "opus" in n: - return "opus" - if "sonnet" in n: - return "sonnet" - if "haiku" in n: - return "haiku" - if "kimi" in n: - return "kimi" - if "gpt" in n: - return "gpt" - if "deepseek" in n: - return "deepseek" - if "qwen" in n: - return "qwen" - if "llama" in n: - return "llama" - return "other" - - -def _turns_bucket(n: int) -> str: - if n < 5: - return "short" - if n <= 20: - return "medium" - return "long" - - -def _cost_bucket(cost_usd: float) -> str: - if cost_usd < 0.10: - return "low" - if cost_usd < 1.0: - return "med" - return "high" - - -def _flavor_to_gpu_tags(flavor: str) -> list[str]: - family = _GPU_FAMILY.get(flavor, "none") - tags = [f"gpu:{family}"] - if any(m in flavor for m in _MULTI_GPU_MARKERS): - tags.append("gpu:multi") - return tags - - -def _has_oom_signal(tool_outputs: Iterable[str]) -> bool: - for out in tool_outputs: - if not isinstance(out, str): - continue - low = out.lower() - if "outofmemoryerror" in low or "cuda out of memory" in low or "oom" in low: - return True - return False - - -def _infer_task_tag( - tool_names: set[str], - hf_job_submit_scripts: list[str], -) -> str | None: - """Return a ``task:*`` tag or None if we can't tell. - - Heuristic order: training > inference > data_prep > research_only. - """ - # training: any hf_jobs script with a Trainer/SFT/training keyword, OR uses - # hf_jobs at all and a script mentions training APIs. - for script in hf_job_submit_scripts: - low = script.lower() - if any( - k in low - for k in ( - "sftconfig", - "sfttrainer", - "trainer(", - "trainingarguments", - "grpo", - "dpo", - ".train(", - "transformers import", - "trainer import", - "fine-tune", - "finetune", - ) - ): - return "training" - - # inference: sessions that use inference tools but never hf_jobs/sandbox - uses_compute = bool(tool_names & {"hf_jobs", "sandbox_create", "sandbox_exec"}) - if not uses_compute and tool_names & {"inference", "generate", "run_inference"}: - return "inference" - - # data_prep: primarily dataset tools and no training/inference - if tool_names & _DATA_PREP_TOOLS and not uses_compute: - return "data_prep" - - # research_only: every tool used is in the research allow-list - if tool_names and tool_names <= _RESEARCH_ONLY_TOOLS: - return "research_only" - - return None - - -def tag_session(trajectory: dict) -> list[str]: - """Derive tags from a session trajectory. Pure function.""" - tags: set[str] = set() - - events: list[dict] = trajectory.get("events") or [] - messages: list[dict] = trajectory.get("messages") or [] - model_name: str | None = trajectory.get("model_name") - - # model - tags.add(f"model:{_model_family(model_name)}") - - # turns - user_turns = sum(1 for m in messages if m.get("role") == "user") - tags.add(f"turns:{_turns_bucket(user_turns)}") - - # cost + tool-name enumeration + outcome detection - cost_usd = 0.0 - tool_names: set[str] = set() - tool_outputs: list[str] = [] - hf_job_submit_count = 0 - hf_job_submit_scripts: list[str] = [] - hf_job_success_count = 0 - hf_job_fail_count = 0 - hf_job_push_to_hub = False - gpu_tags_seen: set[str] = set() - - # Outcome is the *last* terminal signal. Seed with "ongoing" β€” overridden - # if we see a terminal event. - outcome = "ongoing" - had_error = False - had_doom_loop = False - had_compact = False - - feedback_up = 0 - feedback_down = 0 - - sandbox_created = False - sandbox_hardware: str | None = None - sandbox_lifetime_s: int | None = None - - for ev in events: - et = ev.get("event_type") - data = ev.get("data") or {} - - if et == "llm_call": - cost_usd += float(data.get("cost_usd") or 0.0) - - elif et == "tool_call": - name = data.get("tool") - if name: - tool_names.add(name) - - elif et == "tool_output": - out = data.get("output") - if isinstance(out, str): - tool_outputs.append(out) - - elif et == "hf_job_submit": - hf_job_submit_count += 1 - if data.get("push_to_hub"): - hf_job_push_to_hub = True - flavor = data.get("flavor") or "cpu-basic" - for t in _flavor_to_gpu_tags(flavor): - gpu_tags_seen.add(t) - - elif et == "hf_job_complete": - final = (data.get("final_status") or "").lower() - if final in ("completed", "succeeded", "success"): - hf_job_success_count += 1 - elif final in ("failed", "error", "timeout", "cancelled"): - hf_job_fail_count += 1 - - elif et == "sandbox_create": - sandbox_created = True - sandbox_hardware = data.get("hardware") - - elif et == "sandbox_destroy": - lt = data.get("lifetime_s") - if isinstance(lt, (int, float)): - sandbox_lifetime_s = int(lt) - - elif et == "feedback": - rating = data.get("rating") - if rating == "up": - feedback_up += 1 - elif rating == "down": - feedback_down += 1 - - elif et == "error": - had_error = True - elif et == "turn_complete": - if not had_error: - outcome = "completed" - elif et == "interrupted": - outcome = "interrupted" - elif et == "compacted": - had_compact = True - elif et == "tool_log": - log_text = (data.get("log") or "").lower() - if "doom loop" in log_text: - had_doom_loop = True - - if had_error and outcome not in ("completed", "interrupted"): - outcome = "errored" - - tags.add(f"outcome:{outcome}") - if had_doom_loop: - tags.add("outcome:doom_loop") - if had_compact: - tags.add("outcome:context_exceeded") - - # tools - for name in tool_names: - tags.add(f"tool:{name}") - - # hf_jobs facets - if hf_job_submit_count >= 1: - tags.add("hf_job:submitted") - if hf_job_submit_count > 1: - tags.add("hf_job:multi") - if hf_job_success_count > 0: - tags.add("hf_job:succeeded") - if hf_job_fail_count > 0: - tags.add("hf_job:failed") - if hf_job_push_to_hub: - tags.add("hf_job:push_to_hub") - if _has_oom_signal(tool_outputs): - tags.add("hf_job:oom") - - # gpu tags (from all submitted jobs) - tags.update(gpu_tags_seen) - if "gpu:none" in tags and len(gpu_tags_seen) > 1: - # If any GPU flavor was used, drop the "none" tag for clarity. - tags.discard("gpu:none") - - # sandbox facets - if sandbox_created: - tags.add("sandbox:created") - if sandbox_hardware: - fam = _GPU_FAMILY.get(sandbox_hardware, "none") - tags.add("sandbox:cpu" if fam == "none" else "sandbox:gpu") - if sandbox_lifetime_s is not None and sandbox_lifetime_s > 1800: - tags.add("sandbox:long_lived") - - # feedback - if feedback_up and feedback_down: - tags.add("feedback:mixed") - elif feedback_up: - tags.add("feedback:up") - elif feedback_down: - tags.add("feedback:down") - else: - tags.add("feedback:none") - - # cost bucket - tags.add(f"cost:{_cost_bucket(cost_usd)}") - - # task heuristic (needs scripts β€” pull from the hf_job_submit events' - # matching tool_call arguments in the event list). - for ev in events: - if ev.get("event_type") == "tool_call": - data = ev.get("data") or {} - if data.get("tool") == "hf_jobs": - args = data.get("arguments") or {} - script = args.get("script") or args.get("command") or "" - if isinstance(script, str): - hf_job_submit_scripts.append(script) - - task_tag = _infer_task_tag(tool_names, hf_job_submit_scripts) - if task_tag: - tags.add(f"task:{task_tag}") - - return sorted(tags) diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py index 65c793cbaad3b2f74eacaf1da6038ff0bef893d9..14ef45669bc443c1c005ddde69b4205eb02f46cb 100644 --- a/agent/tools/__init__.py +++ b/agent/tools/__init__.py @@ -20,7 +20,6 @@ from agent.tools.github_read_file import ( ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler from agent.tools.types import ToolResult -from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler __all__ = [ "ToolResult", @@ -37,6 +36,4 @@ __all__ = [ "github_search_code_handler", "HF_INSPECT_DATASET_TOOL_SPEC", "hf_inspect_dataset_handler", - "WEB_SEARCH_TOOL_SPEC", - "web_search_handler", ] diff --git a/agent/tools/dataset_tools.py b/agent/tools/dataset_tools.py index 20add683d40c3b0f550daaae046408d64f23ddbd..ef3f3c81b629d0f937a91255622629396e5a2534 100644 --- a/agent/tools/dataset_tools.py +++ b/agent/tools/dataset_tools.py @@ -423,9 +423,7 @@ HF_INSPECT_DATASET_TOOL_SPEC = { } -async def hf_inspect_dataset_handler( - arguments: dict[str, Any], session=None -) -> tuple[str, bool]: +async def hf_inspect_dataset_handler(arguments: dict[str, Any], session=None) -> tuple[str, bool]: """Handler for agent tool router""" try: hf_token = session.hf_token if session else None diff --git a/agent/tools/docs_tools.py b/agent/tools/docs_tools.py index ee40ef353ae05b8d32d4c9a17bd0d9eaa8687532..a1782107ef9c439f03276d314966176c1ce9e4d9 100644 --- a/agent/tools/docs_tools.py +++ b/agent/tools/docs_tools.py @@ -932,7 +932,7 @@ EXPLORE_HF_DOCS_TOOL_SPEC = { "β€’ argilla β€” Data annotation, feedback, and human-in-the-loop workflows.\n" "β€’ distilabel β€” Synthetic data generation and distillation pipelines.\n" "β€’ microsoft-azure β€” Azure deployment and integration guides.\n" - "β€’ kernels β€” Load prebuilt compute kernels (E.g. flash-attn2) from the Hub via `attn_implementation`; avoids compiling flash-attn from source.\n" + "β€’ kernels β€” Lightweight execution environments and notebook-style workflows.\n" "β€’ google-cloud β€” GCP deployment and serving workflows.\n" ), }, diff --git a/agent/tools/edit_utils.py b/agent/tools/edit_utils.py index 1c6b958192ad8a90c9b3268f6fdb688787d97ea6..6a9a3295e2e25313758d633e0b733f57c373cd5a 100644 --- a/agent/tools/edit_utils.py +++ b/agent/tools/edit_utils.py @@ -10,18 +10,18 @@ from __future__ import annotations # ── Unicode normalization map ──────────────────────────────────────────── UNICODE_MAP = { - "\u2013": "-", # en-dash - "\u2014": "-", # em-dash - "\u2212": "-", # minus sign - "\u2018": "'", # left single quote - "\u2019": "'", # right single quote - "\u201c": '"', # left double quote - "\u201d": '"', # right double quote - "\u00a0": " ", # non-breaking space - "\u2003": " ", # em space - "\u2002": " ", # en space - "\u200b": "", # zero-width space - "\ufeff": "", # BOM + "\u2013": "-", # en-dash + "\u2014": "-", # em-dash + "\u2212": "-", # minus sign + "\u2018": "'", # left single quote + "\u2019": "'", # right single quote + "\u201c": '"', # left double quote + "\u201d": '"', # right double quote + "\u00a0": " ", # non-breaking space + "\u2003": " ", # em space + "\u2002": " ", # en space + "\u200b": "", # zero-width space + "\ufeff": "", # BOM } @@ -59,12 +59,12 @@ def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]: line_start_map[i] = original byte offset of the start of line i. """ orig_lines = text.split("\n") - stripped_lines = [strip_fn(line) for line in orig_lines] + stripped_lines = [strip_fn(l) for l in orig_lines] return "\n".join(stripped_lines), orig_lines, stripped_lines # Pass 2 β€” right-trim c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip) - p_rt = "\n".join(line.rstrip() for line in pattern.split("\n")) + p_rt = "\n".join(l.rstrip() for l in pattern.split("\n")) idx = c_rt.find(p_rt) if idx != -1: orig_idx = _map_back(idx, c_orig_lines, c_rt_lines) @@ -72,7 +72,7 @@ def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]: # Pass 3 β€” both-sides trim c_st, _, c_st_lines = _build_stripped(content, str.strip) - p_st = "\n".join(line.strip() for line in pattern.split("\n")) + p_st = "\n".join(l.strip() for l in pattern.split("\n")) idx = c_st.find(p_st) if idx != -1: orig_idx = _map_back(idx, c_orig_lines, c_st_lines) @@ -114,9 +114,7 @@ def _map_back( return 0 -def fuzzy_find_original_match( - content: str, pattern: str -) -> tuple[str | None, str | None]: +def fuzzy_find_original_match(content: str, pattern: str) -> tuple[str | None, str | None]: """Find the *original* text in content that matches pattern fuzzily. Returns (original_matched_text, match_note) or (None, None). @@ -226,9 +224,7 @@ def apply_edit( return new_content, 1, fuzzy_note else: - raise ValueError( - f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before." - ) + raise ValueError(f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before.") # ── Syntax validation (Python) ─────────────────────────────────────────── @@ -259,15 +255,14 @@ def validate_python(content: str, path: str = "") -> list[str]: return warnings # 2. Training script heuristics - if any( - kw in content - for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig") - ): + if any(kw in content for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")): if "push_to_hub" not in content: warnings.append( "Training script warning: no 'push_to_hub' found β€” model may be lost when job ends" ) if "hub_model_id" not in content: - warnings.append("Training script warning: no 'hub_model_id' found") + warnings.append( + "Training script warning: no 'hub_model_id' found" + ) return warnings diff --git a/agent/tools/hf_repo_files_tool.py b/agent/tools/hf_repo_files_tool.py index aee00b741662838769d25711602b5afefcb623e8..fd39a488fc5610b665d2e0ddb7584d892104644a 100644 --- a/agent/tools/hf_repo_files_tool.py +++ b/agent/tools/hf_repo_files_tool.py @@ -10,7 +10,6 @@ from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError -from agent.core.hub_artifacts import is_known_hub_artifact, register_hub_artifact from agent.tools.types import ToolResult OperationType = Literal["list", "read", "upload", "delete"] @@ -40,9 +39,8 @@ def _format_size(size_bytes: int) -> str: class HfRepoFilesTool: """Tool for file operations on HF repos.""" - def __init__(self, hf_token: Optional[str] = None, session: Any = None): + def __init__(self, hf_token: Optional[str] = None): self.api = HfApi(token=hf_token) - self.session = session async def execute(self, args: Dict[str, Any]) -> ToolResult: """Execute the specified operation.""" @@ -63,9 +61,7 @@ class HfRepoFilesTool: if handler: return await handler(args) else: - return self._error( - f"Unknown operation: {operation}. Valid: list, read, upload, delete" - ) + return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete") except RepositoryNotFoundError: return self._error(f"Repository not found: {args.get('repo_id')}") @@ -100,23 +96,17 @@ class HfRepoFilesTool: revision = args.get("revision", "main") path = args.get("path", "") - items = list( - await _async_call( - self.api.list_repo_tree, - repo_id=repo_id, - repo_type=repo_type, - revision=revision, - path_in_repo=path, - recursive=True, - ) - ) + items = list(await _async_call( + self.api.list_repo_tree, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + path_in_repo=path, + recursive=True, + )) if not items: - return { - "formatted": f"No files in {repo_id}", - "totalResults": 0, - "resultsShared": 0, - } + return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0} lines = [] total_size = 0 @@ -128,16 +118,9 @@ class HfRepoFilesTool: lines.append(f"{item.path}/") url = _build_repo_url(repo_id, repo_type) - response = ( - f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" - + "\n".join(lines) - ) + response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines) - return { - "formatted": response, - "totalResults": len(items), - "resultsShared": len(items), - } + return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)} async def _read(self, args: Dict[str, Any]) -> ToolResult: """Read file content from a repository.""" @@ -177,13 +160,8 @@ class HfRepoFilesTool: except UnicodeDecodeError: import os - size = os.path.getsize(file_path) - return { - "formatted": f"Binary file ({_format_size(size)})", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1} async def _upload(self, args: Dict[str, Any]) -> ToolResult: """Upload content to a repository.""" @@ -216,16 +194,6 @@ class HfRepoFilesTool: create_pr=create_pr, ) - if not create_pr and is_known_hub_artifact(self.session, repo_id, repo_type): - await _async_call( - register_hub_artifact, - self.api, - repo_id, - repo_type, - session=self.session, - force=path == "README.md", - ) - url = _build_repo_url(repo_id, repo_type) if create_pr and hasattr(result, "pr_url"): response = f"**Uploaded as PR**\n{result.pr_url}" @@ -267,12 +235,7 @@ class HfRepoFilesTool: def _error(self, message: str) -> ToolResult: """Return an error result.""" - return { - "formatted": message, - "totalResults": 0, - "resultsShared": 0, - "isError": True, - } + return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} # Tool specification @@ -349,13 +312,11 @@ HF_REPO_FILES_TOOL_SPEC = { } -async def hf_repo_files_handler( - arguments: Dict[str, Any], session=None -) -> tuple[str, bool]: +async def hf_repo_files_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]: """Handler for agent tool router.""" try: hf_token = session.hf_token if session else None - tool = HfRepoFilesTool(hf_token=hf_token, session=session) + tool = HfRepoFilesTool(hf_token=hf_token) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: diff --git a/agent/tools/hf_repo_git_tool.py b/agent/tools/hf_repo_git_tool.py index cfff4120b089aa7923c2a46c5c3da22cf201457f..d7e2323a361efd578cf97d78be2593b71f4d8ac8 100644 --- a/agent/tools/hf_repo_git_tool.py +++ b/agent/tools/hf_repo_git_tool.py @@ -10,24 +10,14 @@ from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError -from agent.core.hub_artifacts import register_hub_artifact from agent.tools.types import ToolResult OperationType = Literal[ - "create_branch", - "delete_branch", - "create_tag", - "delete_tag", + "create_branch", "delete_branch", + "create_tag", "delete_tag", "list_refs", - "create_pr", - "list_prs", - "get_pr", - "merge_pr", - "close_pr", - "comment_pr", - "change_pr_status", - "create_repo", - "update_repo", + "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", + "create_repo", "update_repo", ] @@ -46,9 +36,8 @@ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str: class HfRepoGitTool: """Tool for git-like operations on HF repos.""" - def __init__(self, hf_token: Optional[str] = None, session: Any = None): + def __init__(self, hf_token: Optional[str] = None): self.api = HfApi(token=hf_token) - self.session = session async def execute(self, args: Dict[str, Any]) -> ToolResult: """Execute the specified operation.""" @@ -142,11 +131,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}" - return { - "formatted": f"**Branch created:** {branch}\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1} async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult: """Delete a branch.""" @@ -167,11 +152,7 @@ class HfRepoGitTool: repo_type=repo_type, ) - return { - "formatted": f"**Branch deleted:** {branch}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # TAG OPERATIONS @@ -202,11 +183,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}" - return { - "formatted": f"**Tag created:** {tag}\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1} async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult: """Delete a tag.""" @@ -227,11 +204,7 @@ class HfRepoGitTool: repo_type=repo_type, ) - return { - "formatted": f"**Tag deleted:** {tag}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # LIST REFS @@ -253,9 +226,7 @@ class HfRepoGitTool: ) branches = [b.name for b in refs.branches] if refs.branches else [] - tags = ( - [t.name for t in refs.tags] if hasattr(refs, "tags") and refs.tags else [] - ) + tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else [] url = _build_repo_url(repo_id, repo_type) lines = [f"**{repo_id}**", url, ""] @@ -270,11 +241,7 @@ class HfRepoGitTool: else: lines.append("**Tags:** none") - return { - "formatted": "\n".join(lines), - "totalResults": len(branches) + len(tags), - "resultsShared": len(branches) + len(tags), - } + return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)} # ========================================================================= # PR OPERATIONS @@ -303,7 +270,7 @@ class HfRepoGitTool: url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}" return { - "formatted": f'**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision="refs/pr/{result.num}"', + "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"", "totalResults": 1, "resultsShared": 1, } @@ -318,27 +285,17 @@ class HfRepoGitTool: repo_type = args.get("repo_type", "model") status = args.get("status", "all") # open, closed, all - discussions = list( - self.api.get_repo_discussions( - repo_id=repo_id, - repo_type=repo_type, - discussion_status=status if status != "all" else None, - ) - ) + discussions = list(self.api.get_repo_discussions( + repo_id=repo_id, + repo_type=repo_type, + discussion_status=status if status != "all" else None, + )) if not discussions: - return { - "formatted": f"No discussions in {repo_id}", - "totalResults": 0, - "resultsShared": 0, - } + return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0} url = _build_repo_url(repo_id, repo_type) - lines = [ - f"**{repo_id}** - {len(discussions)} discussions", - f"{url}/discussions", - "", - ] + lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""] for d in discussions[:20]: if d.status == "draft": @@ -352,11 +309,7 @@ class HfRepoGitTool: type_label = "PR" if d.is_pull_request else "D" lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}") - return { - "formatted": "\n".join(lines), - "totalResults": len(discussions), - "resultsShared": min(20, len(discussions)), - } + return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))} async def _get_pr(self, args: Dict[str, Any]) -> ToolResult: """Get PR details.""" @@ -382,7 +335,7 @@ class HfRepoGitTool: "draft": "Draft", "open": "Open", "merged": "Merged", - "closed": "Closed", + "closed": "Closed" } status = status_map.get(pr.status, pr.status.capitalize()) type_label = "Pull Request" if pr.is_pull_request else "Discussion" @@ -396,13 +349,9 @@ class HfRepoGitTool: if pr.is_pull_request: if pr.status == "draft": - lines.append( - f'\nTo add commits: upload with revision="refs/pr/{pr_num}"' - ) + lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") elif pr.status == "open": - lines.append( - f'\nTo add commits: upload with revision="refs/pr/{pr_num}"' - ) + lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1} @@ -428,11 +377,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return { - "formatted": f"**PR #{pr_num} merged**\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1} async def _close_pr(self, args: Dict[str, Any]) -> ToolResult: """Close a PR/discussion.""" @@ -456,11 +401,7 @@ class HfRepoGitTool: repo_type=repo_type, ) - return { - "formatted": f"**Discussion #{pr_num} closed**", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1} async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult: """Add a comment to a PR/discussion.""" @@ -486,11 +427,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return { - "formatted": f"**Comment added to #{pr_num}**\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1} async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult: """Change PR/discussion status (mainly to convert draft to open).""" @@ -518,11 +455,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return { - "formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # REPO MANAGEMENT @@ -540,9 +473,7 @@ class HfRepoGitTool: space_sdk = args.get("space_sdk") if repo_type == "space" and not space_sdk: - return self._error( - "space_sdk required for spaces (gradio/streamlit/docker/static)" - ) + return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)") kwargs = { "repo_id": repo_id, @@ -554,17 +485,6 @@ class HfRepoGitTool: kwargs["space_sdk"] = space_sdk result = await _async_call(self.api.create_repo, **kwargs) - extra_metadata = None - if repo_type == "space" and space_sdk: - extra_metadata = {"sdk": space_sdk} - await _async_call( - register_hub_artifact, - self.api, - repo_id, - repo_type, - session=self.session, - extra_metadata=extra_metadata, - ) return { "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}", @@ -584,9 +504,7 @@ class HfRepoGitTool: gated = args.get("gated") if private is None and gated is None: - return self._error( - "Specify private (bool) or gated ('auto'/'manual'/false)" - ) + return self._error("Specify private (bool) or gated ('auto'/'manual'/false)") kwargs = {"repo_id": repo_id, "repo_type": repo_type} if private is not None: @@ -603,20 +521,11 @@ class HfRepoGitTool: changes.append(f"gated={gated}") url = f"{_build_repo_url(repo_id, repo_type)}/settings" - return { - "formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1} def _error(self, message: str) -> ToolResult: """Return an error result.""" - return { - "formatted": message, - "totalResults": 0, - "resultsShared": 0, - "isError": True, - } + return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} # Tool specification @@ -662,20 +571,10 @@ HF_REPO_GIT_TOOL_SPEC = { "operation": { "type": "string", "enum": [ - "create_branch", - "delete_branch", - "create_tag", - "delete_tag", - "list_refs", - "create_pr", - "list_prs", - "get_pr", - "merge_pr", - "close_pr", - "comment_pr", - "change_pr_status", - "create_repo", - "update_repo", + "create_branch", "delete_branch", + "create_tag", "delete_tag", "list_refs", + "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", + "create_repo", "update_repo", ], "description": "Operation to execute", }, @@ -754,13 +653,11 @@ HF_REPO_GIT_TOOL_SPEC = { } -async def hf_repo_git_handler( - arguments: Dict[str, Any], session=None -) -> tuple[str, bool]: +async def hf_repo_git_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]: """Handler for agent tool router.""" try: hf_token = session.hf_token if session else None - tool = HfRepoGitTool(hf_token=hf_token, session=session) + tool = HfRepoGitTool(hf_token=hf_token) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: diff --git a/agent/tools/jobs_tool.py b/agent/tools/jobs_tool.py index 29d6b3017ab3c5641b01d3573ed560d98c103c6b..2c6ebf6c7ed7c73f2dfb049b9924825386336760 100644 --- a/agent/tools/jobs_tool.py +++ b/agent/tools/jobs_tool.py @@ -7,24 +7,20 @@ Refactored to use official huggingface-hub library instead of custom HTTP client import asyncio import base64 import http.client -import logging +import os import re -import shlex -from typing import Any, Awaitable, Callable, Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional, Callable, Awaitable + +import logging import httpx from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError -from agent.core.hf_access import ( - JobsAccessError, - is_billing_error, - resolve_jobs_namespace, -) -from agent.core.hub_artifacts import build_hub_artifact_sitecustomize from agent.core.session import Event -from agent.tools.trackio_seed import ensure_trackio_dashboard from agent.tools.types import ToolResult + +logger = logging.getLogger(__name__) from agent.tools.utilities import ( format_job_details, format_jobs_table, @@ -32,8 +28,6 @@ from agent.tools.utilities import ( format_scheduled_jobs_table, ) -logger = logging.getLogger(__name__) - # Hardware flavors CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"] GPU_FLAVORS = [ @@ -123,11 +117,11 @@ def _filter_uv_install_output(logs: list[str]) -> list[str]: return logs -_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07") +_ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07') def _strip_ansi(text: str) -> str: - return _ANSI_RE.sub("", text) + return _ANSI_RE.sub('', text) _DEFAULT_ENV = { @@ -239,26 +233,6 @@ def _resolve_uv_command( return _build_uv_command(script, with_deps, python, script_args) -def _wrap_command_with_artifact_bootstrap( - command: list[str], session: Any = None -) -> list[str]: - """Install sitecustomize hooks before the user command runs in HF Jobs.""" - sitecustomize = build_hub_artifact_sitecustomize(session) - if not sitecustomize: - return command - - encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii") - original_command = shlex.join(command) - shell = ( - 'set -e; _ml_intern_artifacts_dir="$(mktemp -d)"; ' - f"printf %s {shlex.quote(encoded)} | base64 -d " - '> "$_ml_intern_artifacts_dir/sitecustomize.py"; ' - 'export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"; ' - f"exec {original_command}" - ) - return ["/bin/sh", "-lc", shell] - - async def _async_call(func, *args, **kwargs): """Wrap synchronous HfApi calls for async context""" return await asyncio.to_thread(func, *args, **kwargs) @@ -324,7 +298,6 @@ class HfJobsTool: self, hf_token: Optional[str] = None, namespace: Optional[str] = None, - jobs_access: Any = None, log_callback: Optional[Callable[[str], Awaitable[None]]] = None, session: Any = None, tool_call_id: Optional[str] = None, @@ -332,7 +305,6 @@ class HfJobsTool: self.hf_token = hf_token self.api = HfApi(token=hf_token) self.namespace = namespace - self.jobs_access = jobs_access self.log_callback = log_callback self.session = session self.tool_call_id = tool_call_id @@ -407,31 +379,6 @@ class HfJobsTool: "isError": True, } - async def _seed_trackio_dashboard(self, space_id: str) -> None: - """Idempotently install trackio dashboard files into *space_id* before - the job runs. Surfaces seed progress as tool_log events but never - raises β€” a seed failure should not block job submission, since trackio - often still works when the Space already has dashboard code from a - previous run. - """ - loop = asyncio.get_running_loop() - - def _log(msg: str) -> None: - if self.session is None: - return - loop.call_soon_threadsafe( - self.session.event_queue.put_nowait, - Event(event_type="tool_log", data={"tool": "hf_jobs", "log": msg}), - ) - - try: - await asyncio.to_thread( - ensure_trackio_dashboard, space_id, self.hf_token, _log - ) - except Exception as e: - logger.warning(f"trackio dashboard seed failed for {space_id}: {e}") - _log(f"trackio dashboard seed failed: {e}") - async def _wait_for_job_completion( self, job_id: str, namespace: Optional[str] = None ) -> tuple[str, list[str]]: @@ -456,9 +403,7 @@ class HfJobsTool: def log_producer(): try: # fetch_job_logs is a blocking sync generator - logs_gen = self.api.fetch_job_logs( - job_id=job_id, namespace=namespace - ) + logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace) for line in logs_gen: # Push line to queue thread-safely loop.call_soon_threadsafe(queue.put_nowait, line) @@ -582,66 +527,17 @@ class HfJobsTool: image = args.get("image", "python:3.12") job_type = "Docker" - command = _wrap_command_with_artifact_bootstrap(command, self.session) - # Run the job - flavor = args.get("hardware_flavor", "cpu-basic") - timeout_str = args.get("timeout", "30m") - - # Trackio: agent-declared space + project become env vars on the job - # so trackio.init() picks them up automatically. We also surface them - # in tool_state_change so the frontend can embed the dashboard. - env_dict = _add_default_env(args.get("env")) - trackio_space_id = args.get("trackio_space_id") - trackio_project = args.get("trackio_project") - if trackio_space_id: - env_dict["TRACKIO_SPACE_ID"] = trackio_space_id - await self._seed_trackio_dashboard(trackio_space_id) - if trackio_project: - env_dict["TRACKIO_PROJECT"] = trackio_project - - try: - job = await _async_call( - self.api.run_job, - image=image, - command=command, - env=env_dict, - secrets=_add_environment_variables( - args.get("secrets"), self.hf_token - ), - flavor=flavor, - timeout=timeout_str, - namespace=self.namespace, - ) - except HfHubHTTPError as e: - if is_billing_error(str(e)): - if self.session and self.tool_call_id: - await self.session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": self.tool_call_id, - "tool": "hf_jobs", - "state": "billing_required", - "namespace": self.namespace, - }, - ) - ) - return { - "formatted": ( - f"Hugging Face Jobs rejected this run because the " - f"namespace `{self.namespace}` has no available credits. " - "HF Jobs are billed with namespace credits, which are " - "separate from HF Pro membership. Tell the user to add " - "credits at https://huggingface.co/settings/billing β€” " - "once topped up, re-run this same job. (Switching " - "namespaces is fine if another wallet has credits.)" - ), - "totalResults": 0, - "resultsShared": 0, - "isError": True, - } - raise + job = await _async_call( + self.api.run_job, + image=image, + command=command, + env=_add_default_env(args.get("env")), + secrets=_add_environment_variables(args.get("secrets"), self.hf_token), + flavor=args.get("hardware_flavor", "cpu-basic"), + timeout=args.get("timeout", "30m"), + namespace=self.namespace, + ) # Track job ID for cancellation on interrupt if self.session: @@ -649,55 +545,17 @@ class HfJobsTool: # Send job URL immediately after job creation (before waiting for completion) if self.session and self.tool_call_id: - state_data: Dict[str, Any] = { - "tool_call_id": self.tool_call_id, - "tool": "hf_jobs", - "state": "running", - "jobUrl": job.url, - } - if trackio_space_id: - state_data["trackioSpaceId"] = trackio_space_id - if trackio_project: - state_data["trackioProject"] = trackio_project await self.session.send_event( - Event(event_type="tool_state_change", data=state_data) - ) - - # Telemetry: job submission + completion (infra consumption signal). - submit_ts = None - if self.session: - from agent.core import telemetry - - submit_ts = await telemetry.record_hf_job_submit( - self.session, - job, - { - **args, - "hardware_flavor": flavor, - "timeout": timeout_str, - "namespace": self.namespace, - }, - image=image, - job_type=job_type, - ) - # Top-up signal: this submit succeeded after a prior billing - # block in the same session, and we haven't fired the event - # yet β€” the user came back from the HF billing flow. - events = self.session.logged_events - already_fired = any( - e.get("event_type") == "credits_topped_up" for e in events - ) - if not already_fired: - blocked = any( - e.get("event_type") == "tool_state_change" - and (e.get("data") or {}).get("state") == "billing_required" - for e in events + Event( + event_type="tool_state_change", + data={ + "tool_call_id": self.tool_call_id, + "tool": "hf_jobs", + "state": "running", + "jobUrl": job.url, + }, ) - if blocked: - await telemetry.record_credits_topped_up( - self.session, - namespace=self.namespace, - ) + ) # Wait for completion and stream logs logger.info(f"{job_type} job started: {job.url}") @@ -708,44 +566,29 @@ class HfJobsTool: namespace=self.namespace, ) - if self.session and submit_ts is not None: - from agent.core import telemetry - - await telemetry.record_hf_job_complete( - self.session, - job, - flavor=flavor, - final_status=final_status, - submit_ts=submit_ts, - ) - # Untrack job ID (completed or failed, no longer needs cancellation) if self.session: self.session._running_job_ids.discard(job.id) # Notify frontend of final status if self.session and self.tool_call_id: - final_data: Dict[str, Any] = { - "tool_call_id": self.tool_call_id, - "tool": "hf_jobs", - "state": final_status.lower(), - "jobUrl": job.url, - } - if trackio_space_id: - final_data["trackioSpaceId"] = trackio_space_id - if trackio_project: - final_data["trackioProject"] = trackio_project await self.session.send_event( - Event(event_type="tool_state_change", data=final_data) + Event( + event_type="tool_state_change", + data={ + "tool_call_id": self.tool_call_id, + "tool": "hf_jobs", + "state": final_status.lower(), + "jobUrl": job.url, + }, + ) ) # Filter out UV package installation output filtered_logs = _filter_uv_install_output(all_logs) # Format all logs for the agent - log_text = ( - _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)" - ) + log_text = _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)" response = f"""{job_type} job completed! @@ -937,8 +780,6 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}} image = args.get("image", "python:3.12") job_type = "Docker" - command = _wrap_command_with_artifact_bootstrap(command, self.session) - # Create scheduled job scheduled_job = await _async_call( self.api.create_scheduled_job, @@ -1114,10 +955,7 @@ HF_JOBS_TOOL_SPEC = { "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n" "- Training config MUST include push_to_hub=True and hub_model_id. " "Job storage is EPHEMERAL β€” all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n" - "- Include trackio monitoring and provide the dashboard URL to the user. " - "When the script uses report_to='trackio', also pass `trackio_space_id` " - "(e.g. '/mlintern-<8char>') and `trackio_project` as tool args β€” " - "they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n" + "- Include trackio monitoring and provide the dashboard URL to the user.\n\n" "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. " "Only then submit the remaining jobs. Never submit all at once β€” if there's a bug, all jobs fail.\n\n" "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n" @@ -1200,34 +1038,6 @@ HF_JOBS_TOOL_SPEC = { "type": "object", "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.", }, - "trackio_space_id": { - "type": "string", - "description": ( - "Optional. The HF Space hosting the trackio dashboard for this run " - "(e.g. '/mlintern-<8char>', under YOUR HF namespace). " - "Injected as TRACKIO_SPACE_ID env var and used by the UI to embed " - "the live dashboard. Set this whenever the script uses " - "report_to='trackio'. The Space is auto-created and seeded with the " - "trackio dashboard before the job starts β€” DO NOT pre-create it via " - "hf_repo_git, that produces an empty Space that breaks the embed." - ), - }, - "trackio_project": { - "type": "string", - "description": ( - "Optional. The trackio project name to log this run under. " - "Injected as TRACKIO_PROJECT env var and used by the UI to filter " - "the embedded dashboard to this project." - ), - }, - "namespace": { - "type": "string", - "description": ( - "Optional namespace to run the job under. Must be the caller's own " - "account or an org they belong to. If omitted, defaults to the " - "caller's personal account. Credits are billed against this namespace." - ), - }, "job_id": { "type": "string", "description": "Job ID. Required for: logs, inspect, cancel.", @@ -1263,7 +1073,6 @@ async def hf_jobs_handler( sandbox = getattr(session, "sandbox", None) if session else None if sandbox and script: from agent.tools.sandbox_tool import resolve_sandbox_script - content, error = await resolve_sandbox_script(sandbox, script) if error: return error, False @@ -1271,18 +1080,11 @@ async def hf_jobs_handler( arguments = {**arguments, "script": content} hf_token = session.hf_token if session else None - try: - namespace, jobs_access = await resolve_jobs_namespace( - hf_token or "", - arguments.get("namespace"), - ) - except JobsAccessError as e: - return str(e), False + namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None) tool = HfJobsTool( namespace=namespace, hf_token=hf_token, - jobs_access=jobs_access, log_callback=log_callback if session else None, session=session, tool_call_id=tool_call_id, diff --git a/agent/tools/local_tools.py b/agent/tools/local_tools.py index 50cd5bd65b517f8855ceeb87ffade52a04e25a15..fc456f682eb54fec8a2ee29d5fba07a7d6a4a324 100644 --- a/agent/tools/local_tools.py +++ b/agent/tools/local_tools.py @@ -15,8 +15,6 @@ import tempfile from pathlib import Path from typing import Any -from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap - MAX_OUTPUT_CHARS = 25_000 MAX_LINE_LENGTH = 4000 @@ -24,7 +22,7 @@ DEFAULT_READ_LINES = 2000 DEFAULT_TIMEOUT = 120 MAX_TIMEOUT = 36000 # 10 hours β€” needed for long training runs (e.g. PostTrainBench) -_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07") +_ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07') # Track files that have been read this session (enforces read-before-write/edit) _files_read: set[str] = set() @@ -65,21 +63,17 @@ def _atomic_write(path: Path, content: str) -> None: def _strip_ansi(text: str) -> str: - return _ANSI_RE.sub("", text) + return _ANSI_RE.sub('', text) -def _truncate_output( - output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25 -) -> str: +def _truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25) -> str: """Tail-biased truncation with temp file spillover for full output access.""" if len(output) <= max_chars: return output # Write full output to temp file so LLM can read specific sections spill_path = None try: - with tempfile.NamedTemporaryFile( - mode="w", suffix=".txt", prefix="bash_output_", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', prefix='bash_output_', delete=False) as f: f.write(output) spill_path = f.name except Exception: @@ -99,14 +93,10 @@ def _truncate_output( # ── Handlers ──────────────────────────────────────────────────────────── - -async def _bash_handler( - args: dict[str, Any], session: Any = None, **_kw -) -> tuple[str, bool]: +async def _bash_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: command = args.get("command", "") if not command: return "No command provided.", False - command = wrap_shell_command_with_hub_artifact_bootstrap(command, session) work_dir = args.get("work_dir", ".") timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT) try: @@ -184,12 +174,9 @@ async def _write_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: # Syntax validation for Python files if p.suffix == ".py": from agent.tools.edit_utils import validate_python - warnings = validate_python(content, file_path) if warnings: - msg += "\n\nValidation warnings:\n" + "\n".join( - f" ⚠ {w}" for w in warnings - ) + msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings) return msg, True except Exception as e: return f"write error: {e}", False @@ -242,9 +229,7 @@ async def _edit_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: if p.suffix == ".py": warnings = validate_python(new_text, file_path) if warnings: - msg += "\n\nValidation warnings:\n" + "\n".join( - f" ⚠ {w}" for w in warnings - ) + msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings) return msg, True diff --git a/agent/tools/notify_tool.py b/agent/tools/notify_tool.py deleted file mode 100644 index f926d5a58d5f3c4b877cb8792f812f6e4fa322a7..0000000000000000000000000000000000000000 --- a/agent/tools/notify_tool.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import Any - -from agent.messaging.models import NotificationRequest - -NOTIFY_TOOL_SPEC = { - "name": "notify", - "description": ( - "Send an out-of-band notification to configured messaging destinations. " - "Use this only when the user explicitly asked for proactive notifications " - "or when the task requires reporting progress outside the chat. " - "Destinations must be named server-side configs such as 'slack.ops'." - ), - "parameters": { - "type": "object", - "properties": { - "destinations": { - "type": "array", - "description": "Named messaging destinations to notify.", - "items": {"type": "string"}, - "minItems": 1, - }, - "message": { - "type": "string", - "description": "Main notification body.", - }, - "title": { - "type": "string", - "description": "Optional short title line.", - }, - "severity": { - "type": "string", - "enum": ["info", "success", "warning", "error"], - "description": "Notification severity label.", - }, - }, - "required": ["destinations", "message"], - }, -} - - -async def notify_handler( - arguments: dict[str, Any], session=None, **_kwargs -) -> tuple[str, bool]: - if session is None or session.notification_gateway is None: - return "Messaging is not configured for this session.", False - - raw_destinations = arguments.get("destinations", []) - if not isinstance(raw_destinations, list) or not raw_destinations: - return "destinations must be a non-empty array of destination names.", False - - destinations: list[str] = [] - seen: set[str] = set() - for raw_name in raw_destinations: - if not isinstance(raw_name, str): - return "Each destination must be a string.", False - name = raw_name.strip() - if not name: - return "Destination names must not be empty.", False - if name not in seen: - destinations.append(name) - seen.add(name) - - disallowed = [ - name - for name in destinations - if not session.config.messaging.can_agent_tool_send(name) - ] - if disallowed: - return ( - "These destinations are unavailable for the notify tool: " - + ", ".join(disallowed) - ), False - - message = arguments.get("message", "") - if not isinstance(message, str) or not message.strip(): - return "message must be a non-empty string.", False - - title = arguments.get("title") - severity = arguments.get("severity", "info") - if title is not None and not isinstance(title, str): - return "title must be a string when provided.", False - if severity not in {"info", "success", "warning", "error"}: - return "severity must be one of: info, success, warning, error.", False - - requests = [ - NotificationRequest( - destination=name, - title=title, - message=message, - severity=severity, - metadata={ - "session_id": session.session_id, - "model": session.config.model_name, - }, - ) - for name in destinations - ] - results = await session.notification_gateway.send_many(requests) - - lines = [] - all_ok = True - for result in results: - if result.ok: - lines.append(f"{result.destination}: sent") - else: - all_ok = False - lines.append(f"{result.destination}: failed ({result.error})") - return "\n".join(lines), all_ok diff --git a/agent/tools/papers_tool.py b/agent/tools/papers_tool.py index dea63d7d327999303e76c7e3e155d90107a2fd4f..f6c52ae8a9c5f0e6edb4c3182bb9eea5cf3d23a7 100644 --- a/agent/tools/papers_tool.py +++ b/agent/tools/papers_tool.py @@ -2,14 +2,11 @@ HF Papers Tool β€” Discover papers, read their contents, and find linked resources. Operations: trending, search, paper_details, read_paper, - find_datasets, find_models, find_collections, find_all_resources, - citation_graph, snippet_search, recommend + find_datasets, find_models, find_collections, find_all_resources """ import asyncio -import os import re -import time from typing import Any import httpx @@ -33,105 +30,6 @@ SORT_MAP = { "trending": "trendingScore", } -# --------------------------------------------------------------------------- -# Semantic Scholar API -# --------------------------------------------------------------------------- - -S2_API = "https://api.semanticscholar.org" -S2_API_KEY = os.environ.get("S2_API_KEY") -S2_HEADERS: dict[str, str] = {"x-api-key": S2_API_KEY} if S2_API_KEY else {} -S2_TIMEOUT = 12 -_s2_last_request: float = 0.0 - -# Shared response cache (survives across sessions, keyed by (path, params_tuple)) -_s2_cache: dict[str, Any] = {} -_S2_CACHE_MAX = 500 - - -def _s2_paper_id(arxiv_id: str) -> str: - """Convert bare arxiv ID to S2 format.""" - return f"ARXIV:{arxiv_id}" - - -def _s2_cache_key(path: str, params: dict | None) -> str: - """Build a hashable cache key from path + sorted params.""" - p = tuple(sorted((params or {}).items())) - return f"{path}:{p}" - - -async def _s2_request( - client: httpx.AsyncClient, - method: str, - path: str, - **kwargs: Any, -) -> httpx.Response | None: - """S2 request with 2 retries on 429/5xx. Rate-limited only when using API key.""" - global _s2_last_request - url = f"{S2_API}{path}" - kwargs.setdefault("headers", {}).update(S2_HEADERS) - kwargs.setdefault("timeout", S2_TIMEOUT) - - for attempt in range(3): - # Rate limit only when authenticated (1 req/s for search, 10 req/s for others) - if S2_API_KEY: - min_interval = 1.0 if "search" in path else 0.1 - elapsed = time.monotonic() - _s2_last_request - if elapsed < min_interval: - await asyncio.sleep(min_interval - elapsed) - _s2_last_request = time.monotonic() - - try: - resp = await client.request(method, url, **kwargs) - if resp.status_code == 429: - if attempt < 2: - await asyncio.sleep(60) - continue - return None - if resp.status_code >= 500: - if attempt < 2: - await asyncio.sleep(3) - continue - return None - return resp - except (httpx.RequestError, httpx.HTTPStatusError): - if attempt < 2: - await asyncio.sleep(3) - continue - return None - return None - - -async def _s2_get_json( - client: httpx.AsyncClient, - path: str, - params: dict | None = None, -) -> dict | None: - """Cached S2 GET returning parsed JSON or None.""" - key = _s2_cache_key(path, params) - if key in _s2_cache: - return _s2_cache[key] - - resp = await _s2_request(client, "GET", path, params=params or {}) - if resp and resp.status_code == 200: - data = resp.json() - if len(_s2_cache) < _S2_CACHE_MAX: - _s2_cache[key] = data - return data - return None - - -async def _s2_get_paper( - client: httpx.AsyncClient, - arxiv_id: str, - fields: str, -) -> dict | None: - """Fetch a single paper from S2 by arxiv ID. Returns None on failure.""" - return await _s2_get_json( - client, - f"/graph/v1/paper/{_s2_paper_id(arxiv_id)}", - {"fields": fields}, - ) - # --------------------------------------------------------------------------- # HTML paper parsing @@ -295,7 +193,7 @@ def _format_paper_list( return "\n".join(lines) -def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str: +def _format_paper_detail(paper: dict) -> str: arxiv_id = paper.get("id", "") title = paper.get("title", "Unknown") upvotes = paper.get("upvotes", 0) @@ -307,12 +205,7 @@ def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str: authors = paper.get("authors") or [] lines = [f"# {title}"] - meta_parts = [f"**arxiv_id:** {arxiv_id}", f"**upvotes:** {upvotes}"] - if s2_data: - cites = s2_data.get("citationCount", 0) - influential = s2_data.get("influentialCitationCount", 0) - meta_parts.append(f"**citations:** {cites} ({influential} influential)") - lines.append(" | ".join(meta_parts)) + lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}") lines.append(f"https://huggingface.co/papers/{arxiv_id}") lines.append(f"https://arxiv.org/abs/{arxiv_id}") @@ -325,29 +218,16 @@ def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str: if keywords: lines.append(f"**Keywords:** {', '.join(keywords)}") - if s2_data and s2_data.get("s2FieldsOfStudy"): - fields = [ - f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category") - ] - if fields: - lines.append(f"**Fields:** {', '.join(fields)}") - if s2_data and s2_data.get("venue"): - lines.append(f"**Venue:** {s2_data['venue']}") if github: lines.append(f"**GitHub:** {github} ({stars} stars)") - if s2_data and s2_data.get("tldr"): - tldr_text = s2_data["tldr"].get("text", "") - if tldr_text: - lines.append(f"\n## TL;DR\n{tldr_text}") if ai_summary: lines.append(f"\n## AI Summary\n{ai_summary}") if summary: lines.append(f"\n## Abstract\n{_truncate(summary, 500)}") lines.append( - "\n**Next:** Use read_paper to read specific sections, find_all_resources for linked datasets/models, " - "or citation_graph to trace references and citations." + "\n**Next:** Use read_paper to read specific sections, or find_all_resources to discover linked datasets/models." ) return "\n".join(lines) @@ -399,9 +279,7 @@ def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str: ds_id = ds.get("id", "unknown") downloads = ds.get("downloads", 0) likes = ds.get("likes", 0) - desc = _truncate( - _clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN - ) + desc = _truncate(_clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN) tags = ds.get("tags") or [] interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5] @@ -563,112 +441,11 @@ async def _op_trending(args: dict[str, Any], limit: int) -> ToolResult: } -def _format_s2_paper_list(papers: list[dict], title: str) -> str: - """Format a list of S2 paper results.""" - lines = [f"# {title}"] - lines.append(f"Showing {len(papers)} result(s)\n") - - for i, paper in enumerate(papers, 1): - ptitle = paper.get("title") or "(untitled)" - year = paper.get("year") or "?" - cites = paper.get("citationCount", 0) - venue = paper.get("venue") or "" - ext_ids = paper.get("externalIds") or {} - aid = ext_ids.get("ArXiv", "") - tldr = (paper.get("tldr") or {}).get("text", "") - - lines.append(f"### {i}. {ptitle}") - meta = [f"Year: {year}", f"Citations: {cites}"] - if venue: - meta.append(f"Venue: {venue}") - if aid: - meta.append(f"arxiv_id: {aid}") - lines.append(" | ".join(meta)) - if aid: - lines.append(f"https://arxiv.org/abs/{aid}") - if tldr: - lines.append(f"**TL;DR:** {tldr}") - lines.append("") - - lines.append( - "Use paper_details with arxiv_id for full info, or read_paper to read sections." - ) - return "\n".join(lines) - - -async def _s2_bulk_search( - query: str, args: dict[str, Any], limit: int -) -> ToolResult | None: - """Search via S2 bulk endpoint with filters. Returns None on failure.""" - params: dict[str, Any] = { - "query": query, - "limit": limit, - "fields": "title,externalIds,year,citationCount,tldr,venue,publicationDate", - } - - # Date filter - date_from = args.get("date_from", "") - date_to = args.get("date_to", "") - if date_from or date_to: - params["publicationDateOrYear"] = f"{date_from}:{date_to}" - - # Fields of study - categories = args.get("categories") - if categories: - params["fieldsOfStudy"] = categories - - # Min citations - min_cites = args.get("min_citations") - if min_cites: - params["minCitationCount"] = str(min_cites) - - # Sort - sort_by = args.get("sort_by") - if sort_by and sort_by != "relevance": - params["sort"] = f"{sort_by}:desc" - - async with httpx.AsyncClient(timeout=15) as client: - resp = await _s2_request( - client, "GET", "/graph/v1/paper/search/bulk", params=params - ) - if not resp or resp.status_code != 200: - return None - data = resp.json() - - papers = data.get("data") or [] - if not papers: - return { - "formatted": f"No papers found for '{query}' with the given filters.", - "totalResults": 0, - "resultsShared": 0, - } - - formatted = _format_s2_paper_list( - papers[:limit], f"Papers matching '{query}' (Semantic Scholar)" - ) - return { - "formatted": formatted, - "totalResults": data.get("total", len(papers)), - "resultsShared": min(limit, len(papers)), - } - - async def _op_search(args: dict[str, Any], limit: int) -> ToolResult: query = args.get("query") if not query: return _error("'query' is required for search operation.") - # Route to S2 when filters are present - use_s2 = any( - args.get(k) - for k in ("date_from", "date_to", "categories", "min_citations", "sort_by") - ) - if use_s2: - result = await _s2_bulk_search(query, args, limit) - if result is not None: - return result - # Fall back to HF search (without filters) if S2 fails - async with httpx.AsyncClient(timeout=15) as client: resp = await client.get( f"{HF_API}/papers/search", params={"q": query, "limit": limit} @@ -768,116 +545,6 @@ async def _op_read_paper(args: dict[str, Any], limit: int) -> ToolResult: return {"formatted": formatted, "totalResults": 1, "resultsShared": 1} -# --------------------------------------------------------------------------- -# Citation graph (Semantic Scholar) -# --------------------------------------------------------------------------- - - -def _format_citation_entry(entry: dict, show_context: bool = False) -> str: - """Format a single citation/reference entry.""" - paper = entry.get("citingPaper") or entry.get("citedPaper") or {} - title = paper.get("title") or "(untitled)" - year = paper.get("year") or "?" - cites = paper.get("citationCount", 0) - ext_ids = paper.get("externalIds") or {} - aid = ext_ids.get("ArXiv", "") - influential = " **[influential]**" if entry.get("isInfluential") else "" - - parts = [f"- **{title}** ({year}, {cites} cites){influential}"] - if aid: - parts[0] += f" arxiv:{aid}" - - if show_context: - intents = entry.get("intents") or [] - if intents: - parts.append(f" Intent: {', '.join(intents)}") - contexts = entry.get("contexts") or [] - for ctx in contexts[:2]: - if ctx: - parts.append(f" > {_truncate(ctx, 200)}") - - return "\n".join(parts) - - -def _format_citation_graph( - arxiv_id: str, - references: list[dict] | None, - citations: list[dict] | None, -) -> str: - lines = [f"# Citation Graph for {arxiv_id}"] - lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") - - if references is not None: - lines.append(f"## References ({len(references)})") - if references: - for entry in references: - lines.append(_format_citation_entry(entry)) - else: - lines.append("No references found.") - lines.append("") - - if citations is not None: - lines.append(f"## Citations ({len(citations)})") - if citations: - for entry in citations: - lines.append(_format_citation_entry(entry, show_context=True)) - else: - lines.append("No citations found.") - lines.append("") - - lines.append( - "**Tip:** Use paper_details with an arxiv_id from above to explore further." - ) - return "\n".join(lines) - - -async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult: - arxiv_id = _validate_arxiv_id(args) - if not arxiv_id: - return _error("'arxiv_id' is required for citation_graph.") - - direction = args.get("direction", "both") - s2_id = _s2_paper_id(arxiv_id) - fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential" - params = {"fields": fields, "limit": limit} - - async with httpx.AsyncClient(timeout=15) as client: - refs, cites = None, None - coros = [] - if direction in ("references", "both"): - coros.append( - _s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params) - ) - if direction in ("citations", "both"): - coros.append( - _s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params) - ) - - results = await asyncio.gather(*coros, return_exceptions=True) - idx = 0 - if direction in ("references", "both"): - r = results[idx] - if isinstance(r, dict): - refs = r.get("data", []) - idx += 1 - if direction in ("citations", "both"): - r = results[idx] - if isinstance(r, dict): - cites = r.get("data", []) - - if refs is None and cites is None: - return _error( - f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar." - ) - - total = (len(refs) if refs else 0) + (len(cites) if cites else 0) - return { - "formatted": _format_citation_graph(arxiv_id, refs, cites), - "totalResults": total, - "resultsShared": total, - } - - async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: @@ -1036,154 +703,6 @@ async def _op_find_all_resources(args: dict[str, Any], limit: int) -> ToolResult return {"formatted": formatted, "totalResults": total, "resultsShared": total} -# --------------------------------------------------------------------------- -# Snippet search (Semantic Scholar) -# --------------------------------------------------------------------------- - - -def _format_snippets(snippets: list[dict], query: str) -> str: - lines = [f"# Snippet Search: '{query}'"] - lines.append(f"Found {len(snippets)} matching passage(s)\n") - - for i, item in enumerate(snippets, 1): - paper = item.get("paper") or {} - ptitle = paper.get("title") or "(untitled)" - year = paper.get("year") or "?" - cites = paper.get("citationCount", 0) - ext_ids = paper.get("externalIds") or {} - aid = ext_ids.get("ArXiv", "") - - snippet = item.get("snippet") or {} - text = snippet.get("text", "") - section = snippet.get("section") or "" - - lines.append(f"### {i}. {ptitle} ({year}, {cites} cites)") - if aid: - lines.append(f"arxiv:{aid}") - if section: - lines.append(f"Section: {section}") - if text: - lines.append(f"> {_truncate(text, 400)}") - lines.append("") - - lines.append( - "Use paper_details or read_paper with arxiv_id to explore a paper further." - ) - return "\n".join(lines) - - -async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult: - query = args.get("query") - if not query: - return _error("'query' is required for snippet_search.") - - params: dict[str, Any] = { - "query": query, - "limit": limit, - "fields": "title,externalIds,year,citationCount", - } - - # Optional filters (same as search) - date_from = args.get("date_from", "") - date_to = args.get("date_to", "") - if date_from or date_to: - params["publicationDateOrYear"] = f"{date_from}:{date_to}" - if args.get("categories"): - params["fieldsOfStudy"] = args["categories"] - if args.get("min_citations"): - params["minCitationCount"] = str(args["min_citations"]) - - async with httpx.AsyncClient(timeout=15) as client: - resp = await _s2_request( - client, "GET", "/graph/v1/snippet/search", params=params - ) - if not resp or resp.status_code != 200: - return _error("Snippet search failed. Semantic Scholar may be unavailable.") - data = resp.json() - - snippets = data.get("data") or [] - if not snippets: - return { - "formatted": f"No snippets found for '{query}'.", - "totalResults": 0, - "resultsShared": 0, - } - - return { - "formatted": _format_snippets(snippets, query), - "totalResults": len(snippets), - "resultsShared": len(snippets), - } - - -# --------------------------------------------------------------------------- -# Recommendations (Semantic Scholar) -# --------------------------------------------------------------------------- - - -async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult: - positive_ids = args.get("positive_ids") - arxiv_id = _validate_arxiv_id(args) - - if not arxiv_id and not positive_ids: - return _error("'arxiv_id' or 'positive_ids' is required for recommend.") - - fields = "title,externalIds,year,citationCount,tldr,venue" - - async with httpx.AsyncClient(timeout=15) as client: - if positive_ids and not arxiv_id: - # Multi-paper recommendations (POST, not cached) - pos = [ - _s2_paper_id(pid.strip()) - for pid in positive_ids.split(",") - if pid.strip() - ] - neg_raw = args.get("negative_ids", "") - neg = ( - [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] - if neg_raw - else [] - ) - resp = await _s2_request( - client, - "POST", - "/recommendations/v1/papers/", - json={"positivePaperIds": pos, "negativePaperIds": neg}, - params={"fields": fields, "limit": limit}, - ) - if not resp or resp.status_code != 200: - return _error( - "Recommendation request failed. Semantic Scholar may be unavailable." - ) - data = resp.json() - else: - # Single-paper recommendations (cached) - data = await _s2_get_json( - client, - f"/recommendations/v1/papers/forpaper/{_s2_paper_id(arxiv_id)}", - {"fields": fields, "limit": limit, "from": "recent"}, - ) - if not data: - return _error( - "Recommendation request failed. Semantic Scholar may be unavailable." - ) - - papers = data.get("recommendedPapers") or [] - if not papers: - return { - "formatted": "No recommendations found.", - "totalResults": 0, - "resultsShared": 0, - } - - title = f"Recommended papers based on {arxiv_id or positive_ids}" - return { - "formatted": _format_s2_paper_list(papers[:limit], title), - "totalResults": len(papers), - "resultsShared": min(limit, len(papers)), - } - - # --------------------------------------------------------------------------- # Operation dispatch # --------------------------------------------------------------------------- @@ -1193,9 +712,6 @@ _OPERATIONS = { "search": _op_search, "paper_details": _op_paper_details, "read_paper": _op_read_paper, - "citation_graph": _op_citation_graph, - "snippet_search": _op_snippet_search, - "recommend": _op_recommend, "find_datasets": _op_find_datasets, "find_models": _op_find_models, "find_collections": _op_find_collections, @@ -1210,25 +726,22 @@ _OPERATIONS = { HF_PAPERS_TOOL_SPEC = { "name": "hf_papers", "description": ( - "Discover ML research papers, analyze citations, search paper contents, and find linked resources.\n\n" - "Combines HuggingFace Hub, arXiv, and Semantic Scholar. Use for exploring research areas, " - "finding datasets for a task, tracing citation chains, or implementing a paper's approach.\n\n" - "Typical flows:\n" - " search β†’ read_paper β†’ find_all_resources β†’ hf_inspect_dataset\n" - " search β†’ paper_details β†’ citation_graph β†’ read_paper (trace influence)\n" - " snippet_search β†’ paper_details β†’ read_paper (find specific claims)\n\n" + "Discover ML research papers, find their linked resources (datasets, models, collections), " + "and read paper contents on HuggingFace Hub and arXiv.\n\n" + "Use this when exploring a research area, looking for datasets for a task, " + "implementing a paper's approach, or trying to improve performance on something. " + "Typical flow:\n" + " hf_papers(search/trending) β†’ hf_papers(read_paper) β†’ hf_papers(find_all_resources) β†’ hf_inspect_dataset\n\n" "Operations:\n" "- trending: Get trending daily papers, optionally filter by topic keyword\n" - "- search: Search papers. Uses HF by default (ML-tuned). Add date_from/min_citations/categories to use Semantic Scholar with filters\n" - "- paper_details: Metadata, abstract, AI summary, github link\n" - "- read_paper: Read paper contents β€” without section: abstract + TOC; with section: full text\n" - "- citation_graph: Get references and citations for a paper with influence flags and citation intents\n" - "- snippet_search: Semantic search over full-text passages from 12M+ papers\n" - "- recommend: Find similar papers (single paper or positive/negative examples)\n" + "- search: Full-text search for papers by query\n" + "- paper_details: Get metadata, abstract, AI summary, and github link for a paper\n" + "- read_paper: Read paper contents β€” without section: returns abstract + table of contents; " + "with section: returns full section text\n" "- find_datasets: Find datasets linked to a paper\n" "- find_models: Find models linked to a paper\n" "- find_collections: Find collections that include a paper\n" - "- find_all_resources: Parallel fetch of datasets + models + collections for a paper" + "- find_all_resources: Parallel fetch of datasets + models + collections for a paper (unified view)" ), "parameters": { "type": "object", @@ -1241,69 +754,36 @@ HF_PAPERS_TOOL_SPEC = { "query": { "type": "string", "description": ( - "Search query. Required for: search, snippet_search. " - "Optional for: trending (filters by keyword). " - "Supports boolean syntax for Semantic Scholar: '\"exact phrase\" term1 | term2'." + "Search query. Required for: search. " + "Optional for: trending (filters results by keyword match on title, summary, and AI-generated keywords)." ), }, "arxiv_id": { "type": "string", "description": ( "ArXiv paper ID (e.g. '2305.18290'). " - "Required for: paper_details, read_paper, citation_graph, find_datasets, find_models, find_collections, find_all_resources. " - "Optional for: recommend (single-paper recs). Get IDs from search results first." + "Required for: paper_details, read_paper, find_datasets, find_models, find_collections, find_all_resources. " + "Get IDs from trending or search results first." ), }, "section": { "type": "string", "description": ( "Section name or number to read (e.g. '3', 'Experiments', '4.2'). " - "Optional for: read_paper. Without this, returns abstract + TOC." + "Optional for: read_paper. Without this, read_paper returns the abstract + table of contents " + "so you can choose which section to read." ), }, - "direction": { - "type": "string", - "enum": ["citations", "references", "both"], - "description": "Direction for citation_graph. Default: both.", - }, "date": { "type": "string", "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).", }, - "date_from": { - "type": "string", - "description": "Start date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", - }, - "date_to": { - "type": "string", - "description": "End date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", - }, - "categories": { - "type": "string", - "description": "Field of study filter (e.g. 'Computer Science'). Triggers Semantic Scholar search.", - }, - "min_citations": { - "type": "integer", - "description": "Minimum citation count filter. Triggers Semantic Scholar search.", - }, - "sort_by": { - "type": "string", - "enum": ["relevance", "citationCount", "publicationDate"], - "description": "Sort order for Semantic Scholar search. Default: relevance.", - }, - "positive_ids": { - "type": "string", - "description": "Comma-separated arxiv IDs for multi-paper recommendations. For: recommend.", - }, - "negative_ids": { - "type": "string", - "description": "Comma-separated arxiv IDs as negative examples. For: recommend.", - }, "sort": { "type": "string", "enum": ["downloads", "likes", "trending"], "description": ( - "Sort order for find_datasets and find_models. Default: downloads." + "Sort order for find_datasets and find_models. Default: downloads. " + "Use 'downloads' for most-used, 'likes' for community favorites, 'trending' for recently popular." ), }, "limit": { diff --git a/agent/tools/research_tool.py b/agent/tools/research_tool.py index f5815be8332ef371d3e863652bfc6cdd5127bbc2..17af9d60321055dc7dc766b46c4b728446e07c70 100644 --- a/agent/tools/research_tool.py +++ b/agent/tools/research_tool.py @@ -9,15 +9,12 @@ Inspired by claude-code's code-explorer agent pattern. import json import logging -import time +import os from typing import Any from litellm import Message, acompletion -from agent.core import telemetry from agent.core.doom_loop import check_for_doom_loop -from agent.core.llm_params import _resolve_llm_params -from agent.core.prompt_caching import with_prompt_caching from agent.core.session import Event logger = logging.getLogger(__name__) @@ -39,56 +36,47 @@ RESEARCH_TOOL_NAMES = { "github_find_examples", "github_list_repos", "github_read_file", - "web_search", "hf_inspect_dataset", "hf_repo_files", } RESEARCH_SYSTEM_PROMPT = """\ You are a research sub-agent for an ML engineering assistant. -Your primary job: mine the literature to find the best training recipes β€” -then back them up with working code and up to date documantation. The main agent will use +Your job: explore documentation, code examples, APIs, and repos, +then return a concise, actionable summary. The main agent will use your findings to implement the actual solution. -# Start from the literature +# Being up to date is critical -Your default approach is a deep literature crawl. Do not start from docs or -example scripts β€” start from papers. Papers contain the results, and results -tell you what actually works. +Always prioritize finding the most current, state-of-the-art approaches. +ML moves fast β€” a method from 6 months ago may already be obsolete. -## The crawl +- Search for **recent papers** (use `hf_papers`) to find SOTA methods, models, and datasets for the task +- Compare what you find in docs/examples against what recent papers recommend β€” prefer the newer approach +- When multiple approaches exist, identify which is SOTA and why (benchmark results, adoption, recency) +- Include in your findings: what is the current best model, dataset, and method for the task -1. **Find anchor papers**: Search for the task/domain. Identify the landmark paper(s) β€” high citations, recent, or both. -2. **Crawl the citation graph**: Use `citation_graph` on the anchor paper(s). Look DOWNSTREAM (papers that cite it) β€” these are the ones that built on it, improved it, or applied it to new domains. Prioritize recent papers and papers with many citations. -3. **Read methodology sections**: For the most promising papers (strong results, recent, relevant), use `read_paper` with section parameter to read sections 3, 4, 5 (Methodology, Experiments, Results β€” not the abstract). Extract: - - The exact dataset(s) used (name, source, size, any filtering/preprocessing) - - The training method and configuration (optimizer, lr, schedule, epochs, batch size) - - The results those choices produced (benchmark scores, metrics, comparisons) -4. **Attribute results to recipes**: This is the critical step. Every finding must link a RESULT to the RECIPE that produced it. "Dataset X + method Y + lr Z β†’ score W on benchmark V" is useful. "They used SFT" is not. -5. **Validate datasets**: For the most promising datasets, check if they exist on HF Hub with `hf_inspect_dataset`. Verify format matches the training method. Report if doesnt. -6. **Find code**: Now find working implementation code via `github_find_examples` and `github_read_file`. Use docs (`explore_hf_docs`, `fetch_hf_docs`) to fill in API details. +# Research methodology -## When to go deeper - -- If the anchor paper is old (>1 year), its citation graph is your main source β€” the downstream papers will have better methods. -- If a downstream paper reports significantly better results, crawl ITS citation graph too. -- Use `snippet_search` to find specific claims across papers (e.g., "does dataset X consistently outperform Y for this task?"). -- Use `recommend` to find related papers the citation graph might miss. +1. **Discovery**: Find relevant entry points β€” example scripts, doc pages, API endpoints, **and recent papers for SOTA approaches** +2. **Tracing**: Follow the chain from entry point to implementation detail +3. **Analysis**: Identify patterns, current API usage, key dependencies. **Compare against SOTA from recent papers** +4. **Synthesis**: Summarize findings in a structured format, highlighting what is current best practice vs. outdated # How to use your tools -## Papers & citations (USE FIRST) -- `hf_papers(operation="search", query=...)`: Search papers (HF-tuned for ML) -- `hf_papers(operation="search", query=..., min_citations=50, sort_by="citationCount")`: Find highly-cited papers via Semantic Scholar -- `hf_papers(operation="search", query=..., date_from="2024-01-01")`: Search with date filter -- `hf_papers(operation="paper_details", arxiv_id=...)`: Metadata, citations, TL;DR -- `hf_papers(operation="citation_graph", arxiv_id=...)`: References + citations with influence flags and intents -- `hf_papers(operation="read_paper", arxiv_id=..., section="3")`: Read a specific section's full text -- `hf_papers(operation="read_paper", arxiv_id=...)`: Get TOC (abstract + section list) β€” use this to find which section numbers contain methodology/experiments -- `hf_papers(operation="snippet_search", query=...)`: Semantic search across 12M+ full-text paper passages -- `hf_papers(operation="recommend", arxiv_id=...)`: Find related papers -- `hf_papers(operation="find_datasets", arxiv_id=...)`: Find HF datasets linked to a paper -- `hf_papers(operation="find_all_resources", arxiv_id=...)`: Datasets + models + collections for a paper +## GitHub code research (USE FIRST for any ML implementation task) +- `github_find_examples`: Find working example scripts in HF repos (trl, transformers, etc.) + Example: `github_find_examples({"repo": "trl", "keyword": "sft"})` + Returns: file paths in examples/, scripts/, notebooks/ directories +- `github_read_file`: Read the actual implementation code + Example: `github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"})` + Use line_start/line_end for large files + +## Documentation +- `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc. +- `fetch_hf_docs(url)`: Fetch full page content from explore results +- `find_hf_api(query=..., tag=...)`: Find REST API endpoints ## Dataset inspection - `hf_inspect_dataset`: Check dataset schema, splits, sample rows @@ -97,77 +85,38 @@ tell you what actually works. - DPO: needs "prompt", "chosen", "rejected" - GRPO: needs "prompt" only -## GitHub code research -- `github_find_examples`: Find working example scripts in HF repos (trl, transformers, etc.) -- `github_read_file`: Read the actual implementation code. Use line_start/line_end for large files. - -## Documentation -- `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc. -- `fetch_hf_docs(url)`: Fetch full page content from explore results -- `find_hf_api(query=..., tag=...)`: Find REST API endpoints -- `web_search(query=..., allowed_domains=[...], blocked_domains=[...])`: - Search the current web when papers/docs/GitHub are not enough. +## Papers +- `hf_papers`: Search papers, get details, find linked datasets/models ## Hub repo inspection - `hf_repo_files`: List/read files in any HF repo (model, dataset, space) -# Correct research pattern +# Correct research pattern for ML tasks ``` -# 1. Find anchor paper(s) for the task -hf_papers({"operation": "search", "query": "GPQA graduate questions", "sort_by": "citationCount"}) - -# 2. Crawl citation graph β€” look downstream -hf_papers({"operation": "citation_graph", "arxiv_id": "2311.12022", "direction": "citations"}) - -# 3. Read methodology of promising downstream papers -hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348"}) # TOC first -hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "3"}) # Methodology -hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "4"}) # Experiments - -# 4. Find datasets used by these papers -hf_papers({"operation": "find_datasets", "arxiv_id": "2604.01348"}) -hf_papers({"operation": "find_all_resources", "arxiv_id": "2604.01348"}) - -# 5. Validate datasets exist and have correct format -hf_inspect_dataset({"dataset": "org/dataset-name", "split": "train", "sample_rows": 3}) - -# 6. Now get working code for the training method +# 1. Find working example code FIRST github_find_examples({"repo": "trl", "keyword": "sft"}) + +# 2. Read the implementation github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"}) + +# 3. Check docs for parameters/config details explore_hf_docs("trl") +fetch_hf_docs("https://huggingface.co/docs/trl/sft_trainer") + +# 4. Validate dataset format if relevant +hf_inspect_dataset({"dataset": "org/name", "split": "train", "sample_rows": 3}) ``` # Output format - - -Your output MUST be structured as a ranked list of training recipes, each attributed to published results: - -## Recipe table (REQUIRED) -For each promising approach found, report: -- **Paper**: title, arxiv_id, date, venue -- **Result**: exact benchmark scores and what they were measured on -- **Dataset(s)**: name, size, source, HF Hub availability, format verified (yes/no) -- **Method**: training approach, key hyperparameters (lr, epochs, batch size, optimizer, schedule) -- **What made it work**: the specific insight or trick that drove the result (data curation, curriculum, loss function, etc.) - -Rank recipes by result quality. The main agent will pick the best one that's feasible. - -## Code patterns -- Key imports, configurations, and usage patterns from working examples -- Specific file paths, URLs, function names from docs - -## Recommendations -- Which recipe to implement first and why -- What datasets to use (with HF Hub paths, verified) -- Any gaps: datasets that need preprocessing, methods that need adaptation - -Additionally include: +Your output MUST include: - **SOTA landscape**: Current best models, datasets, and methods for the task (from recent papers). Flag anything outdated. +- **Key findings**: The most important things you discovered (current API usage, working patterns) - **Essential references**: Specific file paths, URLs, function names, doc sections, code snippets that the main agent should use directly - **Code patterns**: Key imports, configurations, and usage patterns from working examples +- **Recommendations**: What to do next based on your findings, preferring SOTA approaches Be concise. Your output goes into another agent's context β€” every token counts. Aim for 500-1500 words max. Include actual code snippets from examples you read, @@ -219,18 +168,34 @@ RESEARCH_TOOL_SPEC = { } +def _resolve_llm_params(model_name: str) -> dict: + """Build LiteLLM kwargs, reusing the HF router logic from agent_loop.""" + if not model_name.startswith("huggingface/"): + return {"model": model_name} + + parts = model_name.split("/", 2) # ["huggingface", "", "/"] + if len(parts) < 3: + return {"model": model_name} + + provider = parts[1] + model_id = parts[2] + return { + "model": f"openai/{model_id}", + "api_base": f"https://router.huggingface.co/{provider}/v3/openai", + "api_key": os.environ.get("INFERENCE_TOKEN", ""), + } + + def _get_research_model(main_model: str) -> str: """Pick a cheaper model for research based on the main model.""" - if main_model.startswith("anthropic/"): + if "anthropic/" in main_model: return "anthropic/claude-sonnet-4-6" - if main_model.startswith("bedrock/") and "anthropic" in main_model: - return "bedrock/us.anthropic.claude-sonnet-4-6" # For non-Anthropic models (HF router etc.), use the same model return main_model async def research_handler( - arguments: dict[str, Any], session=None, tool_call_id: str | None = None, **_kw + arguments: dict[str, Any], session=None, **_kw ) -> tuple[str, bool]: """Execute a research sub-agent with its own context.""" task = arguments.get("task", "") @@ -254,17 +219,7 @@ async def research_handler( # Use a cheaper/faster model for research main_model = session.config.model_name research_model = _get_research_model(main_model) - # Research is a cheap sub-call β€” cap the main session's effort at "high" - # so a user preference of ``max``/``xhigh`` (valid for Opus 4.6/4.7) doesn't - # propagate to a Sonnet research model that may not accept those levels. - # We also haven't probed this sub-model so we don't know its ceiling. - _pref = getattr(session.config, "reasoning_effort", None) - _capped = "high" if _pref in ("max", "xhigh") else _pref - llm_params = _resolve_llm_params( - research_model, - getattr(session, "hf_token", None), - reasoning_effort=_capped, - ) + llm_params = _resolve_llm_params(research_model) # Get read-only tool specs from the session's tool router tool_specs = [ @@ -273,32 +228,11 @@ async def research_handler( if spec["function"]["name"] in RESEARCH_TOOL_NAMES ] - # Unique ID + short label so parallel agents show separate status lines. - # Use the tool_call_id when available β€” it's unique per invocation and lets - # the frontend match a research tool card to its agent state. Fall back to - # uuid for offline/test paths. Previously used md5(task), which collided - # when the same task string was researched in parallel. - if tool_call_id: - _agent_id = tool_call_id - else: - import uuid - - _agent_id = uuid.uuid4().hex[:8] - _agent_label = "research: " + (task[:50] + "…" if len(task) > 50 else task) - async def _log(text: str) -> None: """Send a progress event to the UI so it doesn't look frozen.""" try: await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "research", - "log": text, - "agent_id": _agent_id, - "label": _agent_label, - }, - ) + Event(event_type="tool_log", data={"tool": "research", "log": text}) ) except Exception: pass @@ -315,10 +249,8 @@ async def research_handler( # ── Doom-loop detection ── doom_prompt = check_for_doom_loop(messages) if doom_prompt: - logger.warning( - "Research sub-agent repetition guard activated at iteration %d", - _iteration, - ) + logger.warning("Research sub-agent doom loop detected at iteration %d", _iteration) + await _log("Doom loop detected β€” injecting corrective prompt") messages.append(Message(role="user", content=doom_prompt)) # ── Context budget: warn at 75%, hard-stop at 95% ── @@ -327,93 +259,49 @@ async def research_handler( "Research sub-agent hit context max (%d tokens) β€” forcing summary", _total_tokens, ) - await _log( - f"Context limit reached ({_total_tokens} tokens) β€” forcing wrap-up" - ) + await _log(f"Context limit reached ({_total_tokens} tokens) β€” forcing wrap-up") # Ask for a final summary with no tools - messages.append( - Message( - role="user", - content=( - "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. " - "Summarize your findings NOW. Do NOT call any more tools." - ), - ) - ) + messages.append(Message( + role="user", + content=( + "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. " + "Summarize your findings NOW. Do NOT call any more tools." + ), + )) try: - _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model")) - _t0 = time.monotonic() response = await acompletion( - messages=_msgs, + messages=messages, tools=None, # no tools β€” force text response stream=False, timeout=120, **llm_params, ) - # Telemetry is best-effort; a logging blip must never mask a - # valid LLM response (the surrounding except would convert it - # to "summary call failed"). - try: - await telemetry.record_llm_call( - session, - model=research_model, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind="research", - ) - except Exception as _telem_err: - logger.debug("research telemetry failed: %s", _telem_err) content = response.choices[0].message.content or "" - return ( - content or "Research context exhausted β€” no summary produced.", - bool(content), - ) + return content or "Research context exhausted β€” no summary produced.", bool(content) except Exception: return "Research context exhausted and summary call failed.", False if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN: _warned_context = True await _log(f"Context at {_total_tokens} tokens β€” nudging to wrap up") - messages.append( - Message( - role="user", - content=( - "[SYSTEM: You have used 75% of your context budget. " - "Start wrapping up: finish any critical lookups, then " - "produce your final summary within the next 1-2 iterations.]" - ), - ) - ) + messages.append(Message( + role="user", + content=( + "[SYSTEM: You have used 75% of your context budget. " + "Start wrapping up: finish any critical lookups, then " + "produce your final summary within the next 1-2 iterations.]" + ), + )) try: - _msgs, _tools = with_prompt_caching( - messages, tool_specs if tool_specs else None, llm_params.get("model") - ) - _t0 = time.monotonic() response = await acompletion( - messages=_msgs, - tools=_tools, + messages=messages, + tools=tool_specs if tool_specs else None, tool_choice="auto", stream=False, timeout=120, **llm_params, ) - try: - await telemetry.record_llm_call( - session, - model=research_model, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind="research", - ) - except Exception as _telem_err: - logger.debug("research telemetry failed: %s", _telem_err) except Exception as e: logger.error("Research sub-agent LLM error: %s", e) return f"Research agent LLM error: {e}", False @@ -432,18 +320,8 @@ async def research_handler( content = msg.content or "Research completed but no summary generated." return content, True - # Execute tool calls and add results. - # Rebuild the assistant message with only the wire-safe fields β€” - # LiteLLM's raw Message carries `provider_specific_fields` and - # `reasoning_content`, which the HF router's OpenAI schema rejects - # if we echo them back in the next request. - messages.append( - Message( - role="assistant", - content=msg.content, - tool_calls=msg.tool_calls, - ) - ) + # Execute tool calls and add results + messages.append(msg) for tc in msg.tool_calls: try: tool_args = json.loads(tc.function.arguments) @@ -477,7 +355,7 @@ async def research_handler( await _log(f"β–Έ {tool_name} {args_str}") output, _success = await session.tool_router.call_tool( - tool_name, tool_args, session=session, tool_call_id=tc.id + tool_name, tool_args, session=session ) _tool_uses += 1 await _log(f"tools:{_tool_uses}") @@ -498,38 +376,21 @@ async def research_handler( # ── Iteration limit: try to salvage findings ── await _log("Iteration limit reached β€” extracting summary") - messages.append( - Message( - role="user", - content=( - "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research " - "iterations. Summarize ALL findings so far. Do NOT call any more tools." - ), - ) - ) + messages.append(Message( + role="user", + content=( + "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research " + "iterations. Summarize ALL findings so far. Do NOT call any more tools." + ), + )) try: - _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model")) - _t0 = time.monotonic() response = await acompletion( - messages=_msgs, + messages=messages, tools=None, stream=False, timeout=120, **llm_params, ) - try: - await telemetry.record_llm_call( - session, - model=research_model, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind="research", - ) - except Exception as _telem_err: - logger.debug("research telemetry failed: %s", _telem_err) content = response.choices[0].message.content or "" if content: return content, True diff --git a/agent/tools/sandbox_client.py b/agent/tools/sandbox_client.py index 24f85e1fb719549ddca4925490e44e9a275986af..7e3eaf715f9de8f7217aa42efe6bf1b0770585fa 100644 --- a/agent/tools/sandbox_client.py +++ b/agent/tools/sandbox_client.py @@ -13,7 +13,7 @@ Architecture: - Optionally deletes the Space when done Lifecycle: - sb = Sandbox.create(owner="burtenshaw") # duplicate private Space, wait, connect + sb = Sandbox.create(owner="burtenshaw") # duplicate, wait, connect sb = Sandbox.create(owner="burtenshaw", # with options hardware="t4-small", private=True, @@ -37,7 +37,6 @@ Tools: bash, read, write, edit, upload from __future__ import annotations import io -import secrets as secrets_lib import sys import time import uuid @@ -65,70 +64,6 @@ MAX_TIMEOUT = 1200 WAIT_TIMEOUT = 600 WAIT_INTERVAL = 5 API_WAIT_TIMEOUT = 180 -HARDWARE_REQUEST_TIMEOUT = 60 -CPU_BASIC_HARDWARE = "cpu-basic" - - -def _is_transient_space_visibility_error(error: Exception) -> bool: - """Return True when a newly duplicated Space is not queryable yet.""" - response = getattr(error, "response", None) - if getattr(response, "status_code", None) == 404: - return True - message = str(error) - return "Repository Not Found" in message or "404 Client Error" in message - - -def _is_transient_space_management_error(error: Exception) -> bool: - """Return True when a just-created private Space is not manageable yet.""" - response = getattr(error, "response", None) - if getattr(response, "status_code", None) in {401, 404}: - return True - message = str(error) - return ( - "Repository Not Found" in message - or "401 Client Error" in message - or "404 Client Error" in message - ) - - -def _request_space_hardware_with_retry( - api: HfApi, - space_id: str, - *, - hardware: str, - sleep_time: int | None, - log: Callable[[str], object], - check_cancel: Callable[[], object], -) -> None: - """Request hardware, retrying while Hub permissions propagate for a new Space.""" - deadline = time.time() + HARDWARE_REQUEST_TIMEOUT - attempt = 0 - while True: - check_cancel() - try: - api.request_space_hardware( - space_id, - hardware=hardware, - sleep_time=sleep_time, - ) - return - except Exception as e: - if not _is_transient_space_management_error(e): - raise - - remaining = deadline - time.time() - if remaining <= 0: - raise - - attempt += 1 - status_code = getattr(getattr(e, "response", None), "status_code", None) - status = f"HTTP {status_code}" if status_code else type(e).__name__ - log( - f" Hardware request not accepted yet ({status}); " - f"retrying ({attempt})..." - ) - time.sleep(min(WAIT_INTERVAL, remaining)) - _DOCKERFILE = """\ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim @@ -164,8 +99,8 @@ CMD ["python", "sandbox_server.py"] _SANDBOX_SERVER = '''\ """Minimal FastAPI server for sandbox operations.""" -import hmac, os, subprocess, pathlib, signal, threading, re, tempfile -from fastapi import Depends, FastAPI, HTTPException, Request +import os, subprocess, pathlib, signal, threading, re, tempfile +from fastapi import FastAPI from pydantic import BaseModel from typing import Optional import uvicorn @@ -221,24 +156,6 @@ def _atomic_write(path: pathlib.Path, content: str): app = FastAPI() -def _bearer_token(header: str) -> str: - scheme, _, supplied = header.partition(" ") - if scheme.lower() != "bearer" or not supplied: - return "" - return supplied - -def _require_auth(request: Request) -> None: - sandbox_token = os.environ.get("SANDBOX_API_TOKEN") or "" - if not sandbox_token: - raise HTTPException(status_code=503, detail="Sandbox API token not configured") - supplied = _bearer_token(request.headers.get("x-sandbox-authorization", "")) - if not supplied: - raise HTTPException(status_code=401, detail="Missing bearer token") - if not hmac.compare_digest(supplied, sandbox_token): - raise HTTPException(status_code=401, detail="Invalid bearer token") - -_AUTH = [Depends(_require_auth)] - # Track active bash processes so they can be killed on cancel _active_procs = {} # pid -> subprocess.Popen _proc_lock = threading.Lock() @@ -427,7 +344,7 @@ def _validate_python(content, path=""): def health(): return {"status": "ok"} -@app.post("/api/bash", dependencies=_AUTH) +@app.post("/api/bash") def bash(req: BashReq): try: proc = subprocess.Popen( @@ -454,7 +371,7 @@ def bash(req: BashReq): except Exception as e: return {"success": False, "output": "", "error": str(e)} -@app.post("/api/kill", dependencies=_AUTH) +@app.post("/api/kill") def kill_all(): """Kill all active bash processes. Called when user cancels.""" with _proc_lock: @@ -472,7 +389,7 @@ def kill_all(): pass return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""} -@app.post("/api/read", dependencies=_AUTH) +@app.post("/api/read") def read(req: ReadReq): try: p = pathlib.Path(req.path) @@ -489,7 +406,7 @@ def read(req: ReadReq): except Exception as e: return {"success": False, "output": "", "error": str(e)} -@app.post("/api/write", dependencies=_AUTH) +@app.post("/api/write") def write(req: WriteReq): try: p = pathlib.Path(req.path) @@ -503,7 +420,7 @@ def write(req: WriteReq): except Exception as e: return {"success": False, "output": "", "error": str(e)} -@app.post("/api/edit", dependencies=_AUTH) +@app.post("/api/edit") def edit(req: EditReq): try: p = pathlib.Path(req.path) @@ -530,7 +447,7 @@ def edit(req: EditReq): except Exception as e: return {"success": False, "output": "", "error": str(e)} -@app.post("/api/exists", dependencies=_AUTH) +@app.post("/api/exists") def exists(req: ExistsReq): return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""} @@ -565,7 +482,6 @@ class Sandbox: space_id: str token: str | None = None - api_token: str | None = field(default=None, repr=False) work_dir: str = "/app" timeout: int = DEFAULT_TIMEOUT _owns_space: bool = field(default=False, repr=False) @@ -581,26 +497,12 @@ class Sandbox: self._base_url = f"https://{slug}.hf.space/api/" self._client = httpx.Client( base_url=self._base_url, - headers=self._auth_headers(), + headers={"Authorization": f"Bearer {self.token}"} if self.token else {}, timeout=httpx.Timeout(MAX_TIMEOUT, connect=30), follow_redirects=True, ) self._hf_api = HfApi(token=self.token) - def _auth_headers(self) -> dict[str, str]: - """Return headers for private HF Space access plus sandbox API auth. - - Private Spaces require the HF token in ``Authorization`` at the Hub - edge. The sandbox server requires its control-plane token in the - dedicated ``X-Sandbox-Authorization`` header. - """ - headers: dict[str, str] = {} - if self.token: - headers["Authorization"] = f"Bearer {self.token}" - if self.api_token: - headers["X-Sandbox-Authorization"] = f"Bearer {self.api_token}" - return headers - # ── Lifecycle ───────────────────────────────────────────────── class Cancelled(Exception): @@ -613,11 +515,10 @@ class Sandbox: *, name: str | None = None, template: str = TEMPLATE_SPACE, - hardware: str = CPU_BASIC_HARDWARE, - private: bool = True, + hardware: str = "cpu-basic", + private: bool = False, sleep_time: int | None = None, token: str | None = None, - secrets: dict[str, str] | None = None, wait_timeout: int = WAIT_TIMEOUT, log: "Callable[[str], object] | None" = None, cancel_event: "Any | None" = None, @@ -634,7 +535,7 @@ class Sandbox: A unique suffix is always appended. template: Source Space to duplicate (default: burtenshaw/sandbox). hardware: Hardware tier (cpu-basic, t4-small, etc.). - private: Whether the Space should be private. Defaults to True. + private: Whether the Space should be private. sleep_time: Auto-sleep after N seconds of inactivity. token: HF API token (from user's OAuth session). wait_timeout: Max seconds to wait for Space to start (default: 300). @@ -661,7 +562,6 @@ class Sandbox: base = name or "sandbox" suffix = uuid.uuid4().hex[:8] space_id = f"{owner}/{base}-{suffix}" - sandbox_api_token = secrets_lib.token_urlsafe(32) _log(f"Creating sandbox: {space_id} (from {template})...") @@ -679,33 +579,6 @@ class Sandbox: _check_cancel() - # ``duplicate_space`` already receives the target hardware. The extra - # /hardware call is useful for paid tiers, but hosted OAuth tokens can - # 401 on that endpoint for a fresh private Space even after duplication - # succeeds. Avoid the redundant call for default CPU sandboxes when no - # auto-sleep timer is requested; with sleep_time set, the hardware - # endpoint is still needed to configure auto-sleep. - if hardware == CPU_BASIC_HARDWARE and sleep_time is None: - _log(f"Using duplicated Space hardware: {hardware}") - else: - _request_space_hardware_with_retry( - api, - space_id, - hardware=hardware, - sleep_time=sleep_time, - log=_log, - check_cancel=_check_cancel, - ) - _log(f"Requested hardware: {hardware}") - - # Inject secrets BEFORE uploading server files (which triggers rebuild). - # Secrets added after a Space is running aren't available until restart, - # so they must be set before the build/start cycle. - sandbox_secrets = {**(secrets or {}), "SANDBOX_API_TOKEN": sandbox_api_token} - if sandbox_secrets: - for key, val in sandbox_secrets.items(): - api.add_space_secret(space_id, key, val) - # Upload sandbox server and Dockerfile (triggers rebuild) cls._setup_server(space_id, api, log=_log) @@ -716,22 +589,8 @@ class Sandbox: deadline = time.time() + wait_timeout while time.time() < deadline: _check_cancel() - try: - runtime = api.get_space_runtime(space_id) - except Exception as e: - if _is_transient_space_visibility_error(e): - _log(" Space runtime not visible yet...") - time.sleep(WAIT_INTERVAL) - continue - raise + runtime = api.get_space_runtime(space_id) if runtime.stage == "RUNNING": - current_hardware = runtime.hardware or getattr( - runtime, "requested_hardware", None - ) - if current_hardware != hardware: - _log(f" RUNNING on {current_hardware}; waiting for {hardware}...") - time.sleep(WAIT_INTERVAL) - continue _log(f"Space is running (hardware: {runtime.hardware})") break if runtime.stage in ("RUNTIME_ERROR", "BUILD_ERROR"): @@ -750,12 +609,7 @@ class Sandbox: _check_cancel() # Wait for the API server to be responsive (non-fatal) - sb = cls( - space_id=space_id, - token=token, - api_token=sandbox_api_token, - _owns_space=True, - ) + sb = cls(space_id=space_id, token=token, _owns_space=True) try: sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log) except TimeoutError as e: @@ -765,9 +619,7 @@ class Sandbox: return sb @staticmethod - def _setup_server( - space_id: str, api: HfApi, *, log: Callable[[str], object] = print - ) -> None: + def _setup_server(space_id: str, api: HfApi, *, log: Callable[[str], object] = print) -> None: """Upload embedded sandbox server + Dockerfile to the Space (single commit).""" log(f"Uploading sandbox server to {space_id}...") api.create_commit( @@ -788,30 +640,17 @@ class Sandbox: log("Server files uploaded, rebuild triggered.") @classmethod - def connect( - cls, - space_id: str, - *, - token: str | None = None, - api_token: str | None = None, - ) -> Sandbox: + def connect(cls, space_id: str, *, token: str | None = None) -> Sandbox: """ Connect to an existing running Space. Does a health check to verify the Space is reachable. """ - sb = cls( - space_id=space_id, - token=token, - api_token=api_token, - _owns_space=False, - ) + sb = cls(space_id=space_id, token=token, _owns_space=False) sb._wait_for_api(timeout=60) return sb - def _wait_for_api( - self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print - ): + def _wait_for_api(self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print): """Poll the health endpoint until the server responds.""" deadline = time.time() + timeout last_err = None @@ -840,10 +679,6 @@ class Sandbox: ) print(f"Deleting sandbox: {self.space_id}...") self._hf_api.delete_repo(self.space_id, repo_type="space") - # Clear ownership so a second cleanup call (e.g. delete_session + - # _run_session.finally both fire) early-returns instead of retrying - # a 404 delete and emitting a spurious ERROR log. - self._owns_space = False self._client.close() print("Deleted.") @@ -988,12 +823,7 @@ class Sandbox: return result def edit( - self, - path: str, - old_str: str, - new_str: str, - *, - replace_all: bool = False, + self, path: str, old_str: str, new_str: str, *, replace_all: bool = False, mode: str = "replace", ) -> ToolResult: if old_str == new_str: diff --git a/agent/tools/sandbox_tool.py b/agent/tools/sandbox_tool.py index fbc6a41f9fd9edf05b1565d5782983bde167fa3c..35676c7d75b1d5441ce73096c7e1f1badd111ae7 100644 --- a/agent/tools/sandbox_tool.py +++ b/agent/tools/sandbox_tool.py @@ -2,60 +2,23 @@ Sandbox tools β€” expose the Sandbox client as agent tools. 5 tools total: - sandbox_create β€” create/replace sandbox for non-default hardware - bash, read, write, edit β€” operations on the active sandbox + sandbox_create β€” explicit sandbox creation (requires approval) + bash, read, write, edit β€” operations on the sandbox -A cpu-basic sandbox is preloaded for each session. Operation tools wait for it -if startup is still in progress. +If any operation tool is called without an active sandbox, +a cpu-basic sandbox is auto-created (no approval needed). """ from __future__ import annotations import asyncio -import logging -import re import threading -import weakref -from datetime import datetime, timedelta, timezone from typing import Any from huggingface_hub import HfApi, SpaceHardware -from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap from agent.core.session import Event from agent.tools.sandbox_client import Sandbox -from agent.tools.trackio_seed import ensure_trackio_dashboard - -logger = logging.getLogger(__name__) - -DEFAULT_CPU_SANDBOX_HARDWARE = "cpu-basic" - -# Match the exact suffix pattern Sandbox.create produces: "sandbox-<8 hex>". -# Used to identify orphan sandboxes from prior sessions safely (won't match -# user-renamed lookalikes). -SANDBOX_SPACE_NAME_RE = re.compile(r"^sandbox-[a-f0-9]{8}$") - -# How stale a sandbox must be before we treat it as definitely orphan. -# Anything more recent could be tied to a still-live session in another tab, -# so we leave it alone. -_ORPHAN_STALE_AFTER = timedelta(hours=1) - -# HF Space duplication/build APIs can behave poorly when multiple private -# sandboxes are created concurrently for the same namespace. Keep session -# creation non-blocking, but serialize the actual Hub create path per owner. -_SANDBOX_CREATE_LOCKS: weakref.WeakKeyDictionary[ - asyncio.AbstractEventLoop, dict[str, asyncio.Lock] -] = weakref.WeakKeyDictionary() - - -def _get_sandbox_create_lock(owner: str) -> asyncio.Lock: - loop = asyncio.get_running_loop() - locks = _SANDBOX_CREATE_LOCKS.setdefault(loop, {}) - lock = locks.get(owner) - if lock is None: - lock = asyncio.Lock() - locks[owner] = lock - return lock def _looks_like_path(script: str) -> bool: @@ -99,138 +62,11 @@ async def resolve_sandbox_script( return None, f"Failed to read {script} from sandbox: {e}" -async def _seed_trackio_dashboard_safe(session: Any, space_id: str) -> None: - """Idempotently seed *space_id* with trackio dashboard files using the - session's HF token. Logs progress, swallows errors β€” a failed seed should - not block sandbox creation.""" - if not session or not getattr(session, "hf_token", None): - return - loop = asyncio.get_running_loop() - - def _log(msg: str) -> None: - loop.call_soon_threadsafe( - session.event_queue.put_nowait, - Event(event_type="tool_log", data={"tool": "sandbox_create", "log": msg}), - ) - - try: - await asyncio.to_thread( - ensure_trackio_dashboard, space_id, session.hf_token, _log - ) - except Exception as e: - _log(f"trackio dashboard seed failed: {e}") - - -async def _update_persisted_sandbox_fields(session: Any, **fields: Any) -> None: - """Best-effort update of sandbox metadata on the durable session record.""" - store = getattr(session, "persistence_store", None) - session_id = getattr(session, "session_id", None) - if not (store and session_id and hasattr(store, "update_session_fields")): - return - try: - await store.update_session_fields(session_id, **fields) - except Exception as e: - logger.warning("Failed to persist sandbox metadata for %s: %s", session_id, e) - - -async def _persist_active_sandbox( - session: Any, - sandbox: Sandbox, - *, - hardware: str, -) -> None: - space_id = getattr(sandbox, "space_id", None) - if not space_id: - return - owner = space_id.split("/", 1)[0] if "/" in space_id else None - await _update_persisted_sandbox_fields( - session, - sandbox_space_id=space_id, - sandbox_hardware=hardware, - sandbox_owner=owner, - sandbox_created_at=datetime.now(timezone.utc), - sandbox_status="active", - ) - - -async def _clear_persisted_sandbox(session: Any) -> None: - await _update_persisted_sandbox_fields( - session, - sandbox_space_id=None, - sandbox_hardware=None, - sandbox_owner=None, - sandbox_created_at=None, - sandbox_status="destroyed", - ) - - # ── Tool name mapping (short agent names β†’ Sandbox client names) ────── -def _cleanup_user_orphan_sandboxes( - api: HfApi, - owner: str, - log: Any, -) -> int: - """Delete stale ``sandbox-<8hex>`` Spaces in ``owner``'s account. - - "Stale" = not modified in the last hour. The naming pattern + staleness - filter together make this safe: - - * Naming: only matches ``sandbox-``, the - pattern Sandbox.create produces. Won't touch user-renamed Spaces. - * Staleness: anything modified in the last hour might still be tied - to a live session in another tab/replica, so we leave it alone. - - Runs blocking β€” call via ``asyncio.to_thread``. Best-effort: failures - are logged but never raised, so a flaky HF API never blocks creation. - """ - cutoff = datetime.now(timezone.utc) - _ORPHAN_STALE_AFTER - deleted = 0 - try: - spaces = list(api.list_spaces(author=owner, limit=200, full=True)) - except Exception as e: - log(f"orphan sweep: list_spaces failed: {e}") - return 0 - - for space in spaces: - space_name = space.id.rsplit("/", 1)[-1] - if not SANDBOX_SPACE_NAME_RE.match(space_name): - continue - - last_mod = getattr(space, "lastModified", None) or getattr( - space, "last_modified", None - ) - if isinstance(last_mod, str): - try: - last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00")) - except ValueError: - last_mod = None - if last_mod is None: - log(f"orphan sweep: skipping {space.id}; missing lastModified") - continue - if last_mod and last_mod > cutoff: - # Recent β€” could be a concurrent live session. Skip. - continue - - try: - api.delete_repo(repo_id=space.id, repo_type="space") - deleted += 1 - log(f"orphan sweep: deleted {space.id}") - except Exception as e: - log(f"orphan sweep: failed to delete {space.id}: {e}") - - if deleted: - log(f"orphan sweep: cleaned up {deleted} stale sandbox(es) before create") - return deleted - - async def _ensure_sandbox( - session: Any, - hardware: str = DEFAULT_CPU_SANDBOX_HARDWARE, - extra_secrets: dict[str, str] | None = None, - cancel_event: threading.Event | None = None, - **create_kwargs, + session: Any, hardware: str = "cpu-basic", **create_kwargs ) -> tuple[Sandbox | None, str | None]: """ Ensure a sandbox exists on the session. Auto-creates with given hardware if needed. @@ -254,45 +90,6 @@ async def _ensure_sandbox( if not owner: return None, "Could not determine HF username from token." - create_lock = _get_sandbox_create_lock(owner) - if create_lock.locked(): - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "sandbox", - "log": "Waiting for sandbox creation slot...", - }, - ) - ) - - async with create_lock: - if getattr(session, "sandbox", None): - return session.sandbox, None - - return await _create_sandbox_locked( - session, - api=api, - owner=owner, - hardware=hardware, - extra_secrets=extra_secrets, - cancel_event=cancel_event, - **create_kwargs, - ) - - -async def _create_sandbox_locked( - session: Any, - *, - api: HfApi, - owner: str, - hardware: str, - extra_secrets: dict[str, str] | None = None, - cancel_event: threading.Event | None = None, - **create_kwargs, -) -> tuple[Sandbox | None, str | None]: - """Create the Space while the per-owner sandbox creation lock is held.""" - token = session.hf_token await session.send_event( Event( event_type="tool_log", @@ -315,7 +112,7 @@ async def _create_sandbox_locked( # Bridge asyncio cancel event to a threading.Event for the blocking create call. # We poll session._cancelled from the main loop in a background task and set # a threading.Event that Sandbox.create checks during its polling loops. - cancel_flag = cancel_event or threading.Event() + cancel_flag = threading.Event() async def _watch_cancel(): await session._cancelled.wait() @@ -323,57 +120,39 @@ async def _create_sandbox_locked( watcher_task = asyncio.create_task(_watch_cancel()) - secrets: dict[str, str] = {"HF_TOKEN": token} - if extra_secrets: - secrets.update({k: v for k, v in extra_secrets.items() if v}) - - create_kwargs["private"] = True # enforce: overrides any caller-supplied value kwargs = { "owner": owner, "hardware": hardware, "token": token, - "secrets": secrets, "log": _log, "cancel_event": cancel_flag, **create_kwargs, } - if hardware != DEFAULT_CPU_SANDBOX_HARDWARE: + if hardware != "cpu-basic": kwargs["sleep_time"] = 2700 - import time as _t - - _t_start = _t.monotonic() try: sb = await asyncio.to_thread(Sandbox.create, **kwargs) except Sandbox.Cancelled: return None, "Sandbox creation cancelled by user." finally: watcher_task.cancel() + session.sandbox = sb - if cancel_flag.is_set(): - if getattr(sb, "_owns_space", False): - try: - await asyncio.to_thread(sb.delete) - except Exception as e: - logger.warning( - "Failed to delete cancelled sandbox %s: %s", sb.space_id, e - ) - return None, "Sandbox creation cancelled by user." + # Set a descriptive title (template title is inherited on duplicate) + from huggingface_hub import metadata_update - session.sandbox = sb - session.sandbox_hardware = hardware - session.sandbox_preload_error = None - await _persist_active_sandbox(session, sb, hardware=hardware) - - # Telemetry: sandbox creation (infra consumption signal) - from agent.core import telemetry - - await telemetry.record_sandbox_create( - session, - sb, - hardware=hardware, - create_latency_s=int(_t.monotonic() - _t_start), + await asyncio.to_thread( + metadata_update, + sb.space_id, + {"title": "ml-agent sandbox"}, + repo_type="space", + overwrite=True, + token=token, ) + # Inject the OAuth token into the sandbox so Hub operations work inside it + await asyncio.to_thread(api.add_space_secret, sb.space_id, "HF_TOKEN", token) + await session.send_event( Event( event_type="tool_log", @@ -384,166 +163,24 @@ async def _create_sandbox_locked( return sb, None -def start_cpu_sandbox_preload(session: Any) -> asyncio.Task | None: - """Start a background ``cpu-basic`` sandbox for this session.""" - if not session or getattr(session, "sandbox", None): - return None - - existing_task = getattr(session, "sandbox_preload_task", None) - if existing_task and not existing_task.done(): - return existing_task - - cancel_event = threading.Event() - session.sandbox_preload_cancel_event = cancel_event - session.sandbox_preload_error = None - - async def _preload() -> Sandbox | None: - try: - sb, error = await _ensure_sandbox( - session, - hardware=DEFAULT_CPU_SANDBOX_HARDWARE, - cancel_event=cancel_event, - ) - if error: - session.sandbox_preload_error = error - return None - return sb - except asyncio.CancelledError: - cancel_event.set() - session.sandbox_preload_error = "Sandbox creation cancelled by user." - raise - except Exception as e: - session.sandbox_preload_error = f"Failed to create sandbox: {e}" - logger.warning("CPU sandbox preload failed: %s", e) - return None - - task = asyncio.create_task(_preload()) - session.sandbox_preload_task = task - return task - - -async def cancel_sandbox_preload(session: Any) -> None: - """Best-effort cancellation for an in-flight CPU sandbox preload.""" - cancel_event = getattr(session, "sandbox_preload_cancel_event", None) - if cancel_event is not None: - cancel_event.set() - - task = getattr(session, "sandbox_preload_task", None) - if not task or task.done(): - return - - current_task = asyncio.current_task() - if task is current_task: - return - - try: - await asyncio.wait_for(asyncio.shield(task), timeout=30) - except asyncio.TimeoutError: - logger.warning( - "Timed out waiting for CPU sandbox preload cancellation; " - "task is still live, cancelling asyncio wrapper" - ) - task.cancel() - except asyncio.CancelledError: - raise - except Exception: - pass - - -async def get_active_or_preloaded_sandbox( - session: Any, -) -> tuple[Sandbox | None, str | None]: - """Return the active sandbox, waiting for the startup preload if needed.""" - if not session: - return None, "No session available." - if getattr(session, "sandbox", None): - return session.sandbox, None - - task = getattr(session, "sandbox_preload_task", None) - if task: - try: - await asyncio.shield(task) - except asyncio.CancelledError: - raise - except Exception as e: - session.sandbox_preload_error = f"Failed to create sandbox: {e}" - - if getattr(session, "sandbox", None): - return session.sandbox, None - - preload_error = getattr(session, "sandbox_preload_error", None) - if preload_error: - return None, preload_error - - return None, "Sandbox is still starting. Please retry shortly." - - -async def teardown_session_sandbox(session: Any) -> None: - """Cancel sandbox preload and delete the active owned sandbox, if present.""" - if not session: - return - - await cancel_sandbox_preload(session) - - sandbox = getattr(session, "sandbox", None) - session.sandbox = None - session.sandbox_hardware = None - - if not sandbox: - return - - try: - if not getattr(sandbox, "_owns_space", False): - return - - space_id = getattr(sandbox, "space_id", None) - last_err: Exception | None = None - for attempt in range(3): - try: - logger.info( - "Deleting sandbox %s (attempt %s/3)...", - space_id, - attempt + 1, - ) - await asyncio.to_thread(sandbox.delete) - from agent.core import telemetry - - await telemetry.record_sandbox_destroy(session, sandbox) - return - except Exception as e: - last_err = e - if attempt < 2: - await asyncio.sleep(2**attempt) - logger.error( - "Failed to delete sandbox %s after 3 attempts: %s. " - "Orphan β€” sweep script will pick it up.", - space_id, - last_err, - ) - finally: - await _clear_persisted_sandbox(session) - - # ── sandbox_create tool ────────────────────────────────────────────── SANDBOX_CREATE_TOOL_SPEC = { "name": "sandbox_create", "description": ( - "Create or replace the session sandbox when non-default hardware is needed.\n\n" - "A private cpu-basic sandbox is already started automatically for each session. " - "For normal CPU code execution, call bash/read/write/edit directly; do NOT call sandbox_create first.\n\n" - "Use sandbox_create when: you need GPU hardware, cpu-upgrade, or Trackio secrets before running code. " - "The active sandbox persists across tool calls within the session. pip install works out of the box. " - "Sandboxes are always created as private HF Spaces.\n\n" + "Create a persistent remote Linux environment for developing and testing scripts.\n\n" + "Workflow: sandbox_create β†’ write script β†’ pip install β†’ test with small run β†’ fix errors β†’ hf_jobs at scale.\n" + "The sandbox persists across tool calls within the session. pip install works out of the box.\n\n" + "Use this when: you need to develop, test, and iterate on scripts before launching via hf_jobs. " + "Especially for training scripts where you need to verify imports, test on a small subset, and fix errors interactively.\n\n" + "Skip this when: the task is a simple one-shot operation (status check, resource search, quick data query), " + "or the script is copied from a verified working example with minimal changes.\n\n" "For ML code that uses CUDA, bf16, or model loading: use GPU hardware (t4-small minimum). " "CPU sandboxes cannot run GPU code paths β€” your test will not catch GPU-related errors.\n\n" "Before choosing hardware, estimate your VRAM needs (models you run, training data size). Rule of thumb: bf16/fp16 β‰ˆ 2 bytes/param, " "fp32 β‰ˆ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n" "Common picks: t4-small (16GB VRAM, fits ≀1-3B), a10g-small (24GB, ≀7B), a100-large (80GB, ≀30B). " "If the model won't fit, pick larger hardware upfront β€” OOM on a sandbox wastes time.\n\n" - "If you intend to run a training script in this sandbox that uses report_to='trackio', " - "pass `trackio_space_id` (e.g. '/mlintern-<8char>') and `trackio_project` so they " - "are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n" "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n" ), "parameters": { @@ -554,27 +191,11 @@ SANDBOX_CREATE_TOOL_SPEC = { "hardware": { "type": "string", "enum": [e.value for e in SpaceHardware], - "description": ( - "Hardware tier for the sandbox. Omit for the existing auto-started " - "cpu-basic sandbox; choose GPU/cpu-upgrade only when needed." - ), + "description": "Hardware tier for the sandbox (default: cpu-basic)", }, - "trackio_space_id": { - "type": "string", - "description": ( - "Optional. The HF Space hosting the trackio dashboard for runs in this sandbox " - "(e.g. '/mlintern-<8char>', under YOUR HF namespace). Injected as " - "TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and " - "seeded with the trackio dashboard β€” DO NOT pre-create it via hf_repo_git, " - "that produces an empty Space that breaks the embed." - ), - }, - "trackio_project": { - "type": "string", - "description": ( - "Optional. The trackio project name. Injected as TRACKIO_PROJECT secret and " - "used by the UI to filter the embedded dashboard to this project." - ), + "private": { + "type": "boolean", + "description": "If true, create a private Space", }, }, }, @@ -582,127 +203,35 @@ SANDBOX_CREATE_TOOL_SPEC = { async def sandbox_create_handler( - args: dict[str, Any], session: Any = None, tool_call_id: str | None = None + args: dict[str, Any], session: Any = None ) -> tuple[str, bool]: """Handle sandbox_create tool calls.""" - hardware = args.get("hardware", DEFAULT_CPU_SANDBOX_HARDWARE) - trackio_space_id = args.get("trackio_space_id") or None - trackio_project = args.get("trackio_project") or None - - async def _emit_trackio_state(sb: Sandbox) -> None: - """Tell the frontend which trackio dashboard to embed for this sandbox.""" - if not (session and tool_call_id and trackio_space_id): - return - data: dict[str, Any] = { - "tool_call_id": tool_call_id, - "tool": "sandbox_create", - "state": "running", - "trackioSpaceId": trackio_space_id, - } - if trackio_project: - data["trackioProject"] = trackio_project - await session.send_event(Event(event_type="tool_state_change", data=data)) - - preload_task = getattr(session, "sandbox_preload_task", None) - if ( - session - and not getattr(session, "sandbox", None) - and preload_task - and not preload_task.done() - and hardware == DEFAULT_CPU_SANDBOX_HARDWARE - ): - sb, error = await get_active_or_preloaded_sandbox(session) - if error: - return error, False - if sb: - await _emit_trackio_state(sb) - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Hardware: {DEFAULT_CPU_SANDBOX_HARDWARE}\n" - f"Use bash/read/write/edit to interact with it." - ), True - - if ( - session - and not getattr(session, "sandbox", None) - and preload_task - and not preload_task.done() - and hardware != DEFAULT_CPU_SANDBOX_HARDWARE - ): - await cancel_sandbox_preload(session) - - # If sandbox already exists, return its info or replace the auto CPU sandbox + # If sandbox already exists, return its info if session and getattr(session, "sandbox", None): sb = session.sandbox - active_hardware = getattr(session, "sandbox_hardware", None) - if active_hardware == hardware: - await _emit_trackio_state(sb) - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Hardware: {active_hardware}\n" - f"Use bash/read/write/edit to interact with it." - ), True - - requested_hardware = args.get("hardware") - lockout_note = "" - if ( - active_hardware == DEFAULT_CPU_SANDBOX_HARDWARE - and hardware != DEFAULT_CPU_SANDBOX_HARDWARE - ): - await teardown_session_sandbox(session) - elif requested_hardware: - lockout_note = ( - f"\nRequested hardware: {requested_hardware}\n" - "Hardware cannot be changed by calling sandbox_create again. " - "Delete the existing sandbox first if you need a different tier." - ) - await _emit_trackio_state(sb) - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"{lockout_note}\n" - f"Use bash/read/write/edit to interact with it." - ), True - else: - await _emit_trackio_state(sb) - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Hardware: {active_hardware or 'unknown'}\n" - f"Use bash/read/write/edit to interact with it." - ), True - - create_kwargs: dict[str, Any] = {} - - extra_secrets: dict[str, str] = {} - if trackio_space_id: - extra_secrets["TRACKIO_SPACE_ID"] = trackio_space_id - await _seed_trackio_dashboard_safe(session, trackio_space_id) - if trackio_project: - extra_secrets["TRACKIO_PROJECT"] = trackio_project + return ( + f"Sandbox already active: {sb.space_id}\n" + f"URL: {sb.url}\n" + f"Use bash/read/write/edit to interact with it." + ), True + + hardware = args.get("hardware", "cpu-basic") + create_kwargs = {} + if "private" in args: + create_kwargs["private"] = args["private"] try: - sb, error = await _ensure_sandbox( - session, - hardware=hardware, - extra_secrets=extra_secrets or None, - **create_kwargs, - ) + sb, error = await _ensure_sandbox(session, hardware=hardware, **create_kwargs) except Exception as e: return f"Failed to create sandbox: {e}", False if error: return error, False - await _emit_trackio_state(sb) - return ( f"Sandbox created: {sb.space_id}\n" f"URL: {sb.url}\n" f"Hardware: {hardware}\n" - "Visibility: private\n" f"Use bash/read/write/edit to interact with it." ), True @@ -711,21 +240,13 @@ def _make_tool_handler(sandbox_tool_name: str): """Factory: create a handler for a sandbox operation tool.""" async def handler(args: dict[str, Any], session: Any = None) -> tuple[str, bool]: - sb, error = await get_active_or_preloaded_sandbox(session) - if error: - return error, False - if not sb: - return "Sandbox is still starting. Please retry shortly.", False + # Require sandbox to exist β€” user must approve sandbox_create first + if not session or not getattr(session, "sandbox", None): + return "No sandbox running. Call sandbox_create first to start one.", False + + sb = session.sandbox try: - if sandbox_tool_name == "bash" and args.get("command"): - args = { - **args, - "command": wrap_shell_command_with_hub_artifact_bootstrap( - args["command"], - session, - ), - } result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args) if result.success: output = result.output or "(no output)" @@ -748,7 +269,7 @@ def get_sandbox_tools(): tools = [] - # sandbox_create (for GPU or other non-default hardware) + # sandbox_create (explicit creation, requires approval) tools.append( ToolSpec( name=SANDBOX_CREATE_TOOL_SPEC["name"], @@ -761,15 +282,10 @@ def get_sandbox_tools(): # Operation tools (auto-execute, no approval needed) for name in Sandbox.TOOLS.keys(): spec = Sandbox.TOOLS[name] - description = ( - "Uses the session's active sandbox. A private cpu-basic sandbox is " - "started automatically for normal CPU work; call sandbox_create only " - "for GPU or other non-default hardware.\n\n" + spec["description"] - ) tools.append( ToolSpec( name=name, - description=description, + description=spec["description"], parameters=spec["parameters"], handler=_make_tool_handler(name), ) diff --git a/agent/tools/trackio_seed.py b/agent/tools/trackio_seed.py deleted file mode 100644 index 1062e1b5eda2701833aad7c1c895727d7fbd191e..0000000000000000000000000000000000000000 --- a/agent/tools/trackio_seed.py +++ /dev/null @@ -1,205 +0,0 @@ -"""Seed an HF Space with the trackio dashboard. - -Background: when the agent creates a Space via `hf_repo_git create_repo` (or -the user pre-creates one), it ships with no app.py β€” so the iframe shows the -default Gradio "Get started" template instead of charts. Trackio's `init()` -detects the existing Space but does NOT auto-bootstrap dashboard files into it, -so the dashboard never materializes. - -This helper writes the three files trackio's runtime expects (README.md, -requirements.txt, app.py) into the Space, idempotently, BEFORE the job that -will call `trackio.init()` runs. We deliberately omit `hf_oauth: true` from -the README so the embedded iframe in ml-intern renders without a login click β€” -per-user privacy is enforced by namespace ownership instead. - -Beyond the dashboard files, the helper also creates the metrics bucket and -mounts it on the Space at `/data` (with `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` -Space variables). Without this, the running job writes metrics into a bucket -that the dashboard Space can't read, and the iframe shows "No projects". -""" - -from __future__ import annotations - -import io -from typing import Callable, Optional - -from huggingface_hub import ( - HfApi, - Volume, - add_space_variable, - create_bucket, - create_repo, -) -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError - - -_README = """--- -title: Trackio Dashboard -emoji: πŸ“Š -colorFrom: pink -colorTo: gray -sdk: gradio -app_file: app.py -pinned: false -tags: - - trackio ---- - -Embedded trackio dashboard for ml-intern runs. -""" - -_REQUIREMENTS = "trackio\n" -_APP_PY = "import trackio\ntrackio.show()\n" - -# ml-intern brand mark surfaced inside the trackio dashboard. Trackio reads -# `TRACKIO_LOGO_LIGHT_URL` / `TRACKIO_LOGO_DARK_URL` from Space variables and -# renders them in place of its own logo. We point at the publicly-resolvable -# copy on the smolagents/ml-intern Space repo so any seeded dashboard inherits -# the ml-intern branding without each user having to host the asset. -_LOGO_URL = ( - "https://huggingface.co/spaces/smolagents/ml-intern/" - "resolve/main/frontend/public/smolagents.webp" -) - -_FILES = { - "README.md": _README, - "requirements.txt": _REQUIREMENTS, - "app.py": _APP_PY, -} - - -def _already_seeded(api: HfApi, space_id: str) -> bool: - """Cheap check: does the Space already have a trackio dashboard app.py? - - Avoids re-uploading the same three files on every job submission. We look - for the literal `trackio.show` call which is the load-bearing line β€” any - other app.py shape (the default gradio shell, a stale custom one) means - we should re-seed. - """ - try: - path = api.hf_hub_download( - repo_id=space_id, repo_type="space", filename="app.py" - ) - except (EntryNotFoundError, RepositoryNotFoundError, OSError): - return False - try: - with open(path, "r", encoding="utf-8") as f: - return "trackio.show" in f.read() - except OSError: - return False - - -def _get_space_volumes(api: HfApi, space_id: str) -> list: - """Return mounted volumes for a Space. - - `get_space_runtime()` doesn't always populate `volumes` even when the - mount exists; mirror trackio's fallback to `space_info().runtime.volumes`. - """ - runtime = api.get_space_runtime(space_id) - if getattr(runtime, "volumes", None): - return list(runtime.volumes) - info = api.space_info(space_id) - if info.runtime and getattr(info.runtime, "volumes", None): - return list(info.runtime.volumes) - return [] - - -def _ensure_bucket_mounted( - api: HfApi, - space_id: str, - bucket_id: str, - hf_token: str, - log: Optional[Callable[[str], None]] = None, -) -> None: - """Create the bucket if missing, mount it at `/data` on the Space, and - set the `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` Space variables. Idempotent β€” - skips work that has already been done. - """ - create_bucket(bucket_id, private=True, exist_ok=True, token=hf_token) - - existing = _get_space_volumes(api, space_id) - already_mounted = any( - getattr(v, "type", None) == "bucket" - and getattr(v, "source", None) == bucket_id - and getattr(v, "mount_path", None) == "/data" - for v in existing - ) - if not already_mounted: - preserved = [ - v - for v in existing - if not ( - getattr(v, "type", None) == "bucket" - and ( - getattr(v, "source", None) == bucket_id - or getattr(v, "mount_path", None) == "/data" - ) - ) - ] - api.set_space_volumes( - space_id, - preserved + [Volume(type="bucket", source=bucket_id, mount_path="/data")], - ) - if log: - log(f"mounted bucket {bucket_id} at /data on {space_id}") - - variables = api.get_space_variables(space_id) - desired = { - "TRACKIO_DIR": "/data/trackio", - "TRACKIO_BUCKET_ID": bucket_id, - "TRACKIO_LOGO_LIGHT_URL": _LOGO_URL, - "TRACKIO_LOGO_DARK_URL": _LOGO_URL, - } - for key, value in desired.items(): - if getattr(variables.get(key), "value", None) != value: - add_space_variable(space_id, key, value, token=hf_token) - - -def ensure_trackio_dashboard( - space_id: str, - hf_token: str, - log: Optional[Callable[[str], None]] = None, -) -> bool: - """Make sure *space_id* is fully wired for trackio: - 1. Space exists with our dashboard files (README without `hf_oauth`, - `requirements.txt`, `app.py` calling `trackio.show`). - 2. Bucket `-bucket` exists, is mounted at `/data`, and the - Space has `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` variables set. - - Idempotent β€” re-running is cheap. Returns True if any seeding happened - in step (1), False if the dashboard files were already in place. Bucket - mount is always re-checked. - """ - api = HfApi(token=hf_token) - - create_repo( - repo_id=space_id, - repo_type="space", - space_sdk="gradio", - exist_ok=True, - token=hf_token, - ) - - seeded_files = False - if _already_seeded(api, space_id): - if log: - log(f"trackio dashboard already seeded on {space_id}") - else: - if log: - log(f"seeding trackio dashboard files into {space_id}") - for path_in_repo, content in _FILES.items(): - api.upload_file( - path_or_fileobj=io.BytesIO(content.encode("utf-8")), - path_in_repo=path_in_repo, - repo_id=space_id, - repo_type="space", - commit_message=f"ml-intern: seed trackio dashboard ({path_in_repo})", - ) - seeded_files = True - - bucket_id = f"{space_id}-bucket" - _ensure_bucket_mounted(api, space_id, bucket_id, hf_token, log) - - if log: - log(f"trackio dashboard ready: https://huggingface.co/spaces/{space_id}") - return seeded_files diff --git a/agent/tools/web_search_tool.py b/agent/tools/web_search_tool.py deleted file mode 100644 index 5c18410855bebdee305997d90de4c9e56f942461..0000000000000000000000000000000000000000 --- a/agent/tools/web_search_tool.py +++ /dev/null @@ -1,276 +0,0 @@ -"""DuckDuckGo HTML web search tool. - -This mirrors Claw Code's Rust WebSearch behavior: fetch DuckDuckGo's HTML -endpoint, extract result links, optionally filter domains, and return a -JSON payload the model can cite. -""" - -from __future__ import annotations - -import asyncio -import html -import json -import os -import time -from dataclasses import dataclass -from html.parser import HTMLParser -from typing import Any -from urllib.parse import parse_qsl, parse_qs, urlencode, urlparse, urlunparse - -import requests - -DEFAULT_SEARCH_URL = "https://html.duckduckgo.com/html/" -WEB_SEARCH_BASE_URL_ENV = "CLAWD_WEB_SEARCH_BASE_URL" -USER_AGENT = "clawd-rust-tools/0.1" -REQUEST_TIMEOUT_SECONDS = 20 -MAX_RESULTS = 8 - - -@dataclass(frozen=True) -class SearchHit: - title: str - url: str - - def as_json(self) -> dict[str, str]: - return {"title": self.title, "url": self.url} - - -class _AnchorParser(HTMLParser): - def __init__(self, *, require_result_class: bool) -> None: - super().__init__(convert_charrefs=True) - self.require_result_class = require_result_class - self.hits: list[tuple[str, str]] = [] - self._active_href: str | None = None - self._active_text: list[str] = [] - - def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: - if tag.lower() != "a": - return - attr_map = {key.lower(): value or "" for key, value in attrs} - href = attr_map.get("href") - if not href: - return - if self.require_result_class and "result__a" not in attr_map.get("class", ""): - return - self._active_href = href - self._active_text = [] - - def handle_data(self, data: str) -> None: - if self._active_href is not None: - self._active_text.append(data) - - def handle_entityref(self, name: str) -> None: - if self._active_href is not None: - self._active_text.append(f"&{name};") - - def handle_charref(self, name: str) -> None: - if self._active_href is not None: - self._active_text.append(f"&#{name};") - - def handle_endtag(self, tag: str) -> None: - if tag.lower() != "a" or self._active_href is None: - return - title = collapse_whitespace(html.unescape("".join(self._active_text))).strip() - self.hits.append((self._active_href, title)) - self._active_href = None - self._active_text = [] - - -def build_search_url(query: str) -> str: - base = os.environ.get(WEB_SEARCH_BASE_URL_ENV, DEFAULT_SEARCH_URL) - parsed = urlparse(base) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: - raise ValueError(f"invalid search base URL: {base}") - - query_pairs = parse_qsl(parsed.query, keep_blank_values=True) - query_pairs.append(("q", query)) - return urlunparse(parsed._replace(query=urlencode(query_pairs))) - - -def collapse_whitespace(value: str) -> str: - return " ".join(value.split()) - - -def decode_duckduckgo_redirect(url: str) -> str | None: - if url.startswith("http://") or url.startswith("https://"): - return html.unescape(url) - if url.startswith("//"): - joined = f"https:{url}" - elif url.startswith("/"): - joined = f"https://duckduckgo.com{url}" - else: - return None - - parsed = urlparse(joined) - if parsed.path in {"/l", "/l/"}: - uddg = parse_qs(parsed.query).get("uddg", []) - if uddg: - return html.unescape(uddg[0]) - return joined - - -def _extract_links(search_html: str, *, require_result_class: bool) -> list[SearchHit]: - parser = _AnchorParser(require_result_class=require_result_class) - parser.feed(search_html) - - hits: list[SearchHit] = [] - for raw_url, title in parser.hits: - if not title: - continue - decoded_url = decode_duckduckgo_redirect(raw_url) - if decoded_url and ( - decoded_url.startswith("http://") or decoded_url.startswith("https://") - ): - hits.append(SearchHit(title=title, url=decoded_url)) - return hits - - -def extract_search_hits(search_html: str) -> list[SearchHit]: - return _extract_links(search_html, require_result_class=True) - - -def extract_search_hits_from_generic_links(search_html: str) -> list[SearchHit]: - return _extract_links(search_html, require_result_class=False) - - -def normalize_domain_filter(domain: str) -> str: - trimmed = domain.strip() - parsed = urlparse(trimmed) - candidate = parsed.hostname if parsed.scheme and parsed.hostname else trimmed - return candidate.strip().lstrip(".").rstrip("/").lower() - - -def host_matches_list(url: str, domains: list[str]) -> bool: - host = urlparse(url).hostname - if not host: - return False - normalized_host = host.lower() - for domain in domains: - normalized = normalize_domain_filter(domain) - if normalized and ( - normalized_host == normalized or normalized_host.endswith(f".{normalized}") - ): - return True - return False - - -def dedupe_hits(hits: list[SearchHit]) -> list[SearchHit]: - seen: set[str] = set() - deduped: list[SearchHit] = [] - for hit in hits: - if hit.url in seen: - continue - seen.add(hit.url) - deduped.append(hit) - return deduped - - -def execute_web_search( - query: str, - allowed_domains: list[str] | None = None, - blocked_domains: list[str] | None = None, - tool_use_id: str = "web_search_1", -) -> dict[str, Any]: - started = time.monotonic() - search_url = build_search_url(query) - response = requests.get( - search_url, - headers={"User-Agent": USER_AGENT}, - timeout=REQUEST_TIMEOUT_SECONDS, - allow_redirects=True, - ) - - hits = extract_search_hits(response.text) - if not hits and urlparse(response.url or search_url).hostname: - hits = extract_search_hits_from_generic_links(response.text) - - if allowed_domains is not None: - hits = [hit for hit in hits if host_matches_list(hit.url, allowed_domains)] - if blocked_domains is not None: - hits = [hit for hit in hits if not host_matches_list(hit.url, blocked_domains)] - - hits = dedupe_hits(hits)[:MAX_RESULTS] - rendered_hits = "\n".join(f"- [{hit.title}]({hit.url})" for hit in hits) - if hits: - summary = ( - f"Search results for {query!r}. Include a Sources section in the final answer.\n" - f"{rendered_hits}" - ) - else: - summary = f"No web search results matched the query {query!r}." - - return { - "query": query, - "results": [ - summary, - { - "tool_use_id": tool_use_id, - "content": [hit.as_json() for hit in hits], - }, - ], - "durationSeconds": time.monotonic() - started, - } - - -WEB_SEARCH_TOOL_SPEC = { - "name": "web_search", - "description": "Search the web for current information and return cited results.", - "parameters": { - "type": "object", - "properties": { - "query": {"type": "string", "minLength": 2}, - "allowed_domains": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional allowlist of domains or URLs. Subdomains match.", - }, - "blocked_domains": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional blocklist of domains or URLs. Subdomains match.", - }, - }, - "required": ["query"], - "additionalProperties": False, - }, -} - - -def _optional_string_list(arguments: dict[str, Any], key: str) -> list[str] | None: - value = arguments.get(key) - if value is None: - return None - if not isinstance(value, list) or not all(isinstance(item, str) for item in value): - raise ValueError(f"{key} must be an array of strings") - return value - - -async def web_search_handler( - arguments: dict[str, Any], - session: Any = None, - tool_call_id: str | None = None, - **_kw: Any, -) -> tuple[str, bool]: - query_value = arguments.get("query", "") - if not isinstance(query_value, str): - return ( - "Error: web_search requires a query string with at least 2 characters.", - False, - ) - - query = query_value.strip() - if len(query) < 2: - return "Error: web_search requires a query with at least 2 characters.", False - - try: - output = await asyncio.to_thread( - execute_web_search, - query=query, - allowed_domains=_optional_string_list(arguments, "allowed_domains"), - blocked_domains=_optional_string_list(arguments, "blocked_domains"), - tool_use_id=tool_call_id or "web_search_1", - ) - except Exception as exc: - return f"Error executing web search: {exc}", False - - return json.dumps(output, indent=2), True diff --git a/agent/utils/boot_timing.py b/agent/utils/boot_timing.py deleted file mode 100644 index 0c0884d03380f07a05a26c059247fa1b393552e9..0000000000000000000000000000000000000000 --- a/agent/utils/boot_timing.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Shared timing and color helpers for startup visual effects.""" - -import math - - -def settle_curve(progress: float, sharpness: float = 3.0) -> float: - """Return noise amount in range 1..0 for normalized progress 0..1.""" - t = max(0.0, min(1.0, progress)) - return math.exp(-sharpness * t) - - -def warm_gold_from_white(progress: float) -> tuple[int, int, int]: - """Interpolate from white to warm gold for progress 0..1.""" - t = max(0.0, min(1.0, progress)) - return 255, int(255 - 55 * t), int(255 - 175 * t) diff --git a/agent/utils/braille.py b/agent/utils/braille.py deleted file mode 100644 index 4621b735b7cff25d453afbc93f443f2bae4e7e4b..0000000000000000000000000000000000000000 --- a/agent/utils/braille.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Braille-character canvas for high-resolution terminal graphics. - -Each terminal cell maps to a 2x4 dot grid using Unicode braille characters -(U+2800–U+28FF), giving 2Γ— horizontal and 4Γ— vertical resolution. -""" - -# Braille dot positions: (0,0) (1,0) dots 1,4 -# (0,1) (1,1) dots 2,5 -# (0,2) (1,2) dots 3,6 -# (0,3) (1,3) dots 7,8 -_DOT_MAP = ( - (0x01, 0x08), - (0x02, 0x10), - (0x04, 0x20), - (0x40, 0x80), -) - - -class BrailleCanvas: - """A pixel canvas that renders to braille characters.""" - - def __init__(self, term_width: int, term_height: int): - self.term_width = term_width - self.term_height = term_height - self.pixel_width = term_width * 2 - self.pixel_height = term_height * 4 - self._buf = bytearray(term_width * term_height) - - def clear(self) -> None: - for i in range(len(self._buf)): - self._buf[i] = 0 - - def set_pixel(self, x: int, y: int) -> None: - if 0 <= x < self.pixel_width and 0 <= y < self.pixel_height: - cx, rx = divmod(x, 2) - cy, ry = divmod(y, 4) - self._buf[cy * self.term_width + cx] |= _DOT_MAP[ry][rx] - - def render(self) -> list[str]: - lines = [] - for row in range(self.term_height): - offset = row * self.term_width - line = "".join( - chr(0x2800 + self._buf[offset + col]) for col in range(self.term_width) - ) - lines.append(line) - return lines - - -# ── Bitmap font (5Γ—7 uppercase + digits) ────────────────────────────── - -_FONT: dict[str, list[str]] = {} - - -def _define_font() -> None: - """Define a simple 5Γ—7 bitmap font for uppercase ASCII.""" - glyphs = { - "A": [" ## ", "# #", "# #", "####", "# #", "# #", "# #"], - "B": ["### ", "# #", "# #", "### ", "# #", "# #", "### "], - "C": [" ## ", "# #", "# ", "# ", "# ", "# #", " ## "], - "D": ["### ", "# #", "# #", "# #", "# #", "# #", "### "], - "E": ["####", "# ", "# ", "### ", "# ", "# ", "####"], - "F": ["####", "# ", "# ", "### ", "# ", "# ", "# "], - "G": [" ## ", "# #", "# ", "# ##", "# #", "# #", " ###"], - "H": ["# #", "# #", "# #", "####", "# #", "# #", "# #"], - "I": ["###", " # ", " # ", " # ", " # ", " # ", "###"], - "J": [" ##", " # ", " # ", " # ", " # ", "# # ", " # "], - "K": ["# #", "# # ", "## ", "## ", "# # ", "# #", "# #"], - "L": ["# ", "# ", "# ", "# ", "# ", "# ", "####"], - "M": ["# #", "## ##", "# # #", "# # #", "# #", "# #", "# #"], - "N": ["# #", "## #", "## #", "# ##", "# ##", "# #", "# #"], - "O": [" ## ", "# #", "# #", "# #", "# #", "# #", " ## "], - "P": ["### ", "# #", "# #", "### ", "# ", "# ", "# "], - "Q": [" ## ", "# #", "# #", "# #", "# ##", "# #", " ## "], - "R": ["### ", "# #", "# #", "### ", "# # ", "# #", "# #"], - "S": [" ## ", "# #", "# ", " ## ", " #", "# #", " ## "], - "T": ["#####", " # ", " # ", " # ", " # ", " # ", " # "], - "U": ["# #", "# #", "# #", "# #", "# #", "# #", " ## "], - "V": ["# #", "# #", "# #", " # # ", " # # ", " # ", " # "], - "W": ["# #", "# #", "# #", "# # #", "# # #", "## ##", "# #"], - "X": ["# #", "# #", " ## ", " ## ", " ## ", "# #", "# #"], - "Y": ["# #", "# #", " # # ", " # ", " # ", " # ", " # "], - "Z": ["####", " #", " # ", " # ", "# ", "# ", "####"], - " ": [" ", " ", " ", " ", " ", " ", " "], - "0": [" ## ", "# #", "# #", "# #", "# #", "# #", " ## "], - "1": [" # ", "## ", " # ", " # ", " # ", " # ", "###"], - "2": [" ## ", "# #", " #", " # ", " # ", "# ", "####"], - "3": [" ## ", "# #", " #", " ## ", " #", "# #", " ## "], - "4": ["# #", "# #", "# #", "####", " #", " #", " #"], - "5": ["####", "# ", "### ", " #", " #", "# #", " ## "], - "6": [" ## ", "# ", "### ", "# #", "# #", "# #", " ## "], - "7": ["####", " #", " # ", " # ", " # ", " # ", " # "], - "8": [" ## ", "# #", "# #", " ## ", "# #", "# #", " ## "], - "9": [" ## ", "# #", "# #", " ###", " #", " #", " ## "], - } - _FONT.update(glyphs) - - -_define_font() - - -def text_to_pixels(text: str, scale: int = 1) -> list[tuple[int, int]]: - """Convert text string to a list of (x, y) pixel positions using bitmap font.""" - pixels = [] - cursor_x = 0 - for ch in text.upper(): - glyph = _FONT.get(ch) - if glyph is None: - cursor_x += 4 * scale - continue - for row_idx, row in enumerate(glyph): - for col_idx, cell in enumerate(row): - if cell == "#": - for sy in range(scale): - for sx in range(scale): - pixels.append( - (cursor_x + col_idx * scale + sx, row_idx * scale + sy) - ) - glyph_width = max(len(r) for r in glyph) - cursor_x += (glyph_width + 1) * scale - return pixels diff --git a/agent/utils/crt_boot.py b/agent/utils/crt_boot.py deleted file mode 100644 index da0867188961ff08952005c7d098879dfd2a4279..0000000000000000000000000000000000000000 --- a/agent/utils/crt_boot.py +++ /dev/null @@ -1,116 +0,0 @@ -"""CRT / glitch boot sequence effect for CLI startup. - -Simulates an old CRT terminal booting up: text appearing character by character -with noise artifacts, then settling into a clean display. -""" - -import random -import time - -from rich.console import Console -from rich.text import Text -from rich.live import Live - -from agent.utils.boot_timing import settle_curve - - -def _glitch_text(text: str, intensity: float, rng: random.Random) -> str: - """Add random glitch characters to text.""" - glitch_chars = "β–ˆβ–“β–’β–‘β”ƒβ”«β”£β•‹β•β•Žβ”€β”β”…β”„" - result = list(text) - for i in range(len(result)): - if rng.random() < intensity: - result[i] = rng.choice(glitch_chars) - return "".join(result) - - -def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> None: - """Run the CRT boot sequence effect. - - Args: - console: Rich console instance. - boot_lines: List of (text, rich_style) tuples to display. - """ - term_height = min(console.height - 2, 40) - rng = random.Random(42) - - with Live(console=console, refresh_per_second=30, transient=True) as live: - displayed_lines: list[tuple[str, str]] = [] - - for line_text, line_style in boot_lines: - if not line_text: - displayed_lines.append(("", "")) - continue - - line_len = max(1, len(line_text)) - # Type out each character - for char_idx in range(len(line_text) + 1): - result = Text() - progress = char_idx / line_len - noise = settle_curve(progress) - prev_glitch_chance = 0.01 + 0.06 * noise - prev_glitch_intensity = 0.02 + 0.12 * noise - scanline_chance = 0.005 + 0.03 * noise - - # Render previously completed lines - for prev_text, prev_style in displayed_lines: - if rng.random() < prev_glitch_chance: - result.append( - _glitch_text(prev_text, prev_glitch_intensity, rng), - style=prev_style, - ) - else: - result.append(prev_text, style=prev_style) - result.append("\n") - - # Current line being typed - typed = line_text[:char_idx] - cursor = "β–ˆ" if char_idx < len(line_text) else "" - - # Noise after cursor - noise_tail = "" - if char_idx < len(line_text): - noise_len = rng.randint(0, int(1 + 5 * noise)) - noise_tail = "".join(rng.choice("β–‘β–’β–“") for _ in range(noise_len)) - - result.append(typed, style=line_style) - result.append(cursor, style="bold rgb(255,200,80)") - result.append(noise_tail, style="dim rgb(180,140,40)") - result.append("\n") - - # Faint scanlines in remaining space - remaining = term_height - len(displayed_lines) - 2 - for _ in range(max(0, remaining)): - if rng.random() < scanline_chance: - scan_len = rng.randint(5, 30) - result.append("─" * scan_len, style="dim rgb(180,140,40)") - result.append("\n") - - live.update(result) - - # Variable typing speed - if line_text[char_idx - 1 : char_idx] in " .": - time.sleep(0.025) - else: - time.sleep(0.010) - - displayed_lines.append((line_text, line_style)) - time.sleep(0.06) - - # Hold with blinking cursor - for frame in range(20): - result = Text() - for prev_text, prev_style in displayed_lines: - result.append(prev_text, style=prev_style) - result.append("\n") - if frame % 8 < 4: - result.append("β–ˆ", style="rgb(255,200,80)") - live.update(result) - time.sleep(0.05) - - # Print final clean frame - final = Text() - for prev_text, prev_style in displayed_lines: - final.append(prev_text, style=prev_style) - final.append("\n") - console.print(final) diff --git a/agent/utils/particle_logo.py b/agent/utils/particle_logo.py deleted file mode 100644 index 9c3338152a8b2fd29031c4eadaa19e9078f6da2b..0000000000000000000000000000000000000000 --- a/agent/utils/particle_logo.py +++ /dev/null @@ -1,230 +0,0 @@ -"""Particle coalesce effect for the HUGGING FACE ML INTERN logo. - -Random particles swirl in from the edges, converge to form the text -"HUGGING FACE / ML INTERN", hold briefly, then the final frame is printed. -Rendered with braille characters for high detail. - -Based on Leandro's particle_coalesce.py demo. -""" - -import math -import random -import time - -from rich.console import Console -from rich.text import Text -from rich.align import Align -from rich.live import Live - -from agent.utils.braille import BrailleCanvas, text_to_pixels -from agent.utils.boot_timing import settle_curve, warm_gold_from_white - - -class Particle: - __slots__ = ("x", "y", "target_x", "target_y", "vx", "vy", "phase", "delay") - - def __init__( - self, x: float, y: float, target_x: float, target_y: float, delay: float = 0 - ): - self.x = x - self.y = y - self.target_x = target_x - self.target_y = target_y - self.vx = 0.0 - self.vy = 0.0 - self.phase = random.uniform(0, math.pi * 2) - self.delay = delay - - def update_converge(self, t: float, strength: float = 0.08, damping: float = 0.92): - """Move toward target with spring-like physics.""" - if t < self.delay: - # Still in swirl phase - self.x += self.vx - self.y += self.vy - self.vx *= 0.99 - self.vy *= 0.99 - # Gentle spiral - angle = self.phase + t * 2 - self.vx += math.cos(angle) * 0.3 - self.vy += math.sin(angle) * 0.3 - return - - # Spring toward target - dx = self.target_x - self.x - dy = self.target_y - self.y - self.vx += dx * strength - self.vy += dy * strength - self.vx *= damping - self.vy *= damping - self.x += self.vx - self.y += self.vy - - @property - def at_target(self) -> bool: - return abs(self.x - self.target_x) < 1.5 and abs(self.y - self.target_y) < 1.5 - - -def run_particle_logo(console: Console, hold_seconds: float = 1.5) -> None: - """Run the particle coalesce effect.""" - term_width = min(console.width, 120) - term_height = min(console.height - 4, 35) - - canvas = BrailleCanvas(term_width, term_height) - - # Get target positions from text - text_pixels_line1 = text_to_pixels("HUGGING FACE", scale=2) - text_pixels_line2 = text_to_pixels("ML INTERN", scale=2) - - # Calculate dimensions for centering - def get_bounds(pixels): - if not pixels: - return 0, 0, 0, 0 - xs = [p[0] for p in pixels] - ys = [p[1] for p in pixels] - return min(xs), max(xs), min(ys), max(ys) - - min_x1, max_x1, min_y1, max_y1 = get_bounds(text_pixels_line1) - min_x2, max_x2, min_y2, max_y2 = get_bounds(text_pixels_line2) - - w1, h1 = max_x1 - min_x1 + 1, max_y1 - min_y1 + 1 - w2, h2 = max_x2 - min_x2 + 1, max_y2 - min_y2 + 1 - - total_h = h1 + 6 + h2 # gap between lines - start_y = (canvas.pixel_height - total_h) // 2 - - # Center line 1 - offset_x1 = (canvas.pixel_width - w1) // 2 - min_x1 - offset_y1 = start_y - min_y1 - targets_1 = [(p[0] + offset_x1, p[1] + offset_y1) for p in text_pixels_line1] - - # Center line 2 - offset_x2 = (canvas.pixel_width - w2) // 2 - min_x2 - offset_y2 = start_y + h1 + 6 - min_y2 - targets_2 = [(p[0] + offset_x2, p[1] + offset_y2) for p in text_pixels_line2] - - all_targets = targets_1 + targets_2 - - # Subsample for performance β€” take every Nth pixel - step = max(1, len(all_targets) // 1500) - sampled_targets = all_targets[::step] - - # Create particles at random edge positions - rng = random.Random(42) - particles = [] - pw, ph = canvas.pixel_width, canvas.pixel_height - - for i, (tx, ty) in enumerate(sampled_targets): - # Spawn from random edge - side = rng.choice(["top", "bottom", "left", "right"]) - if side == "top": - sx, sy = rng.uniform(0, pw), rng.uniform(-20, -5) - elif side == "bottom": - sx, sy = rng.uniform(0, pw), rng.uniform(ph + 5, ph + 20) - elif side == "left": - sx, sy = rng.uniform(-20, -5), rng.uniform(0, ph) - else: - sx, sy = rng.uniform(pw + 5, pw + 20), rng.uniform(0, ph) - - delay = rng.uniform(0, 0.4) # staggered start - p = Particle(sx, sy, tx, ty, delay=delay) - # Initial velocity β€” gentle swirl - angle = math.atan2(ph / 2 - sy, pw / 2 - sx) + rng.gauss(0, 0.8) - speed = rng.uniform(1.0, 2.5) - p.vx = math.cos(angle) * speed - p.vy = math.sin(angle) * speed - particles.append(p) - - # Also add some extra ambient particles that never converge - ambient = [] - for _ in range(200): - ax = rng.uniform(0, pw) - ay = rng.uniform(0, ph) - ap = Particle(ax, ay, ax, ay) - ap.vx = rng.gauss(0, 1) - ap.vy = rng.gauss(0, 1) - ambient.append(ap) - - # Timing: 1s converge + 2s hold = 3s total - fps = 24 - converge_frames = int(fps * 0.9) - hold_frames = int(fps * hold_seconds) - total_frames = converge_frames + hold_frames - - with Live(console=console, refresh_per_second=fps, transient=True) as live: - for frame in range(total_frames): - canvas.clear() - t = frame * 0.03 - - # Update ambient particles (always drifting) - for ap in ambient: - ap.x += ap.vx + math.sin(t + ap.phase) * 0.5 - ap.y += ap.vy + math.cos(t + ap.phase * 1.3) * 0.5 - # Wrap around - ap.x = ap.x % pw - ap.y = ap.y % ph - - # Fade out ambient during hold phase - if frame < converge_frames: - alpha = 0.3 + 0.2 * math.sin(t * 2 + ap.phase) - else: - fade = (frame - converge_frames) / hold_frames - alpha = (0.3 + 0.2 * math.sin(t * 2 + ap.phase)) * (1 - fade) - if alpha > 0.25: - canvas.set_pixel(int(ap.x), int(ap.y)) - - if frame < converge_frames: - # Converge phase - progress = frame / converge_frames - noise = settle_curve(progress) - for p in particles: - p.update_converge(t, strength=0.06, damping=0.90) - canvas.set_pixel(int(p.x), int(p.y)) - - # Trail effect - trail_scale = 0.2 + 0.5 * noise - trail_x = int(p.x - p.vx * trail_scale) - trail_y = int(p.y - p.vy * trail_scale) - canvas.set_pixel(trail_x, trail_y) - - # Color transitions from white to warm gold - r, g, b = warm_gold_from_white(progress) - else: - # Hold phase β€” settle into solid logo - settle_t = (frame - converge_frames) / hold_frames - for p in particles: - # Jitter decays to zero - jitter = (1 - settle_t) * 0.7 - jx = p.target_x + math.sin(t * 3 + p.phase) * jitter - jy = p.target_y + math.cos(t * 3 + p.phase * 1.5) * jitter - canvas.set_pixel(int(jx), int(jy)) - canvas.set_pixel(int(p.target_x), int(p.target_y)) - - r, g, b = 255, 200, 80 - - # Render with color - lines = canvas.render() - result = Text() - for line in lines: - for ch in line: - if ch == chr(0x2800): - result.append(ch) - else: - result.append(ch, style=f"rgb({r},{g},{b})") - result.append("\n") - - live.update(Align.center(result)) - time.sleep(1.0 / fps) - - # Print final settled frame - canvas.clear() - for p in particles: - canvas.set_pixel(int(p.target_x), int(p.target_y)) - final = Text() - for line in canvas.render(): - for ch in line: - if ch == chr(0x2800): - final.append(ch) - else: - final.append(ch, style="rgb(255,200,80)") - final.append("\n") - console.print(Align.center(final)) diff --git a/agent/utils/terminal_display.py b/agent/utils/terminal_display.py index d464fd8a727de131268a420753c811273943bfde..6d4d87b211697f599cb3577d8a0cb9d5988559e3 100644 --- a/agent/utils/terminal_display.py +++ b/agent/utils/terminal_display.py @@ -2,82 +2,19 @@ Terminal display utilities β€” rich-powered CLI formatting. """ -import asyncio -import re - from rich.console import Console -from rich.markdown import Heading, Markdown +from rich.markdown import Markdown from rich.panel import Panel from rich.theme import Theme - -class _LeftHeading(Heading): - """Rich's default Markdown renders h1/h2 centered via Align.center. - Yield the styled text directly so headings stay left-aligned.""" - - def __rich_console__(self, console, options): - self.text.justify = "left" - yield self.text - - -Markdown.elements["heading_open"] = _LeftHeading - - -_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") - - -def _clip_to_width(s: str, width: int) -> str: - """Truncate a string to `width` visible columns, preserving ANSI styles. - - Needed for the sub-agent live redraw: cursor-up-and-erase assumes one - logical line == one terminal row. If a line wraps, cursor-up undershoots - and the next redraw corrupts the display. Truncating prevents wrap. - """ - if width <= 0: - return s - out: list[str] = [] - visible = 0 - i = 0 - # Reserve 1 char for the trailing ellipsis - limit = width - 1 - truncated = False - while i < len(s): - m = _ANSI_RE.match(s, i) - if m: - out.append(m.group()) - i = m.end() - continue - if visible >= limit: - truncated = True - break - out.append(s[i]) - visible += 1 - i += 1 - if truncated: - # Strip styles (so ellipsis isn't left hanging inside a style run) - out.append("\033[0m…") - return "".join(out) - - -_THEME = Theme( - { - "tool.name": "bold rgb(255,200,80)", - "tool.args": "dim", - "tool.ok": "dim green", - "tool.fail": "dim red", - "info": "dim", - "muted": "dim", - # Markdown emphasis colors - "markdown.strong": "bold rgb(255,200,80)", - "markdown.emphasis": "italic rgb(180,140,40)", - "markdown.code": "rgb(120,220,255)", - "markdown.code_block": "rgb(120,220,255)", - "markdown.link": "underline rgb(90,180,255)", - "markdown.h1": "bold rgb(255,200,80)", - "markdown.h2": "bold rgb(240,180,95)", - "markdown.h3": "bold rgb(220,165,100)", - } -) +_THEME = Theme({ + "tool.name": "bold cyan", + "tool.args": "dim", + "tool.ok": "dim green", + "tool.fail": "dim red", + "info": "dim", + "muted": "dim", +}) _console = Console(theme=_THEME, highlight=False) @@ -91,86 +28,31 @@ def get_console() -> Console: # ── Banner ───────────────────────────────────────────────────────────── - -def print_banner(model: str | None = None, hf_user: str | None = None) -> None: - """Print particle logo then CRT boot sequence with system info.""" - from agent.utils.particle_logo import run_particle_logo - from agent.utils.crt_boot import run_boot_sequence - - # Particle coalesce logo β€” 1.5s converge, 2s hold - run_particle_logo(_console, hold_seconds=2.0) - - # Clear screen for CRT boot β€” starts from top - _console.file.write("\033[2J\033[H") - _console.file.flush() - - model_label = model or "unknown" - user_label = hf_user or "not logged in" - - # Warm gold palette matching the shimmer highlight (255, 200, 80) - gold = "rgb(255,200,80)" - dim_gold = "rgb(180,140,40)" - - boot_lines = [ - (f"{_I}Initializing agent runtime...", gold), - (f"{_I} User: {user_label}", dim_gold), - (f"{_I} Model: {model_label}", dim_gold), - (f"{_I} Tools: loading...", dim_gold), - ("", ""), - (f"{_I}/help for commands Β· /model to switch Β· /quit to exit", gold), - ] - - run_boot_sequence(_console, boot_lines) +def print_banner() -> None: + Y = "\033[38;2;255;200;50m" # HF yellow + D = "\033[38;2;180;140;40m" # dimmer gold for accents + R = "\033[0m" + art = f""" +{_I}{Y} _ _ _ ___ _ _ {R} +{_I}{Y}| || |_ _ __ _ __ _(_)_ _ __ _ | __|_ _ __ ___ /_\\ __ _ ___ _ _| |_ {R} +{_I}{Y}| __ | || / _` / _` | | ' \\/ _` | | _/ _` / _/ -_) / _ \\/ _` / -_) ' \\ _|{R} +{_I}{Y}|_||_|\\_,_\\__, \\__, |_|_||_\\__, | |_|\\__,_\\__\\___| /_/ \\_\\__, \\___|_||_\\__|{R} +{_I}{D} |___/|___/ |___/ |___/ {R} +""" + _console.print(art, highlight=False) + _console.print(f"{_I}[dim]πŸ€— /help for commands Β· /quit to exit[/dim]\n") # ── Init progress ────────────────────────────────────────────────────── - -def print_init_done(tool_count: int = 0) -> None: - import time - - f = _console.file - # Overwrite the "Tools: loading..." line with actual count - f.write( - "\033[A\033[A\033[A\033[K" - ) # Move up 3 lines (blank + help + blank) then up to tools line - f.write("\033[A\033[K") - gold = "\033[38;2;180;140;40m" - reset = "\033[0m" - tool_text = f"{_I} Tools: {tool_count} loaded" - for ch in tool_text: - f.write(f"{gold}{ch}{reset}") - f.flush() - time.sleep(0.012) - f.write("\n\n") - # Reprint the help line - f.write( - f"{_I}\033[38;2;255;200;80m/help for commands Β· /model to switch Β· /quit to exit{reset}\n\n" - ) - # Ready message β€” minimal padding - f.write( - f"{_I}\033[38;2;255;200;80mReady. Let's build something impressive.{reset}\n" - ) - f.flush() +def print_init_done() -> None: + _console.print(f"{_I}[dim]Ready.[/dim]\n") # ── Tool calls ───────────────────────────────────────────────────────── - def print_tool_call(tool_name: str, args_preview: str) -> None: - import time - - f = _console.file - # CRT-style: type out tool name in HF yellow - gold = "\033[38;2;255;200;80m" - reset = "\033[0m" - f.write(f"{_I}{gold}β–Έ ") - for ch in tool_name: - f.write(ch) - f.flush() - time.sleep(0.015) - f.write(f"{reset} \033[2m{args_preview}{reset}\n") - f.flush() + _console.print(f"{_I}[tool.name]β–Έ {tool_name}[/tool.name] [tool.args]{args_preview}[/tool.args]") def print_tool_output(output: str, success: bool, truncate: bool = True) -> None: @@ -182,86 +64,74 @@ def print_tool_output(output: str, success: bool, truncate: bool = True) -> None _console.print(f"[{style}]{indented}[/{style}]") -class SubAgentDisplayManager: - """Manages multiple concurrent sub-agent displays. - - Each agent gets its own stats and rolling tool-call log. - All agents are rendered together so terminal escape-code - erase/redraw stays consistent. - """ +class SubAgentDisplay: + """Live-updating display: header with stats (ticks every second) + rolling 2-line tool calls.""" - _MAX_VISIBLE = 4 # tool-call lines shown per agent + _MAX_VISIBLE = 2 def __init__(self): - self._agents: dict[str, dict] = {} # agent_id -> state dict + self._calls: list[str] = [] + self._tool_count = 0 + self._token_count = 0 + self._start_time: float | None = None self._lines_on_screen = 0 + self._ticker_task = None - def start(self, agent_id: str, label: str = "research") -> None: + def start(self) -> None: + """Begin the display with a 1-second ticker.""" + import asyncio import time + self._calls = [] + self._tool_count = 0 + self._token_count = 0 + self._start_time = time.monotonic() + self._redraw() + self._ticker_task = asyncio.ensure_future(self._tick()) + + def set_tokens(self, tokens: int) -> None: + self._token_count = tokens + # no redraw β€” ticker handles it + + def set_tool_count(self, count: int) -> None: + self._tool_count = count + # no redraw β€” ticker handles it - self._agents[agent_id] = { - "label": label, - "calls": [], - "tool_count": 0, - "token_count": 0, - "start_time": time.monotonic(), - } + def add_call(self, tool_desc: str) -> None: + self._calls.append(tool_desc) self._redraw() - def set_tokens(self, agent_id: str, tokens: int) -> None: - if agent_id in self._agents: - self._agents[agent_id]["token_count"] = tokens - - def set_tool_count(self, agent_id: str, count: int) -> None: - if agent_id in self._agents: - self._agents[agent_id]["tool_count"] = count - - def add_call(self, agent_id: str, tool_desc: str) -> None: - if agent_id in self._agents: - self._agents[agent_id]["calls"].append(tool_desc) - self._redraw() - - def clear(self, agent_id: str) -> None: - # On completion: erase the live region, freeze a single-line summary - # for this agent ("βœ“ research: … (stats)") above the live region so - # the user sees each sub-agent finish cleanly without the tool-call - # noise, then redraw remaining live agents. - agent = self._agents.pop(agent_id, None) + def clear(self) -> None: + if self._ticker_task: + self._ticker_task.cancel() + self._ticker_task = None self._erase() - if agent is not None: - width = max(10, _console.width) - line = _clip_to_width(self._render_completion_line(agent), width) - _console.file.write(line + "\n") - _console.file.flush() self._lines_on_screen = 0 - if self._agents: - self._redraw() - - @staticmethod - def _render_completion_line(agent: dict) -> str: - stats = SubAgentDisplayManager._format_stats(agent) - label = agent["label"] - # dim green check + dim label; stats in parens - line = f"{_I}\033[38;2;120;200;140mβœ“\033[0m \033[2m{label}\033[0m" - if stats: - line += f" \033[2m({stats})\033[0m" - return line - - @staticmethod - def _format_stats(agent: dict) -> str: + self._calls = [] + self._start_time = None + + async def _tick(self) -> None: + import asyncio + try: + while True: + await asyncio.sleep(1.0) + self._redraw() + except asyncio.CancelledError: + pass + + def _format_stats(self) -> str: import time - - start = agent["start_time"] - if start is None: + if self._start_time is None: return "" - elapsed = time.monotonic() - start + elapsed = time.monotonic() - self._start_time if elapsed < 60: time_str = f"{elapsed:.0f}s" else: time_str = f"{elapsed / 60:.0f}m {elapsed % 60:.0f}s" - tok = agent["token_count"] - tok_str = f"{tok / 1000:.1f}k" if tok >= 1000 else str(tok) - return f"{agent['tool_count']} tool uses Β· {tok_str} tokens Β· {time_str}" + if self._token_count >= 1000: + tok_str = f"{self._token_count / 1000:.1f}k" + else: + tok_str = str(self._token_count) + return f"{self._tool_count} tool uses Β· {tok_str} tokens Β· {time_str}" def _erase(self) -> None: if self._lines_on_screen > 0: @@ -270,134 +140,52 @@ class SubAgentDisplayManager: f.write("\033[A\033[K") f.flush() - def _render_agent_lines(self, agent: dict, compact: bool = False) -> list[str]: - """Render one agent's block. - - compact=True β†’ single line (label + stats + most-recent tool name); - compact=False β†’ header + up to _MAX_VISIBLE rolling tool-call lines. - We use compact mode when multiple agents are live so the total live - region stays small enough to fit on one screen. Otherwise cursor-up - can't reach lines that have scrolled into scrollback, and every - redraw pollutes history with a stale copy. - """ - stats = self._format_stats(agent) - label = agent["label"] - header = f"{_I}\033[38;2;255;200;80mβ–Έ {label}\033[0m" + def _redraw(self) -> None: + f = _console.file + self._erase() + lines = [] + # Header: β–Έ research (stats) + stats = self._format_stats() + header = f"{_I}\033[1;36mβ–Έ research\033[0m" if stats: header += f" \033[2m({stats})\033[0m" - if compact: - latest = agent["calls"][-1] if agent["calls"] else "" - if latest: - # Strip long json tails for the inline view - short = latest.split(" ")[0] if " " in latest else latest - header += f" \033[2mΒ·\033[0m \033[2m{short}\033[0m" - return [header] - lines = [header] - visible = agent["calls"][-self._MAX_VISIBLE :] + lines.append(header) + # Last 2 tool calls, gray + visible = self._calls[-self._MAX_VISIBLE:] for desc in visible: lines.append(f"{_I} \033[2m{desc}\033[0m") - return lines - - def _redraw(self) -> None: - f = _console.file - self._erase() - compact = len(self._agents) > 1 - width = max(10, _console.width) - lines: list[str] = [] - for agent in self._agents.values(): - for ln in self._render_agent_lines(agent, compact=compact): - lines.append(_clip_to_width(ln, width)) for line in lines: f.write(line + "\n") f.flush() self._lines_on_screen = len(lines) -_subagent_display = SubAgentDisplayManager() +_subagent_display = SubAgentDisplay() -def print_tool_log(tool: str, log: str, agent_id: str = "", label: str = "") -> None: +def print_tool_log(tool: str, log: str) -> None: """Handle tool log events β€” sub-agent calls get the rolling display.""" if tool == "research": - aid = agent_id or "research" if log == "Starting research sub-agent...": - _subagent_display.start(aid, label or "research") + _subagent_display.start() elif log == "Research complete.": - _subagent_display.clear(aid) + _subagent_display.clear() elif log.startswith("tokens:"): - _subagent_display.set_tokens(aid, int(log[7:])) + _subagent_display.set_tokens(int(log[7:])) elif log.startswith("tools:"): - _subagent_display.set_tool_count(aid, int(log[6:])) + _subagent_display.set_tool_count(int(log[6:])) else: - _subagent_display.add_call(aid, log) + _subagent_display.add_call(log) else: _console.print(f"{_I}[dim]{tool}: {log}[/dim]") # ── Messages ─────────────────────────────────────────────────────────── - -async def print_markdown( - text: str, - cancel_event: "asyncio.Event | None" = None, - instant: bool = False, -) -> None: - import io - import random +def print_markdown(text: str) -> None: from rich.padding import Padding - _console.print() - - # Render markdown to a string buffer so we can type it out - buf = io.StringIO() - # Important: StringIO is not a TTY, so Rich would normally strip styles. - # Force terminal rendering so ANSI style codes are preserved for typewriter output. - buf_console = Console( - file=buf, - width=_console.width, - highlight=False, - theme=_THEME, - force_terminal=True, - color_system=_console.color_system or "truecolor", - ) - buf_console.print(Padding(Markdown(text), (0, 0, 0, 2))) - rendered = buf.getvalue() - - # Strip trailing whitespace from each line so we don't type across the full width - lines = rendered.split("\n") - rendered = "\n".join(line.rstrip() for line in lines) - - f = _console.file - - # Headless / non-interactive: dump the rendered markdown in one write. - if instant: - f.write(rendered) - f.write("\n") - f.flush() - return - - # CRT typewriter effect β€” async so the event loop can service signal - # handlers (Ctrl+C during streaming) between characters. If cancelled - # mid-type, stop cleanly: write an ANSI reset so half-open color state - # doesn't bleed onto the "interrupted" line, and return. - rng = random.Random(42) - cancelled = False - for ch in rendered: - if cancel_event is not None and cancel_event.is_set(): - cancelled = True - break - f.write(ch) - f.flush() - if ch == "\n": - await asyncio.sleep(0.002) - elif ch == " ": - await asyncio.sleep(0.002) - elif rng.random() < 0.03: - await asyncio.sleep(0.015) - else: - await asyncio.sleep(0.004) - f.write("\033[0m\n" if cancelled else "\n") - f.flush() + _console.print(Padding(Markdown(text), (0, 0, 0, 2))) def print_error(message: str) -> None: @@ -413,35 +201,23 @@ def print_interrupted() -> None: def print_compacted(old_tokens: int, new_tokens: int) -> None: - _console.print( - f"{_I}[dim]context compacted: {old_tokens:,} β†’ {new_tokens:,} tokens[/dim]" - ) + _console.print(f"{_I}[dim]context compacted: {old_tokens:,} β†’ {new_tokens:,} tokens[/dim]") # ── Approval ─────────────────────────────────────────────────────────── - def print_approval_header(count: int) -> None: label = f"Approval required β€” {count} item{'s' if count != 1 else ''}" _console.print() - _console.print( - f"{_I}", - Panel( - f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False - ), - ) + _console.print(f"{_I}", Panel(f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False)) def print_approval_item(index: int, total: int, tool_name: str, operation: str) -> None: - _console.print( - f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}" - ) + _console.print(f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}") def print_yolo_approve(count: int) -> None: - _console.print( - f"{_I}[bold yellow]yolo β†’[/bold yellow] auto-approved {count} item(s)" - ) + _console.print(f"{_I}[bold yellow]yolo β†’[/bold yellow] auto-approved {count} item(s)") # ── Help ─────────────────────────────────────────────────────────────── @@ -452,10 +228,8 @@ HELP_TEXT = f"""\ {_I} [cyan]/undo[/cyan] Undo last turn {_I} [cyan]/compact[/cyan] Compact context window {_I} [cyan]/model[/cyan] [id] Show available models or switch -{_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off) {_I} [cyan]/yolo[/cyan] Toggle auto-approve mode {_I} [cyan]/status[/cyan] Current model & turn count -{_I} [cyan]/share-traces[/cyan] [public|private] Show/flip visibility of your HF trace dataset {_I} [cyan]/quit[/cyan] Exit""" @@ -467,7 +241,6 @@ def print_help() -> None: # ── Plan display ─────────────────────────────────────────────────────── - def format_plan_display() -> str: """Format the current plan for display.""" from agent.tools.plan_tool import get_current_plan @@ -501,7 +274,6 @@ def print_plan() -> None: # ── Formatting for plan_tool output (used by plan_tool handler) ──────── - def format_plan_tool_output(todos: list) -> str: if not todos: return "Plan is empty." @@ -524,7 +296,6 @@ def format_plan_tool_output(todos: list) -> str: # ── Internal helpers ─────────────────────────────────────────────────── - def _truncate(text: str, max_lines: int = 6) -> str: lines = text.split("\n") if len(lines) <= max_lines: diff --git a/backend/dependencies.py b/backend/dependencies.py index 58e02b7ee9a108e586496516bbd076e85b735fa2..ce32eb4ea02175bbcea0ff8d3cd73e77def3a210 100644 --- a/backend/dependencies.py +++ b/backend/dependencies.py @@ -7,22 +7,15 @@ import logging import os import time -from collections.abc import Iterable -from hashlib import sha256 from typing import Any import httpx from fastapi import HTTPException, Request, status -from agent.core.hf_tokens import bearer_token_from_header, clean_hf_token - -from agent.core.hf_access import fetch_whoami_v2 - logger = logging.getLogger(__name__) OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", "")) -HF_EMPLOYEE_ORG = os.environ.get("HF_EMPLOYEE_ORG", "huggingface") # Simple in-memory token cache: token -> (user_info, expiry_time) _token_cache: dict[str, tuple[dict[str, Any], float]] = {} @@ -35,63 +28,8 @@ DEV_USER: dict[str, Any] = { "user_id": "dev", "username": "dev", "authenticated": True, - "plan": "pro", # Dev runs at the Pro quota tier so local testing isn't capped. } -INTERNAL_HF_TOKEN_KEY = "_hf_token" -OAUTH_SCOPE_COOKIE = "hf_oauth_scope_hash" -REQUIRED_OAUTH_SCOPES: tuple[str, ...] = ( - "openid", - "profile", - "read-repos", - "write-repos", - "contribute-repos", - "manage-repos", - "write-collections", - "inference-api", - "jobs", - "write-discussions", -) - -# Log the whoami-v2 shape once at DEBUG so we can confirm the production Pro -# signal without hammering the HF API. -_WHOAMI_SHAPE_LOGGED = False - - -def normalize_oauth_scopes(scopes: Iterable[str]) -> tuple[str, ...]: - """Return stable, de-duplicated OAuth scopes preserving declaration order.""" - seen: set[str] = set() - normalized: list[str] = [] - for scope in scopes: - value = str(scope).strip() - if not value or value in seen: - continue - seen.add(value) - normalized.append(value) - return tuple(normalized) - - -def configured_oauth_scopes() -> tuple[str, ...]: - """Return the scopes this backend should request from HF OAuth. - - Spaces expose README ``hf_oauth_scopes`` through ``OAUTH_SCOPES``. Unioning - that value with the app-required scopes keeps the local request and Space - metadata in sync while ensuring new required scopes are never omitted. - """ - env_scopes = os.environ.get("OAUTH_SCOPES", "").split() - return normalize_oauth_scopes((*env_scopes, *REQUIRED_OAUTH_SCOPES)) - - -def oauth_scope_fingerprint(scopes: Iterable[str] | None = None) -> str: - """Return a non-secret fingerprint for the current OAuth scope contract.""" - scope_list = configured_oauth_scopes() if scopes is None else scopes - payload = " ".join(sorted(normalize_oauth_scopes(scope_list))) - return sha256(payload.encode("utf-8")).hexdigest()[:16] - - -def _cookie_has_current_oauth_scope_marker(request: Request) -> bool: - return request.cookies.get(OAUTH_SCOPE_COOKIE) == oauth_scope_fingerprint() - async def _validate_token(token: str) -> dict[str, Any] | None: """Validate a token against HF OAuth userinfo endpoint. @@ -136,84 +74,12 @@ def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]: } -def _normalize_user_plan(whoami: Any) -> str: - """Normalize a whoami-v2 payload to the app's personal quota tiers.""" - if not isinstance(whoami, dict): - return "free" - - if whoami.get("isPro") is True: - return "pro" - - return "free" - - -async def _fetch_user_plan(token: str) -> str: - """Look up the user's HF plan via /api/whoami-v2. - - Returns 'free' | 'pro'. Non-200, network errors, or an unknown - payload shape all collapse to 'free' β€” safe default; we'd rather under- - grant the Pro cap than over-grant it on bad data. - """ - global _WHOAMI_SHAPE_LOGGED - whoami = await fetch_whoami_v2(token) - if whoami is None: - return "free" - - if not _WHOAMI_SHAPE_LOGGED: - _WHOAMI_SHAPE_LOGGED = True - logger.debug( - "whoami-v2 payload keys: %s (sample values: isPro=%r)", - sorted(whoami.keys()) - if isinstance(whoami, dict) - else type(whoami).__name__, - whoami.get("isPro") if isinstance(whoami, dict) else None, - ) - - return _normalize_user_plan(whoami) - - async def _extract_user_from_token(token: str) -> dict[str, Any] | None: """Validate a token and return a user dict, or None.""" user_info = await _validate_token(token) - if user_info is None: - return None - user = _user_from_info(user_info) - user["plan"] = await _fetch_user_plan(token) - user[INTERNAL_HF_TOKEN_KEY] = clean_hf_token(token) - return user - - -async def _dev_user_from_env() -> dict[str, Any]: - """Use HF_TOKEN as the dev identity when available. - - Local dev often runs without OAuth, but session trace uploads still need a - real HF namespace. Deriving the dev user from HF_TOKEN keeps local uploads - pointed at the token owner's dataset instead of dev/ml-intern-sessions. - """ - token = clean_hf_token(os.environ.get("HF_TOKEN", "")) - if not token: - return dict(DEV_USER) - - whoami = await fetch_whoami_v2(token) - if not isinstance(whoami, dict): - return dict(DEV_USER) - - username = None - for key in ("name", "user", "preferred_username"): - value = whoami.get(key) - if isinstance(value, str) and value: - username = value - break - if not username: - return dict(DEV_USER) - - return { - "user_id": username, - "username": username, - "authenticated": True, - "plan": await _fetch_user_plan(token), - INTERNAL_HF_TOKEN_KEY: token, - } + if user_info: + return _user_from_info(user_info) + return None async def check_org_membership(token: str, org_name: str) -> bool: @@ -248,15 +114,15 @@ async def get_current_user(request: Request) -> dict[str, Any]: 1. Authorization: Bearer header 2. hf_access_token cookie - In dev mode (AUTH_ENABLED=False), uses HF_TOKEN as the user when possible. + In dev mode (AUTH_ENABLED=False), returns a default dev user. """ if not AUTH_ENABLED: - return await _dev_user_from_env() + return DEV_USER - # Bearer callers manage token lifecycle themselves; only browser cookie - # auth is forced through the scope-freshness marker below. - token = bearer_token_from_header(request.headers.get("Authorization", "")) - if token: + # Try Authorization header + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] user = await _extract_user_from_token(token) if user: return user @@ -264,15 +130,6 @@ async def get_current_user(request: Request) -> dict[str, Any]: # Try cookie token = request.cookies.get("hf_access_token") if token: - if not _cookie_has_current_oauth_scope_marker(request): - logger.info( - "Rejecting stale HF OAuth cookie; current scopes require refresh." - ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication scopes changed. Please log in again.", - headers={"WWW-Authenticate": "Bearer"}, - ) user = await _extract_user_from_token(token) if user: return user @@ -284,27 +141,3 @@ async def get_current_user(request: Request) -> dict[str, Any]: ) -def _extract_token(request: Request) -> str | None: - """Pull the HF access token from the Authorization header or cookie. - - Mirrors the lookup order used by ``get_current_user``. - """ - token = bearer_token_from_header(request.headers.get("Authorization", "")) - if token: - return token - return request.cookies.get("hf_access_token") - - -async def require_huggingface_org_member(request: Request) -> bool: - """Return True if the caller is a member of the ``huggingface`` org. - - Used to gate endpoints that can push a session onto an Anthropic model - billed to the Space's ``ANTHROPIC_API_KEY``. Returns True unconditionally - in dev mode so local testing isn't blocked. - """ - if not AUTH_ENABLED: - return True - token = _extract_token(request) - if not token: - return False - return await check_org_membership(token, HF_EMPLOYEE_ORG) diff --git a/backend/kpis_scheduler.py b/backend/kpis_scheduler.py deleted file mode 100644 index 9b2199c69151118762ed2cfaddde579fb5a694d3..0000000000000000000000000000000000000000 --- a/backend/kpis_scheduler.py +++ /dev/null @@ -1,148 +0,0 @@ -"""In-process hourly KPI rollup, owned by the backend Space lifespan. - -Replaces an external GitHub Actions cron so the rollup lives next to the data -and reuses the Space's existing HF token β€” no production secrets on the -public source repo. See ``scripts/build_kpis.py`` for the data-flow diagram -and metric definitions. - -Behaviour:: - - lifespan startup β†’ start APScheduler with cron("5 * * * *", UTC) - β†’ fire a best-effort 6-hour backfill (fire-and-forget) - each :05 β†’ run ``build_kpis.run_for_hour`` for the just-completed hour - lifespan shutdown β†’ scheduler.shutdown(wait=False) - -Environment:: - - HF_KPI_WRITE_TOKEN | HF_SESSION_UPLOAD_TOKEN | HF_TOKEN | HF_ADMIN_TOKEN - First one found is used. Least-privilege first. - KPI_SOURCE_REPO default smolagents/ml-intern-sessions - KPI_TARGET_REPO default smolagents/ml-intern-kpis - ML_INTERN_KPIS_DISABLED if truthy, the scheduler is not started -""" - -from __future__ import annotations - -import asyncio -import importlib.util -import logging -import os -from datetime import datetime, timedelta, timezone -from pathlib import Path -from typing import Optional - -logger = logging.getLogger(__name__) - -_PROJECT_ROOT = Path(__file__).resolve().parent.parent - -# Hold strong refs to backfill tasks so asyncio doesn't GC them mid-run. -_background_tasks: set[asyncio.Task] = set() - -_scheduler = None # AsyncIOScheduler instance (lazy import) - - -def _resolve_token() -> Optional[str]: - """Pick the first available HF token. Least-privilege first.""" - for var in ( - "HF_KPI_WRITE_TOKEN", - "HF_SESSION_UPLOAD_TOKEN", - "HF_TOKEN", - "HF_ADMIN_TOKEN", - ): - val = os.environ.get(var) - if val: - return val - return None - - -def _load_build_kpis(): - """Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path.""" - spec = importlib.util.spec_from_file_location( - "build_kpis", - _PROJECT_ROOT / "scripts" / "build_kpis.py", - ) - mod = importlib.util.module_from_spec(spec) - assert spec.loader is not None - spec.loader.exec_module(mod) - return mod - - -async def _run_hour(hour_dt: datetime) -> None: - """Run one hourly rollup off the event loop. Best-effort, never raises.""" - token = _resolve_token() - if not token: - logger.warning("kpis_scheduler: no HF token available, skipping %s", hour_dt) - return - try: - mod = _load_build_kpis() - from huggingface_hub import HfApi - - api = HfApi() - source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions") - target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis") - await asyncio.to_thread(mod.run_for_hour, api, source, target, hour_dt, token) - except Exception as e: - logger.warning("kpis_scheduler: rollup for %s failed: %s", hour_dt, e) - - -async def run_last_completed_hour() -> None: - """The scheduled-at-:05 job. Rolls up the previous whole hour.""" - now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) - await _run_hour(now - timedelta(hours=1)) - - -async def backfill(hours: int = 6) -> None: - """Catch-up pass for hours the Space was down. Idempotent (overwrites).""" - now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) - for i in range(1, hours + 1): - await _run_hour(now - timedelta(hours=i)) - - -def start(backfill_hours: int = 6) -> None: - """Called from FastAPI lifespan startup.""" - global _scheduler - if os.environ.get("ML_INTERN_KPIS_DISABLED"): - logger.info("kpis_scheduler: disabled via ML_INTERN_KPIS_DISABLED") - return - if _scheduler is not None: - return - - try: - from apscheduler.schedulers.asyncio import AsyncIOScheduler - from apscheduler.triggers.cron import CronTrigger - except ImportError: - logger.warning("kpis_scheduler: apscheduler not installed, skipping") - return - - _scheduler = AsyncIOScheduler(timezone="UTC") - _scheduler.add_job( - run_last_completed_hour, - CronTrigger(minute=5), - id="kpis_hourly", - misfire_grace_time=600, # tolerate a 10-min misfire window - coalesce=True, # collapse multiple missed fires into one - max_instances=1, - replace_existing=True, - ) - _scheduler.start() - logger.info("kpis_scheduler: started (cron '5 * * * *' UTC)") - - # Non-blocking backfill. Hold a strong ref until done so asyncio doesn't - # GC the task before it finishes. - try: - task = asyncio.get_running_loop().create_task(backfill(backfill_hours)) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) - except RuntimeError: - # Not in an event loop (tests); skip backfill. - pass - - -async def shutdown() -> None: - """Called from FastAPI lifespan shutdown.""" - global _scheduler - if _scheduler is None: - return - _scheduler.shutdown(wait=False) - _scheduler = None - logger.info("kpis_scheduler: stopped") diff --git a/backend/main.py b/backend/main.py index 3a8983055db6e1dbd8c263d6bba6022edc3a1a01..888740e539ae72ad12099f7ec6a2190cefbc4ac5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -9,15 +9,12 @@ from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles +from routes.agent import router as agent_router +from routes.auth import router as auth_router -# Load .env before importing routes/session_manager so persistence and quota -# modules see local Mongo settings during startup. +# Load .env from project root (parent directory) load_dotenv(Path(__file__).parent.parent / ".env") -from routes.agent import router as agent_router # noqa: E402 -from routes.auth import router as auth_router # noqa: E402 -from session_manager import session_manager # noqa: E402 - # Configure logging logging.basicConfig( level=logging.INFO, @@ -30,54 +27,15 @@ logger = logging.getLogger(__name__) async def lifespan(app: FastAPI): """Application lifespan handler.""" logger.info("Starting HF Agent backend...") - await session_manager.start() - # Start in-process hourly KPI rollup. Replaces an external cron so the - # rollup lives next to the data and reuses the Space's HF token. - try: - import kpis_scheduler - - kpis_scheduler.start() - except Exception as e: - logger.warning("KPI scheduler failed to start: %s", e) yield - logger.info("Shutting down HF Agent backend...") - try: - import kpis_scheduler - - await kpis_scheduler.shutdown() - except Exception as e: - logger.warning("KPI scheduler shutdown failed: %s", e) - - # Final-flush: save every still-active session so we don't lose traces on - # server restart. Uploads are detached subprocesses β€” this is fast. - try: - for sid, agent_session in list(session_manager.sessions.items()): - sess = agent_session.session - if sess.config.save_sessions: - try: - sess.save_and_upload_detached(sess.config.session_dataset_repo) - logger.info("Flushed session %s on shutdown", sid) - except Exception as e: - logger.warning("Failed to flush session %s: %s", sid, e) - except Exception as e: - logger.warning("Lifespan final-flush skipped: %s", e) - await session_manager.close() - - -# Disable FastAPI auto-docs when running on HF Spaces (SPACE_ID is set by the -# platform) to avoid exposing the full API surface to anonymous visitors. Local -# dev keeps /docs and /redoc available. -_DOCS_DISABLED = os.environ.get("SPACE_ID") is not None + app = FastAPI( title="HF Agent", description="ML Engineering Assistant API", version="1.0.0", lifespan=lifespan, - docs_url=None if _DOCS_DISABLED else "/docs", - redoc_url=None if _DOCS_DISABLED else "/redoc", - openapi_url=None if _DOCS_DISABLED else "/openapi.json", ) # CORS middleware for development diff --git a/backend/models.py b/backend/models.py index 1126c2b92b93e1f7f429abfd02e97be4b145ae1f..3b76bf54e61dcc3a158c3f5f3dfa621d711596c8 100644 --- a/backend/models.py +++ b/backend/models.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel class OpType(str, Enum): @@ -38,7 +38,6 @@ class ToolApproval(BaseModel): approved: bool feedback: str | None = None edited_script: str | None = None - namespace: str | None = None class ApprovalRequest(BaseModel): @@ -52,10 +51,7 @@ class SubmitRequest(BaseModel): """Request to submit user input.""" session_id: str - # Cap text size to prevent context-bloat / cost-amplification: a malicious - # or runaway client could otherwise attach megabytes that then ride along - # in every subsequent turn until /api/compact is called. - text: str = Field(..., min_length=1, max_length=100_000) + text: str class TruncateRequest(BaseModel): @@ -69,7 +65,6 @@ class SessionResponse(BaseModel): session_id: str ready: bool = True - model: str | None = None class PendingApprovalTool(BaseModel): @@ -80,15 +75,6 @@ class PendingApprovalTool(BaseModel): arguments: dict[str, Any] = {} -class SessionAutoApprovalInfo(BaseModel): - """Per-session auto-approval budget state.""" - - enabled: bool = False - cost_cap_usd: float | None = None - estimated_spend_usd: float = 0.0 - remaining_usd: float | None = None - - class SessionInfo(BaseModel): """Session metadata.""" @@ -99,25 +85,6 @@ class SessionInfo(BaseModel): message_count: int user_id: str = "dev" pending_approval: list[PendingApprovalTool] | None = None - model: str | None = None - title: str | None = None - notification_destinations: list[str] = Field(default_factory=list) - auto_approval: SessionAutoApprovalInfo = Field( - default_factory=SessionAutoApprovalInfo - ) - - -class SessionNotificationsRequest(BaseModel): - """Replace the session's auto-notification destinations.""" - - destinations: list[str] - - -class SessionYoloRequest(BaseModel): - """Update a session's auto-approval policy.""" - - enabled: bool - cost_cap_usd: float | None = Field(default=None, ge=0) class HealthResponse(BaseModel): @@ -134,6 +101,4 @@ class LLMHealthResponse(BaseModel): status: str # "ok" | "error" model: str error: str | None = None - error_type: str | None = ( - None # "auth" | "credits" | "rate_limit" | "network" | "unknown" - ) + error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown" diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 0a742b7c793a1949c2f8dcd72f7213bc4403eaf2..9a8c2bc99e0120553966baeefa4aab83fa78c5a9 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -7,229 +7,68 @@ dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically. import asyncio import json import logging +import os from typing import Any -from dependencies import ( - INTERNAL_HF_TOKEN_KEY, - get_current_user, -) +from dependencies import get_current_user from fastapi import ( APIRouter, Depends, HTTPException, Request, ) -from fastapi.exceptions import RequestValidationError from fastapi.responses import StreamingResponse from litellm import acompletion -from pydantic import ValidationError from models import ( ApprovalRequest, HealthResponse, LLMHealthResponse, SessionInfo, - SessionNotificationsRequest, SessionResponse, - SessionYoloRequest, SubmitRequest, TruncateRequest, ) -from session_manager import ( - MAX_SESSIONS, - AgentSession, - SessionCapacityError, - session_manager, -) +from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager -import user_quotas - -from agent.core.hf_access import get_jobs_access -from agent.core.hf_tokens import resolve_hf_request_token, resolve_hf_router_token -from agent.core.llm_params import _resolve_llm_params +from agent.core.agent_loop import _resolve_hf_router_params logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) -_background_teardown_tasks: set[asyncio.Task] = set() - -DEFAULT_CLAUDE_MODEL_ID = "bedrock/us.anthropic.claude-opus-4-6-v1" -DEFAULT_FREE_MODEL_ID = "moonshotai/Kimi-K2.6" -PREMIUM_MODEL_IDS = { - DEFAULT_CLAUDE_MODEL_ID, - "openai/gpt-5.5", -} - - -def _claude_picker_model_id() -> str: - """Return the model ID used by the Claude option in the UI. - - The frontend config sets ``session_manager.config.model_name`` from - ``ML_INTERN_CLAUDE_MODEL_ID`` when that env var is present, otherwise it - falls back to the production Bedrock Claude model. This function only - exposes that resolved config value for the Claude picker; non-Claude models - are listed separately in the model switcher. - """ - return session_manager.config.model_name - - -def _available_models() -> list[dict[str, Any]]: - models = [ - { - "id": "moonshotai/Kimi-K2.6", - "label": "Kimi K2.6", - "provider": "huggingface", - "tier": "free", - "recommended": True, - }, - { - "id": _claude_picker_model_id(), - "label": "Claude Opus 4.6", - "provider": "anthropic", - "tier": "pro", - "recommended": True, - }, - { - "id": "openai/gpt-5.5", - "label": "GPT-5.5", - "provider": "openai", - "tier": "pro", - }, - { - "id": "MiniMaxAI/MiniMax-M2.7", - "label": "MiniMax M2.7", - "provider": "huggingface", - "tier": "free", - }, - { - "id": "zai-org/GLM-5.1", - "label": "GLM 5.1", - "provider": "huggingface", - "tier": "free", - }, - { - "id": "deepseek-ai/DeepSeek-V4-Pro:deepinfra", - "label": "DeepSeek V4 Pro", - "provider": "huggingface", - "tier": "free", - }, - ] - return models - - -AVAILABLE_MODELS = _available_models() - -def _is_premium_model(model_id: str) -> bool: - return model_id in PREMIUM_MODEL_IDS - - -async def _model_override_for_new_session( - request: Request, - requested_model: str | None, -) -> str | None: - """Return the model override to use when creating a new session. - - Explicit premium model requests are allowed and charged at message-submit - time. Implicit default sessions are more forgiving: when the configured - default is premium, start them on the first free model instead of spending - premium quota accidentally. - """ - resolved_model = requested_model or session_manager.config.model_name - if not _is_premium_model(resolved_model): - return requested_model - if requested_model: - return requested_model - - logger.info( - "Default premium model %s would spend quota; " - "creating session with free fallback %s", - resolved_model, - DEFAULT_FREE_MODEL_ID, - ) - return DEFAULT_FREE_MODEL_ID - - -async def _enforce_premium_model_quota( - user: dict[str, Any], - agent_session: AgentSession, -) -> None: - """Charge the user's daily premium-model quota on first use in a session. - - Runs at *message-submit* time, not session-create time β€” so spinning up a - premium-model session to look around doesn't burn quota. The - ``claude_counted`` flag on ``AgentSession`` guards against re-counting the - same session; the stored field name is kept for persistence compatibility. - - No-ops when the session's current model isn't premium, or when this - session has already been charged. Raises 429 when the user has hit - their daily cap. - """ - if agent_session.claude_counted: - return - model_name = agent_session.session.config.model_name - if not _is_premium_model(model_name): - return - user_id = user["user_id"] - plan = user.get("plan", "free") - cap = user_quotas.daily_cap_for(plan) - new_count = await user_quotas.try_increment_claude(user_id, cap) - if new_count is None: - if plan == "pro": - message = ( - "Daily premium model limit reached. Use a free model and try " - "premium models again tomorrow." - ) - else: - message = ( - "Daily premium model limit reached. Upgrade to HF Pro for " - f"{user_quotas.CLAUDE_PRO_DAILY}/day or use a free model." - ) - raise HTTPException( - status_code=429, - detail={ - "error": "premium_model_daily_cap", - "plan": plan, - "cap": cap, - "message": message, - }, - ) - agent_session.claude_counted = True - await session_manager.persist_session_snapshot(agent_session) - - -def _user_hf_token(user: dict[str, Any] | None) -> str | None: - if not isinstance(user, dict): - return None - return user.get(INTERNAL_HF_TOKEN_KEY) - - -async def _check_session_access( - session_id: str, - user: dict[str, Any], - request: Request | None = None, - preload_sandbox: bool = True, -) -> AgentSession: - """Verify and lazily load the user's session. Raises 403 or 404.""" - hf_token = ( - resolve_hf_request_token(request) - if request is not None - else _user_hf_token(user) - ) - agent_session = await session_manager.ensure_session_loaded( - session_id, - user["user_id"], - hf_token=hf_token, - hf_username=user.get("username"), - preload_sandbox=preload_sandbox, - ) - if not agent_session: +AVAILABLE_MODELS = [ + { + "id": "anthropic/claude-opus-4-6", + "label": "Claude Opus 4.6", + "provider": "anthropic", + "recommended": True, + }, + { + "id": "huggingface/fireworks-ai/MiniMaxAI/MiniMax-M2.5", + "label": "MiniMax M2.5", + "provider": "huggingface", + "recommended": True, + }, + { + "id": "huggingface/novita/moonshotai/kimi-k2.5", + "label": "Kimi K2.5", + "provider": "huggingface", + }, + { + "id": "huggingface/novita/zai-org/glm-5", + "label": "GLM 5", + "provider": "huggingface", + }, +] + + +def _check_session_access(session_id: str, user: dict[str, Any]) -> None: + """Verify the user has access to the given session. Raises 403 or 404.""" + info = session_manager.get_session_info(session_id) + if not info: raise HTTPException(status_code=404, detail="Session not found") - if user["user_id"] != "dev" and agent_session.user_id not in { - user["user_id"], - "dev", - }: + if not session_manager.verify_session_access(session_id, user["user_id"]): raise HTTPException(status_code=403, detail="Access denied to this session") - return agent_session @router.get("/health", response_model=HealthResponse) @@ -254,7 +93,7 @@ async def llm_health_check() -> LLMHealthResponse: """ model = session_manager.config.model_name try: - llm_params = _resolve_llm_params(model, reasoning_effort="high") + llm_params = _resolve_hf_router_params(model) await acompletion( messages=[{"role": "user", "content": "hi"}], max_tokens=1, @@ -304,71 +143,55 @@ async def get_model() -> dict: } -_TITLE_STRIP_CHARS = str.maketrans("", "", "`*_~#[]()") +@router.post("/config/model") +async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict: + """Set the LLM model. Applies to new conversations.""" + model_id = body.get("model") + if not model_id: + raise HTTPException(status_code=400, detail="Missing 'model' field") + valid_ids = {m["id"] for m in AVAILABLE_MODELS} + if model_id not in valid_ids: + raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") + session_manager.config.model_name = model_id + logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}") + return {"model": model_id} @router.post("/title") async def generate_title( request: SubmitRequest, user: dict = Depends(get_current_user) ) -> dict: - """Generate a short title for a chat session based on the first user message. - - Always uses gpt-oss-120b via Cerebras on the HF router. The tab headline - renders as plain text, so the model is told to avoid markdown and any - stray formatting characters are stripped before returning. gpt-oss is a - reasoning model β€” reasoning_effort=low keeps the reasoning budget small - so the 60-token output budget isn't consumed before the title is written. - """ - api_key = resolve_hf_router_token(_user_hf_token(user)) + """Generate a short title for a chat session based on the first user message.""" + model = session_manager.config.model_name + llm_params = _resolve_hf_router_params(model) try: response = await acompletion( - # Double openai/ prefix: LiteLLM strips the first as its provider - # prefix, leaving the HF model id on the wire for the router. - model="openai/openai/gpt-oss-120b:cerebras", - api_base="https://router.huggingface.co/v1", - api_key=api_key, messages=[ { "role": "system", "content": ( "Generate a very short title (max 6 words) for a chat conversation " "that starts with the following user message. " - "Reply with ONLY the title in plain text. " - "Do NOT use markdown, backticks, asterisks, quotes, brackets, or any " - "formatting characters. No punctuation at the end." + "Reply with ONLY the title, no quotes, no punctuation at the end." ), }, {"role": "user", "content": request.text[:500]}, ], - max_tokens=60, + max_tokens=20, temperature=0.3, - timeout=10, - reasoning_effort="low", + timeout=8, + **llm_params, ) title = response.choices[0].message.content.strip().strip('"').strip("'") - title = title.translate(_TITLE_STRIP_CHARS).strip() + # Safety: cap at 50 chars if len(title) > 50: title = title[:50].rstrip() + "…" - try: - await _check_session_access(request.session_id, user) - await session_manager.update_session_title(request.session_id, title) - except Exception: - logger.debug( - "Skipping title persistence for missing session %s", request.session_id - ) return {"title": title} except Exception as e: logger.warning(f"Title generation failed: {e}") + # Fallback: truncate the message fallback = request.text.strip() title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback - try: - await _check_session_access(request.session_id, user) - await session_manager.update_session_title(request.session_id, title) - except Exception: - logger.debug( - "Skipping fallback title persistence for missing session %s", - request.session_id, - ) return {"title": title} @@ -382,103 +205,26 @@ async def create_session( and stored in the session so that tools (e.g. hf_jobs) can act on behalf of the user. - Optional body ``{"model"?: }`` selects the session's LLM; unknown - ids are rejected (400). The premium-model quota runs at message-submit - time, not here β€” spinning up a session to look around is free. - Returns 503 if the server or user has reached the session limit. """ # Extract the user's HF token (Bearer header, HttpOnly cookie, or env var) - hf_token = resolve_hf_request_token(request) - - # Optional model override. Empty body falls back to the config default. - model: str | None = None - try: - body = await request.json() - except Exception: - body = None - if isinstance(body, dict): - model = body.get("model") - - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: - raise HTTPException(status_code=400, detail=f"Unknown model: {model}") - - # Explicit premium selections are allowed. If the implicit configured - # default is premium, start the session on a free model instead. - model = await _model_override_for_new_session(request, model) - - try: - session_id = await session_manager.create_session( - user_id=user["user_id"], - hf_username=user.get("username"), - hf_token=hf_token, - model=model, - is_pro=user.get("plan") == "pro", - ) - except SessionCapacityError as e: - raise HTTPException(status_code=503, detail=str(e)) - - return SessionResponse( - session_id=session_id, - ready=True, - model=model or session_manager.config.model_name, - ) - - -@router.post("/session/restore-summary", response_model=SessionResponse) -async def restore_session_summary( - request: Request, body: dict, user: dict = Depends(get_current_user) -) -> SessionResponse: - """Create a new session seeded with a summary of the caller's prior - conversation. The client sends its cached messages; we run the standard - summarization prompt on them and drop the result into the new - session's context as a user-role system note. - - Optional ``"model"`` in the body overrides the session's LLM. The - premium-model quota runs at message-submit time, not here. - """ - messages = body.get("messages") - if not isinstance(messages, list) or not messages: - raise HTTPException(status_code=400, detail="Missing 'messages' array") - - hf_token = resolve_hf_request_token(request) - - model = body.get("model") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: - raise HTTPException(status_code=400, detail=f"Unknown model: {model}") - - model = await _model_override_for_new_session(request, model) + hf_token = None + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + hf_token = auth_header[7:] + if not hf_token: + hf_token = request.cookies.get("hf_access_token") + if not hf_token: + hf_token = os.environ.get("HF_TOKEN") try: session_id = await session_manager.create_session( - user_id=user["user_id"], - hf_username=user.get("username"), - hf_token=hf_token, - model=model, - is_pro=user.get("plan") == "pro", + user_id=user["user_id"], hf_token=hf_token ) except SessionCapacityError as e: raise HTTPException(status_code=503, detail=str(e)) - try: - summarized = await session_manager.seed_from_summary(session_id, messages) - except ValueError as e: - raise HTTPException(status_code=500, detail=str(e)) - except Exception as e: - logger.exception("seed_from_summary failed") - raise HTTPException(status_code=500, detail=f"Summary failed: {e}") - - logger.info( - f"Seeded session {session_id} for {user.get('username', 'unknown')} " - f"(summary of {summarized} messages)" - ) - return SessionResponse( - session_id=session_id, - ready=True, - model=model or session_manager.config.model_name, - ) + return SessionResponse(session_id=session_id, ready=True) @router.get("/session/{session_id}", response_model=SessionInfo) @@ -486,142 +232,24 @@ async def get_session( session_id: str, user: dict = Depends(get_current_user) ) -> SessionInfo: """Get session information. Only accessible by the session owner.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) info = session_manager.get_session_info(session_id) return SessionInfo(**info) -@router.post("/session/{session_id}/model") -async def set_session_model( - session_id: str, - body: dict, - request: Request, - user: dict = Depends(get_current_user), -) -> dict: - """Switch the active model for a single session (tab-scoped). - - Takes effect on the next LLM call in that session β€” other sessions - (including other browser tabs) are unaffected. Model switches don't - charge quota β€” the premium-model quota only fires at message-submit time. - """ - agent_session = await _check_session_access(session_id, user, request) - model_id = body.get("model") - if not model_id: - raise HTTPException(status_code=400, detail="Missing 'model' field") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model_id not in valid_ids: - raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") - if not agent_session: - raise HTTPException(status_code=404, detail="Session not found") - await session_manager.update_session_model(session_id, model_id) - logger.info( - f"Session {session_id} model β†’ {model_id} " - f"(by {user.get('username', 'unknown')})" - ) - return {"session_id": session_id, "model": model_id} - - -@router.post("/session/{session_id}/notifications") -async def set_session_notifications( - session_id: str, - body: SessionNotificationsRequest, - user: dict = Depends(get_current_user), -) -> dict: - """Replace the session's auto-notification destinations.""" - agent_session = await _check_session_access(session_id, user) - try: - destinations = session_manager.set_notification_destinations( - session_id, body.destinations - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - await session_manager.persist_session_snapshot(agent_session) - return { - "session_id": session_id, - "notification_destinations": destinations, - } - - -@router.patch("/session/{session_id}/yolo") -async def set_session_yolo( - session_id: str, - body: SessionYoloRequest, - user: dict = Depends(get_current_user), -) -> dict: - """Update the session-scoped auto-approval policy.""" - await _check_session_access(session_id, user) - try: - summary = await session_manager.update_session_auto_approval( - session_id, - enabled=body.enabled, - cost_cap_usd=body.cost_cap_usd, - cap_provided="cost_cap_usd" in body.model_fields_set, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return {"session_id": session_id, **summary} - - -@router.get("/user/quota") -async def get_user_quota(user: dict = Depends(get_current_user)) -> dict: - """Return the user's plan tier and today's premium-model quota state.""" - plan = user.get("plan", "free") - used = await user_quotas.get_claude_used_today(user["user_id"]) - cap = user_quotas.daily_cap_for(plan) - remaining = max(0, cap - used) - return { - "plan": plan, - "premium_used_today": used, - "premium_daily_cap": cap, - "premium_remaining": remaining, - } - - -@router.get("/user/jobs-access") -async def get_jobs_access_info( - request: Request, user: dict = Depends(get_current_user) -) -> dict: - """Return the namespaces the current token can run HF Jobs under. - - Credits are enforced by the HF API at job-creation time, not here β€” - the response only describes which wallets the caller is allowed to - pick from. Pro is irrelevant. - """ - token = resolve_hf_request_token(request) - - access = await get_jobs_access(token or "") - return { - "eligible_namespaces": access.eligible_namespaces if access else [], - "default_namespace": access.default_namespace if access else None, - "billing_url": "https://huggingface.co/settings/billing", - } - - @router.get("/sessions", response_model=list[SessionInfo]) async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]: """List sessions belonging to the authenticated user.""" - sessions = await session_manager.list_sessions(user_id=user["user_id"]) + sessions = session_manager.list_sessions(user_id=user["user_id"]) return [SessionInfo(**s) for s in sessions] -@router.post("/session/{session_id}/sandbox/teardown") -async def teardown_session_sandbox( - session_id: str, user: dict = Depends(get_current_user) -) -> dict: - """Best-effort sandbox teardown that preserves durable chat history.""" - await _check_session_access(session_id, user, preload_sandbox=False) - task = asyncio.create_task(session_manager.teardown_sandbox(session_id)) - _background_teardown_tasks.add(task) - task.add_done_callback(_background_teardown_tasks.discard) - return {"status": "teardown_requested", "session_id": session_id} - - @router.delete("/session/{session_id}") async def delete_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Delete a session. Only accessible by the session owner.""" - await _check_session_access(session_id, user, preload_sandbox=False) + _check_session_access(session_id, user) success = await session_manager.delete_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") @@ -630,41 +258,14 @@ async def delete_session( @router.post("/submit") async def submit_input( - request: Request, user: dict = Depends(get_current_user) + request: SubmitRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit user input to a session. Only accessible by the session owner.""" - # Parse the body manually so session ownership can be checked before the - # text-length constraints fire β€” otherwise a non-owner sending an empty - # or oversized text gets a 422 leaking the constraint instead of the 404 - # they'd get for any other access to a session they don't own. - try: - payload = await request.json() - except (json.JSONDecodeError, TypeError) as exc: - raise HTTPException(status_code=422, detail=str(exc)) - if not isinstance(payload, dict): - raise HTTPException(status_code=422, detail="Body must be a JSON object") - raw_session_id = payload.get("session_id") - if not isinstance(raw_session_id, str) or not raw_session_id: - raise RequestValidationError( - [ - { - "type": "missing", - "loc": ("body", "session_id"), - "msg": "Field required", - "input": payload, - } - ] - ) - agent_session = await _check_session_access(raw_session_id, user) - try: - body = SubmitRequest(**payload) - except ValidationError as exc: - raise RequestValidationError(exc.errors()) from exc - await _enforce_premium_model_quota(user, agent_session) - success = await session_manager.submit_user_input(body.session_id, body.text) + _check_session_access(request.session_id, user) + success = await session_manager.submit_user_input(request.session_id, request.text) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") - return {"status": "submitted", "session_id": body.session_id} + return {"status": "submitted", "session_id": request.session_id} @router.post("/approve") @@ -672,14 +273,13 @@ async def submit_approval( request: ApprovalRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit tool approvals to a session. Only accessible by the session owner.""" - await _check_session_access(request.session_id, user) + _check_session_access(request.session_id, user) approvals = [ { "tool_call_id": a.tool_call_id, "approved": a.approved, "feedback": a.feedback, "edited_script": a.edited_script, - "namespace": a.namespace, } for a in request.approvals ] @@ -696,7 +296,9 @@ async def chat_sse( user: dict = Depends(get_current_user), ) -> StreamingResponse: """SSE endpoint: submit input or approval, then stream events until turn ends.""" - agent_session = await _check_session_access(session_id, user, request) + _check_session_access(session_id, user) + + agent_session = session_manager.sessions.get(session_id) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -712,16 +314,6 @@ async def chat_sse( text = body.get("text") approvals = body.get("approvals") - # Gate user-message sends against the daily premium-model quota. Approvals are - # continuations of an in-progress turn β€” the session was already charged - # on its first message, so we skip the gate there. - if text is not None and not approvals: - try: - await _enforce_premium_model_quota(user, agent_session) - except HTTPException: - broadcaster.unsubscribe(sub_id) - raise - try: if approvals: formatted = [ @@ -730,7 +322,6 @@ async def chat_sse( "approved": a["approved"], "feedback": a.get("feedback"), "edited_script": a.get("edited_script"), - "namespace": a.get("namespace"), } for a in approvals ] @@ -739,15 +330,12 @@ async def chat_sse( success = await session_manager.submit_user_input(session_id, text) else: broadcaster.unsubscribe(sub_id) - raise HTTPException( - status_code=400, detail="Must provide 'text' or 'approvals'" - ) + raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") if not success: broadcaster.unsubscribe(sub_id) raise HTTPException(status_code=404, detail="Session not found or inactive") except HTTPException: - broadcaster.unsubscribe(sub_id) raise except Exception: broadcaster.unsubscribe(sub_id) @@ -756,91 +344,19 @@ async def chat_sse( return _sse_response(broadcaster, event_queue, sub_id) -@router.post("/pro-click/{session_id}") -async def record_pro_click( - session_id: str, - body: dict, - user: dict = Depends(get_current_user), -) -> dict: - """Record a click on a Pro upgrade CTA shown from inside a session.""" - agent_session = await _check_session_access(session_id, user) - - from agent.core import telemetry - - await telemetry.record_pro_cta_click( - agent_session.session, - source=str(body.get("source") or "unknown"), - target=str(body.get("target") or "pro_pricing"), - ) - if agent_session.session.config.save_sessions: - agent_session.session.save_and_upload_detached( - agent_session.session.config.session_dataset_repo - ) - return {"status": "ok"} - - # --------------------------------------------------------------------------- # Shared SSE helpers # --------------------------------------------------------------------------- -_TERMINAL_EVENTS = { - "turn_complete", - "approval_required", - "error", - "interrupted", - "shutdown", -} +_TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"} _SSE_KEEPALIVE_SECONDS = 15 -def _last_event_seq(request: Request) -> int: - raw = ( - request.headers.get("last-event-id") or request.query_params.get("after") or "0" - ) - try: - return max(0, int(raw)) - except (TypeError, ValueError): - return 0 - - -def _format_sse(msg: dict[str, Any]) -> str: - seq = msg.get("seq") - body = {"event_type": msg.get("event_type"), "data": msg.get("data") or {}} - if seq is not None: - body["seq"] = seq - return f"id: {seq}\ndata: {json.dumps(body)}\n\n" - return f"data: {json.dumps(body)}\n\n" - - -def _event_doc_to_msg(doc: dict[str, Any]) -> dict[str, Any]: - return { - "event_type": doc.get("event_type"), - "data": doc.get("data") or {}, - "seq": doc.get("seq"), - } - - -def _sse_response( - broadcaster, - event_queue, - sub_id, - *, - replay_events: list[dict[str, Any]] | None = None, - after_seq: int = 0, -) -> StreamingResponse: +def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse: """Build a StreamingResponse that drains *event_queue* as SSE, sending keepalive comments every 15 s to prevent proxy timeouts.""" async def event_generator(): try: - for doc in replay_events or []: - msg = _event_doc_to_msg(doc) - seq = msg.get("seq") - if isinstance(seq, int) and seq <= after_seq: - continue - yield _format_sse(msg) - if msg.get("event_type", "") in _TERMINAL_EVENTS: - return - while True: try: msg = await asyncio.wait_for( @@ -851,7 +367,7 @@ def _sse_response( yield ": keepalive\n\n" continue event_type = msg.get("event_type", "") - yield _format_sse(msg) + yield f"data: {json.dumps(msg)}\n\n" if event_type in _TERMINAL_EVENTS: break finally: @@ -871,7 +387,6 @@ def _sse_response( @router.get("/events/{session_id}") async def subscribe_events( session_id: str, - request: Request, user: dict = Depends(get_current_user), ) -> StreamingResponse: """Subscribe to events for a running session without submitting new input. @@ -879,23 +394,15 @@ async def subscribe_events( Used by the frontend to re-attach after a connection drop (e.g. screen sleep). Returns 404 if the session isn't active or isn't processing. """ - agent_session = await _check_session_access(session_id, user, request) + _check_session_access(session_id, user) + + agent_session = session_manager.sessions.get(session_id) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") - after_seq = _last_event_seq(request) - replay_events = await session_manager._store().load_events_after( - session_id, after_seq - ) broadcaster = agent_session.broadcaster sub_id, event_queue = broadcaster.subscribe() - return _sse_response( - broadcaster, - event_queue, - sub_id, - replay_events=replay_events, - after_seq=after_seq, - ) + return _sse_response(broadcaster, event_queue, sub_id) @router.post("/interrupt/{session_id}") @@ -903,7 +410,7 @@ async def interrupt_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Interrupt the current operation in a session.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) success = await session_manager.interrupt(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -915,19 +422,17 @@ async def get_session_messages( session_id: str, user: dict = Depends(get_current_user) ) -> list[dict]: """Return the session's message history from memory.""" - agent_session = await _check_session_access(session_id, user) + _check_session_access(session_id, user) + agent_session = session_manager.sessions.get(session_id) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") - return [ - msg.model_dump(mode="json") - for msg in agent_session.session.context_manager.items - ] + return [msg.model_dump() for msg in agent_session.session.context_manager.items] @router.post("/undo/{session_id}") async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict: """Undo the last turn in a session.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) success = await session_manager.undo(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -936,30 +441,13 @@ async def undo_session(session_id: str, user: dict = Depends(get_current_user)) @router.post("/truncate/{session_id}") async def truncate_session( - session_id: str, - request: Request, - user: dict = Depends(get_current_user), + session_id: str, body: TruncateRequest, user: dict = Depends(get_current_user) ) -> dict: """Truncate conversation to before a specific user message.""" - # Check session ownership before parsing the request body so a 404 on a - # non-existent / non-owned session_id beats the 422 schema-validation error - # (otherwise the response leaks the required field name to non-owners). - await _check_session_access(session_id, user) - try: - body = TruncateRequest(**(await request.json())) - except ValidationError as exc: - # Re-raise as RequestValidationError so FastAPI returns its standard - # structured 422 schema (`{"detail": [{"type":..., "loc":..., ...}]}`) - # instead of a string-stringified Pydantic dump. - raise RequestValidationError(exc.errors()) from exc - except (json.JSONDecodeError, TypeError) as exc: - raise HTTPException(status_code=422, detail=str(exc)) + _check_session_access(session_id, user) success = await session_manager.truncate(session_id, body.user_message_index) if not success: - raise HTTPException( - status_code=404, - detail="Session not found, inactive, or message index out of range", - ) + raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range") return {"status": "truncated", "session_id": session_id} @@ -968,7 +456,7 @@ async def compact_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Compact the context in a session.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) success = await session_manager.compact(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -980,44 +468,10 @@ async def shutdown_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Shutdown a session.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) success = await session_manager.shutdown_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} -@router.post("/feedback/{session_id}") -async def submit_feedback( - session_id: str, - body: dict, - user: dict = Depends(get_current_user), -) -> dict: - """Attach a user feedback signal to a session's event log. - - Body: {rating: "up"|"down"|"outcome_success"|"outcome_fail", - turn_index?: int, comment?: str, message_id?: str} - Appended as a `feedback` event and saved with the session trajectory. - """ - agent_session = await _check_session_access(session_id, user) - - rating = body.get("rating") - if rating not in {"up", "down", "outcome_success", "outcome_fail"}: - raise HTTPException(status_code=400, detail="invalid rating") - - from agent.core import telemetry - - await telemetry.record_feedback( - agent_session.session, - rating=rating, - turn_index=body.get("turn_index"), - message_id=body.get("message_id"), - comment=body.get("comment"), - ) - # Fire-and-forget save so feedback reaches the dataset even if the user - # closes the tab right after clicking. - if agent_session.session.config.save_sessions: - agent_session.session.save_and_upload_detached( - agent_session.session.config.session_dataset_repo - ) - return {"status": "ok"} diff --git a/backend/routes/auth.py b/backend/routes/auth.py index d736deff1841dcc89594f2abb8728bc7306741f5..dce10fe2f70f00cb3914e0de81ad900a05c792f7 100644 --- a/backend/routes/auth.py +++ b/backend/routes/auth.py @@ -4,47 +4,28 @@ Handles the OAuth 2.0 authorization code flow with HF as provider. After successful auth, sets an HttpOnly cookie with the access token. """ -import logging import os import secrets import time from urllib.parse import urlencode import httpx -from dependencies import ( - AUTH_ENABLED, - OAUTH_SCOPE_COOKIE, - REQUIRED_OAUTH_SCOPES, - configured_oauth_scopes, - get_current_user, - oauth_scope_fingerprint, -) +from dependencies import AUTH_ENABLED, check_org_membership, get_current_user from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import RedirectResponse router = APIRouter(prefix="/auth", tags=["auth"]) -logger = logging.getLogger(__name__) # OAuth configuration from environment OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "") OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "") OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") -OAUTH_SCOPES = configured_oauth_scopes() # In-memory OAuth state store with expiry (5 min TTL) _OAUTH_STATE_TTL = 300 oauth_states: dict[str, dict] = {} -def _missing_required_scopes(token_data: dict) -> set[str]: - raw_scopes = token_data.get("scope") - if not isinstance(raw_scopes, str) or not raw_scopes.strip(): - logger.debug("OAuth token response omitted a usable scope field") - return set() - granted = set(raw_scopes.split()) - return set(REQUIRED_OAUTH_SCOPES) - granted - - def _cleanup_expired_states() -> None: """Remove expired OAuth states to prevent memory growth.""" now = time.time() @@ -82,15 +63,16 @@ async def oauth_login(request: Request) -> RedirectResponse: "expires_at": time.time() + _OAUTH_STATE_TTL, } - # Build authorization URL. We no longer suggest a default `orgIds` β€” - # users no longer need to join the ML Agent Explorers org to use the - # app, and HF Jobs are billed per-namespace via credits. + # Build authorization URL params = { "client_id": OAUTH_CLIENT_ID, "redirect_uri": get_redirect_uri(request), - "scope": " ".join(OAUTH_SCOPES), + "scope": "openid profile read-repos write-repos contribute-repos manage-repos inference-api jobs write-discussions", "response_type": "code", "state": state, + "orgIds": os.environ.get( + "HF_OAUTH_ORG_ID", "698dbf55845d85df163175f1" + ), # ml-agent-explorers } auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}" @@ -138,15 +120,6 @@ async def oauth_callback( status_code=500, detail="Token exchange succeeded but no access_token was returned.", ) - missing_scopes = _missing_required_scopes(token_data) - if missing_scopes: - raise HTTPException( - status_code=403, - detail=( - "OAuth token is missing required scopes: " - + ", ".join(sorted(missing_scopes)) - ), - ) # Fetch user info (optional β€” failure is not fatal) async with httpx.AsyncClient() as client: @@ -172,15 +145,6 @@ async def oauth_callback( max_age=3600 * 24 * 7, # 7 days path="/", ) - response.set_cookie( - key=OAUTH_SCOPE_COOKIE, - value=oauth_scope_fingerprint(OAUTH_SCOPES), - httponly=True, - secure=is_production, - samesite="lax", - max_age=3600 * 24 * 7, - path="/", - ) return response @@ -189,7 +153,6 @@ async def logout() -> RedirectResponse: """Log out the user by clearing the auth cookie.""" response = RedirectResponse(url="/") response.delete_cookie(key="hf_access_token", path="/") - response.delete_cookie(key=OAUTH_SCOPE_COOKIE, path="/") return response @@ -205,4 +168,21 @@ async def get_me(user: dict = Depends(get_current_user)) -> dict: Uses the shared auth dependency which handles cookie + Bearer token. """ - return {key: value for key, value in user.items() if not key.startswith("_")} + return user + + +ORG_NAME = "ml-agent-explorers" + + +@router.get("/org-membership") +async def org_membership( + request: Request, user: dict = Depends(get_current_user) +) -> dict: + """Check if the authenticated user belongs to the ml-agent-explorers org.""" + if not AUTH_ENABLED: + return {"is_member": True} + token = request.cookies.get("hf_access_token") or "" + if not token: + return {"is_member": False} + is_member = await check_org_membership(token, ORG_NAME) + return {"is_member": is_member} diff --git a/backend/session_manager.py b/backend/session_manager.py index 449ce3a0e06737ec5470d5fafa182c9a92b2eae0..d6a9f5672853969dc8dfb27a512207a536a1d938 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -1,9 +1,7 @@ """Session manager for handling multiple concurrent agent sessions.""" import asyncio -import json import logging -import os import uuid from dataclasses import dataclass, field from datetime import datetime @@ -12,15 +10,12 @@ from typing import Any, Optional from agent.config import load_config from agent.core.agent_loop import process_submission -from agent.core.hub_artifacts import start_session_artifact_collection_task from agent.core.session import Event, OpType, Session -from agent.core.session_persistence import get_session_store from agent.core.tools import ToolRouter -from agent.messaging.gateway import NotificationGateway # Get project root (parent of backend directory) PROJECT_ROOT = Path(__file__).parent.parent -DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "frontend_agent_config.json") +DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json") # These dataclasses match agent/main.py structure @@ -46,8 +41,9 @@ logger = logging.getLogger(__name__) class EventBroadcaster: """Reads from the agent's event queue and fans out to SSE subscribers. - Events that arrive when no subscribers are listening are discarded by - this in-memory fanout. Durable replay is handled by session_persistence. + Events that arrive when no subscribers are listening are discarded. + With SSE each turn is a separate request, so there is no reconnect + scenario that would need buffered replay. """ def __init__(self, event_queue: asyncio.Queue): @@ -71,11 +67,7 @@ class EventBroadcaster: while True: try: event: Event = await self._source.get() - msg = { - "event_type": event.event_type, - "data": event.data, - "seq": event.seq, - } + msg = {"event_type": event.event_type, "data": event.data} for q in self._subscribers.values(): await q.put(msg) except asyncio.CancelledError: @@ -93,18 +85,12 @@ class AgentSession: tool_router: ToolRouter submission_queue: asyncio.Queue user_id: str = "dev" # Owner of this session - hf_username: str | None = None # HF namespace used for personal trace uploads hf_token: str | None = None # User's HF OAuth token for tool execution task: asyncio.Task | None = None created_at: datetime = field(default_factory=datetime.utcnow) is_active: bool = True is_processing: bool = False # True while a submission is being executed broadcaster: Any = None - title: str | None = None - # True once this session has been counted against the user's daily - # Claude quota. Guards double-counting when the user re-selects an - # Anthropic model mid-session. - claude_counted: bool = False class SessionCapacityError(Exception): @@ -116,15 +102,10 @@ class SessionCapacityError(Exception): # ── Capacity limits ───────────────────────────────────────────────── -# Sized for HF Spaces 8 vCPU / 32 GB RAM. -# Each session uses ~10-20 MB (context, tools, queues, task); 200 Γ— 20 MB -# = 4 GB worst case, leaving plenty of headroom for the Python runtime -# and per-request overhead. -MAX_SESSIONS: int = 200 +# Estimated for HF Spaces cpu-basic (2 vCPU, 16 GB RAM). +# Each session uses ~10-20 MB (context, tools, queues, task). +MAX_SESSIONS: int = 50 MAX_SESSIONS_PER_USER: int = 10 -DEFAULT_YOLO_COST_CAP_USD: float = 5.0 -SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10 -SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0 class SessionManager: @@ -132,590 +113,18 @@ class SessionManager: def __init__(self, config_path: str | None = None) -> None: self.config = load_config(config_path or DEFAULT_CONFIG_PATH) - self.messaging_gateway = NotificationGateway(self.config.messaging) self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() - self.persistence_store = None - self.enable_hub_artifact_collections = True - - async def start(self) -> None: - """Start shared background resources.""" - self.persistence_store = get_session_store() - await self.persistence_store.init() - await self.messaging_gateway.start() - - async def close(self) -> None: - """Flush and close shared background resources.""" - await self._cleanup_all_sandboxes_on_close() - await self.messaging_gateway.close() - if self.persistence_store is not None: - await self.persistence_store.close() - - def _store(self): - if self.persistence_store is None: - self.persistence_store = get_session_store() - return self.persistence_store def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" return sum( - 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active + 1 + for s in self.sessions.values() + if s.user_id == user_id and s.is_active ) - def _create_session_sync( - self, - *, - session_id: str, - user_id: str, - hf_username: str | None, - hf_token: str | None, - model: str | None, - event_queue: asyncio.Queue, - notification_destinations: list[str] | None = None, - ) -> tuple[ToolRouter, Session]: - """Build blocking per-session resources in a worker thread.""" - import time as _time - - t0 = _time.monotonic() - tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) - # Deep-copy config so each session's model switches independently β€” - # tab A picking GLM doesn't flip tab B off Claude. - session_config = self.config.model_copy(deep=True) - if model: - session_config.model_name = model - session = Session( - event_queue=event_queue, - config=session_config, - tool_router=tool_router, - hf_token=hf_token, - user_id=user_id, - hf_username=hf_username, - notification_gateway=self.messaging_gateway, - notification_destinations=notification_destinations or [], - session_id=session_id, - persistence_store=self._store(), - ) - t1 = _time.monotonic() - logger.info("Session initialized in %.2fs", t1 - t0) - return tool_router, session - - def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: - return [msg.model_dump(mode="json") for msg in session.context_manager.items] - - def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]: - pending = session.pending_approval or {} - tool_calls = pending.get("tool_calls") or [] - serialized: list[dict[str, Any]] = [] - for tc in tool_calls: - if hasattr(tc, "model_dump"): - serialized.append(tc.model_dump(mode="json")) - elif isinstance(tc, dict): - serialized.append(tc) - return serialized - - @staticmethod - def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None: - pending = session.pending_approval or {} - tool_calls = pending.get("tool_calls") or [] - if not tool_calls: - return None - result: list[dict[str, Any]] = [] - for tc in tool_calls: - try: - args = json.loads(tc.function.arguments) - except (json.JSONDecodeError, AttributeError, TypeError): - args = {} - result.append( - { - "tool": getattr(tc.function, "name", None), - "tool_call_id": getattr(tc, "id", None), - "arguments": args, - } - ) - return result - - def _restore_pending_approval( - self, session: Session, pending_approval: list[dict[str, Any]] | None - ) -> None: - if not pending_approval: - session.pending_approval = None - return - from litellm import ChatCompletionMessageToolCall as ToolCall - - restored = [] - for raw in pending_approval: - try: - if "function" in raw: - restored.append(ToolCall(**raw)) - else: - restored.append( - ToolCall( - id=raw["tool_call_id"], - type="function", - function={ - "name": raw["tool"], - "arguments": json.dumps(raw.get("arguments") or {}), - }, - ) - ) - except Exception as e: - logger.warning("Dropping malformed pending approval: %s", e) - session.pending_approval = {"tool_calls": restored} if restored else None - - @staticmethod - def _pending_docs_for_api( - pending_approval: list[dict[str, Any]] | None, - ) -> list[dict[str, Any]] | None: - if not pending_approval: - return None - result: list[dict[str, Any]] = [] - for raw in pending_approval: - if "function" in raw: - function = raw.get("function") or {} - try: - args = json.loads(function.get("arguments") or "{}") - except (json.JSONDecodeError, TypeError): - args = {} - result.append( - { - "tool": function.get("name"), - "tool_call_id": raw.get("id"), - "arguments": args, - } - ) - elif {"tool", "tool_call_id"}.issubset(raw): - result.append( - { - "tool": raw.get("tool"), - "tool_call_id": raw.get("tool_call_id"), - "arguments": raw.get("arguments") or {}, - } - ) - return result or None - - @staticmethod - def _runtime_state(agent_session: AgentSession) -> str: - if agent_session.session.pending_approval: - return "waiting_approval" - if agent_session.is_processing: - return "processing" - if not agent_session.is_active: - return "ended" - return "idle" - - @staticmethod - def _auto_approval_summary(session: Session) -> dict[str, Any]: - if hasattr(session, "auto_approval_policy_summary"): - return session.auto_approval_policy_summary() - cap = getattr(session, "auto_approval_cost_cap_usd", None) - estimated = float( - getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0 - ) - remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4) - return { - "enabled": bool(getattr(session, "auto_approval_enabled", False)), - "cost_cap_usd": cap, - "estimated_spend_usd": round(estimated, 4), - "remaining_usd": remaining, - } - - async def _start_agent_session( - self, - *, - agent_session: AgentSession, - event_queue: asyncio.Queue, - tool_router: ToolRouter, - ) -> AgentSession: - async with self._lock: - existing = self.sessions.get(agent_session.session_id) - if existing: - return existing - self.sessions[agent_session.session_id] = agent_session - - task = asyncio.create_task( - self._run_session( - agent_session.session_id, - agent_session.submission_queue, - event_queue, - tool_router, - ) - ) - agent_session.task = task - return agent_session - - @staticmethod - def _start_cpu_sandbox_preload(agent_session: AgentSession) -> None: - """Kick off a best-effort cpu-basic sandbox for the session.""" - try: - from agent.tools.sandbox_tool import start_cpu_sandbox_preload - - start_cpu_sandbox_preload(agent_session.session) - except Exception as e: - logger.warning( - "Failed to start CPU sandbox preload for %s: %s", - agent_session.session_id, - e, - ) - - @staticmethod - def _can_access_session(agent_session: AgentSession, user_id: str) -> bool: - return ( - user_id == "dev" - or agent_session.user_id == "dev" - or agent_session.user_id == user_id - ) - - @staticmethod - def _update_hf_identity( - agent_session: AgentSession, - *, - hf_token: str | None, - hf_username: str | None, - ) -> None: - if hf_token: - agent_session.hf_token = hf_token - agent_session.session.hf_token = hf_token - if hf_username: - agent_session.hf_username = hf_username - agent_session.session.hf_username = hf_username - - @staticmethod - def _has_active_sandbox_preload(agent_session: AgentSession) -> bool: - task = getattr(agent_session.session, "sandbox_preload_task", None) - return bool(task and not task.done()) - - @staticmethod - def _preload_failed_for_missing_hf_token(agent_session: AgentSession) -> bool: - error = getattr(agent_session.session, "sandbox_preload_error", None) - return isinstance(error, str) and error.startswith("No HF token available") - - def _restart_cpu_preload_if_token_recovered( - self, - agent_session: AgentSession, - *, - preload_sandbox: bool, - ) -> None: - if not preload_sandbox: - return - session = agent_session.session - if getattr(session, "sandbox", None): - return - if self._has_active_sandbox_preload(agent_session): - return - if not (agent_session.hf_token or getattr(session, "hf_token", None)): - return - - if not self._preload_failed_for_missing_hf_token(agent_session): - return - - session.sandbox_preload_error = None - session.sandbox_preload_task = None - session.sandbox_preload_cancel_event = None - self._start_cpu_sandbox_preload(agent_session) - - def _start_hub_artifact_collection(self, agent_session: AgentSession) -> None: - """Kick off best-effort Hub collection creation for the session.""" - if not getattr(self, "enable_hub_artifact_collections", False): - return - session = agent_session.session - if not getattr(session, "session_id", None): - try: - session.session_id = agent_session.session_id - except Exception: - logger.debug("Could not attach session id for Hub artifact collection") - token = agent_session.hf_token or getattr(session, "hf_token", None) - if not token: - return - try: - start_session_artifact_collection_task(session, token=token) - except Exception as e: - logger.debug( - "Failed to schedule Hub artifact collection for %s: %s", - agent_session.session_id, - e, - ) - - async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None: - try: - await self._store().update_session_fields( - session_id, - sandbox_space_id=None, - sandbox_hardware=None, - sandbox_owner=None, - sandbox_created_at=None, - sandbox_status="destroyed", - ) - except Exception as e: - logger.warning("Failed to clear sandbox metadata for %s: %s", session_id, e) - - async def _cleanup_persisted_sandbox( - self, - session_id: str, - metadata: dict[str, Any], - *, - hf_token: str | None, - ) -> None: - """Delete a sandbox recorded by a previous backend process, if any.""" - space_id = metadata.get("sandbox_space_id") - if not isinstance(space_id, str) or not space_id: - return - if metadata.get("sandbox_status") == "destroyed": - return - - tokens: list[tuple[str, str]] = [] - seen: set[str] = set() - for label, token in ( - ("user", hf_token), - ("admin", os.environ.get("HF_ADMIN_TOKEN")), - ): - if token and token not in seen: - tokens.append((label, token)) - seen.add(token) - - if not tokens: - logger.warning( - "Cannot clean persisted sandbox %s for session %s: no HF token available", - space_id, - session_id, - ) - return - - last_err: Exception | None = None - for label, token in tokens: - try: - from huggingface_hub import HfApi - - api = HfApi(token=token) - await asyncio.to_thread( - api.delete_repo, - repo_id=space_id, - repo_type="space", - ) - logger.info( - "Deleted persisted sandbox %s for session %s with %s token", - space_id, - session_id, - label, - ) - await self._clear_persisted_sandbox_metadata(session_id) - return - except Exception as e: - status_code = getattr(getattr(e, "response", None), "status_code", None) - if status_code == 404: - logger.info( - "Persisted sandbox %s for session %s is already gone", - space_id, - session_id, - ) - await self._clear_persisted_sandbox_metadata(session_id) - return - last_err = e - - logger.warning( - "Failed to delete persisted sandbox %s for session %s: %s", - space_id, - session_id, - last_err, - ) - - async def persist_session_snapshot( - self, - agent_session: AgentSession, - *, - runtime_state: str | None = None, - status: str = "active", - ) -> None: - """Persist the current runtime context snapshot.""" - store = self._store() - if not getattr(store, "enabled", False): - return - try: - await store.save_snapshot( - session_id=agent_session.session_id, - user_id=agent_session.user_id, - model=agent_session.session.config.model_name, - title=agent_session.title, - messages=self._serialize_messages(agent_session.session), - runtime_state=runtime_state or self._runtime_state(agent_session), - status=status, - turn_count=agent_session.session.turn_count, - pending_approval=self._serialize_pending_approval( - agent_session.session - ), - claude_counted=agent_session.claude_counted, - created_at=agent_session.created_at, - notification_destinations=list( - agent_session.session.notification_destinations - ), - auto_approval_enabled=bool( - getattr(agent_session.session, "auto_approval_enabled", False) - ), - auto_approval_cost_cap_usd=getattr( - agent_session.session, "auto_approval_cost_cap_usd", None - ), - auto_approval_estimated_spend_usd=float( - getattr( - agent_session.session, - "auto_approval_estimated_spend_usd", - 0.0, - ) - or 0.0 - ), - ) - except Exception as e: - logger.warning( - "Failed to persist snapshot for %s: %s", - agent_session.session_id, - e, - ) - - async def ensure_session_loaded( - self, - session_id: str, - user_id: str, - hf_token: str | None = None, - hf_username: str | None = None, - preload_sandbox: bool = True, - ) -> AgentSession | None: - """Return a live runtime session, lazily restoring it from Mongo.""" - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, - hf_username=hf_username, - ) - self._restart_cpu_preload_if_token_recovered( - existing, - preload_sandbox=preload_sandbox, - ) - self._start_hub_artifact_collection(existing) - return existing - return None - - store = self._store() - loaded = await store.load_session(session_id) - if not loaded: - return None - - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, - hf_username=hf_username, - ) - self._restart_cpu_preload_if_token_recovered( - existing, - preload_sandbox=preload_sandbox, - ) - self._start_hub_artifact_collection(existing) - return existing - return None - - meta = loaded.get("metadata") or {} - owner = str(meta.get("user_id") or "") - if user_id != "dev" and owner != "dev" and owner != user_id: - return None - - await self._cleanup_persisted_sandbox( - session_id, - meta, - hf_token=hf_token, - ) - - from litellm import Message - - model = meta.get("model") or self.config.model_name - event_queue: asyncio.Queue = asyncio.Queue() - submission_queue: asyncio.Queue = asyncio.Queue() - tool_router, session = await asyncio.to_thread( - self._create_session_sync, - session_id=session_id, - user_id=owner or user_id, - hf_username=hf_username, - hf_token=hf_token, - model=model, - event_queue=event_queue, - notification_destinations=meta.get("notification_destinations") or [], - ) - - restored_messages: list[Message] = [] - for raw in loaded.get("messages") or []: - if not isinstance(raw, dict) or raw.get("role") == "system": - continue - try: - restored_messages.append(Message.model_validate(raw)) - except Exception as e: - logger.warning("Dropping malformed restored message: %s", e) - if restored_messages: - # Keep the freshly-rendered system prompt, then attach the durable - # non-system context so tools/date/user context stay current. - session.context_manager.items = [ - session.context_manager.items[0], - *restored_messages, - ] - - self._restore_pending_approval(session, meta.get("pending_approval") or []) - session.turn_count = int(meta.get("turn_count") or 0) - session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False)) - raw_cap = meta.get("auto_approval_cost_cap_usd") - session.auto_approval_cost_cap_usd = ( - float(raw_cap) if isinstance(raw_cap, int | float) else None - ) - session.auto_approval_estimated_spend_usd = float( - meta.get("auto_approval_estimated_spend_usd") or 0.0 - ) - - created_at = meta.get("created_at") - if not isinstance(created_at, datetime): - created_at = datetime.utcnow() - - agent_session = AgentSession( - session_id=session_id, - session=session, - tool_router=tool_router, - submission_queue=submission_queue, - user_id=owner or user_id, - hf_username=hf_username, - hf_token=hf_token, - created_at=created_at, - is_active=True, - is_processing=False, - claude_counted=bool(meta.get("claude_counted")), - title=meta.get("title"), - ) - started = await self._start_agent_session( - agent_session=agent_session, - event_queue=event_queue, - tool_router=tool_router, - ) - if started is not agent_session: - self._update_hf_identity( - started, - hf_token=hf_token, - hf_username=hf_username, - ) - self._start_hub_artifact_collection(started) - return started - self._start_hub_artifact_collection(agent_session) - if preload_sandbox: - self._start_cpu_sandbox_preload(agent_session) - logger.info("Restored session %s for user %s", session_id, owner or user_id) - return agent_session - - async def create_session( - self, - user_id: str = "dev", - hf_username: str | None = None, - hf_token: str | None = None, - model: str | None = None, - is_pro: bool | None = None, - ) -> str: + async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str: """Create a new agent session and return its ID. Session() and ToolRouter() constructors contain blocking I/O @@ -724,11 +133,6 @@ class SessionManager: Args: user_id: The ID of the user who owns this session. - hf_username: The HF username/namespace used for personal trace uploads. - hf_token: The user's HF OAuth token, stored for tool execution. - model: Optional model override. When set, replaces ``model_name`` - on the per-session config clone. None falls back to the - config default. Raises: SessionCapacityError: If the server or user has reached the @@ -759,15 +163,22 @@ class SessionManager: event_queue: asyncio.Queue = asyncio.Queue() # Run blocking constructors in a thread to keep the event loop responsive. - tool_router, session = await asyncio.to_thread( - self._create_session_sync, - session_id=session_id, - user_id=user_id, - hf_username=hf_username, - hf_token=hf_token, - model=model, - event_queue=event_queue, - ) + # Without this, Session.__init__ β†’ ContextManager β†’ litellm.get_max_tokens() + # blocks all HTTP/SSE handling. + import time as _time + + def _create_session_sync(): + t0 = _time.monotonic() + tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) + session = Session( + event_queue, config=self.config, tool_router=tool_router, + hf_token=hf_token, + ) + t1 = _time.monotonic() + logger.info(f"Session initialized in {t1 - t0:.2f}s") + return tool_router, session + + tool_router, session = await asyncio.to_thread(_create_session_sync) # Create wrapper agent_session = AgentSession( @@ -776,165 +187,31 @@ class SessionManager: tool_router=tool_router, submission_queue=submission_queue, user_id=user_id, - hf_username=hf_username, hf_token=hf_token, ) - await self._start_agent_session( - agent_session=agent_session, - event_queue=event_queue, - tool_router=tool_router, - ) - self._start_hub_artifact_collection(agent_session) - await self.persist_session_snapshot(agent_session, runtime_state="idle") - self._start_cpu_sandbox_preload(agent_session) + async with self._lock: + self.sessions[session_id] = agent_session - if is_pro is not None and user_id and user_id != "dev": - await self._track_pro_status(agent_session, is_pro=is_pro) + # Start the agent loop task + task = asyncio.create_task( + self._run_session(session_id, submission_queue, event_queue, tool_router) + ) + agent_session.task = task logger.info(f"Created session {session_id} for user {user_id}") return session_id - async def _track_pro_status( - self, agent_session: AgentSession, *, is_pro: bool - ) -> None: - """Update Mongo per-user Pro state and emit a one-shot conversion - event if the store reports a freeβ†’Pro transition. Best-effort: any - Mongo failure is swallowed so we never fail session creation on - telemetry.""" - store = self._store() - if not getattr(store, "enabled", False): - return - try: - result = await store.mark_pro_seen(agent_session.user_id, is_pro=is_pro) - except Exception as e: - logger.debug("mark_pro_seen failed: %s", e) - return - if not result or not result.get("converted"): - return - try: - from agent.core import telemetry - - await telemetry.record_pro_conversion( - agent_session.session, - first_seen_at=result.get("first_seen_at"), - ) - except Exception as e: - logger.debug("record_pro_conversion failed: %s", e) - - async def seed_from_summary(self, session_id: str, messages: list[dict]) -> int: - """Rehydrate a session from cached prior messages via summarization. - - Runs the standard summarization prompt (same one compaction uses) - over the provided messages, then seeds the new session's context - with that summary. Tool-call pairing concerns disappear because the - output is plain text. Returns the number of messages summarized. - """ - from litellm import Message - - from agent.context_manager.manager import _RESTORE_PROMPT, summarize_messages - - agent_session = self.sessions.get(session_id) - if not agent_session: - raise ValueError(f"Session {session_id} not found") - - # Parse into Message objects, tolerating malformed entries. - parsed: list[Message] = [] - for raw in messages: - if raw.get("role") == "system": - continue # the new session has its own system prompt - try: - parsed.append(Message.model_validate(raw)) - except Exception as e: - logger.warning("Dropping malformed message during seed: %s", e) - - if not parsed: - return 0 - - session = agent_session.session - # Pass the real tool specs so the summarizer sees what the agent - # actually has β€” otherwise Anthropic's modify_params injects a - # dummy tool and the summarizer editorializes that the original - # tool calls were fabricated. - tool_specs = None - try: - tool_specs = agent_session.tool_router.get_tool_specs_for_llm() - except Exception: - pass - try: - summary, _ = await summarize_messages( - parsed, - model_name=session.config.model_name, - hf_token=session.hf_token, - max_tokens=4000, - prompt=_RESTORE_PROMPT, - tool_specs=tool_specs, - session=session, - kind="restore", - ) - except Exception as e: - logger.error("Summary call failed during seed: %s", e) - raise - - seed = Message( - role="user", - content=( - "[SYSTEM: Your prior memory of this conversation β€” written " - "in your own voice right before restart. Continue from here.]\n\n" - + (summary or "(no summary returned)") - ), - ) - session.context_manager.items.append(seed) - await self.persist_session_snapshot(agent_session, runtime_state="idle") - return len(parsed) - @staticmethod async def _cleanup_sandbox(session: Session) -> None: - """Delete the sandbox Space if one was created for this session. - - Retries on transient failures (HF API 5xx, rate-limit, network blips) - with exponential backoff. A single missed delete = a permanently - orphaned Space, so the cost of an extra retry beats the alternative. - """ - from agent.tools.sandbox_tool import teardown_session_sandbox - - await teardown_session_sandbox(session) - - async def _cleanup_all_sandboxes_on_close(self) -> None: - """Best-effort sandbox cleanup for graceful backend shutdown.""" - async with self._lock: - agent_sessions = list(self.sessions.values()) - if not agent_sessions: - return - - semaphore = asyncio.Semaphore(SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY) - - async def _cleanup_one(agent_session: AgentSession) -> None: - async with semaphore: - try: - await self._cleanup_sandbox(agent_session.session) - except Exception as e: - logger.warning( - "Shutdown sandbox cleanup failed for %s: %s", - agent_session.session_id, - e, - ) - - tasks = [ - asyncio.create_task(_cleanup_one(agent_session)) - for agent_session in agent_sessions - ] - try: - await asyncio.wait_for( - asyncio.gather(*tasks, return_exceptions=True), - timeout=SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, - ) - except asyncio.TimeoutError: - logger.warning( - "Timed out after %.0fs cleaning up sandboxes on shutdown; " - "orphan sweeper will handle any stragglers", - SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, - ) + """Delete the sandbox Space if one was created for this session.""" + sandbox = getattr(session, "sandbox", None) + if sandbox and getattr(sandbox, "_owns_space", False): + try: + logger.info(f"Deleting sandbox {sandbox.space_id}...") + await asyncio.to_thread(sandbox.delete) + except Exception as e: + logger.warning(f"Failed to delete sandbox {sandbox.space_id}: {e}") async def _run_session( self, @@ -971,12 +248,9 @@ class SessionManager: ) agent_session.is_processing = True try: - should_continue = await process_submission( - session, submission - ) + should_continue = await process_submission(session, submission) finally: agent_session.is_processing = False - await self.persist_session_snapshot(agent_session) if not should_continue: break except asyncio.TimeoutError: @@ -999,25 +273,9 @@ class SessionManager: await self._cleanup_sandbox(session) - # Final-flush: always save on session death so we capture ended - # sessions even if the client disconnects without /shutdown. - # Idempotent via session_id key; detached subprocess. - if session.config.save_sessions: - try: - session.save_and_upload_detached( - session.config.session_dataset_repo - ) - except Exception as e: - logger.warning(f"Final-flush failed for {session_id}: {e}") - async with self._lock: if session_id in self.sessions: self.sessions[session_id].is_active = False - await self.persist_session_snapshot( - self.sessions[session_id], - runtime_state="ended", - status="ended", - ) logger.info(f"Session {session_id} ended") @@ -1067,12 +325,7 @@ class SessionManager: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False - success = agent_session.session.context_manager.truncate_to_user_message( - user_message_index - ) - if success: - await self.persist_session_snapshot(agent_session, runtime_state="idle") - return success + return agent_session.session.context_manager.truncate_to_user_message(user_message_index) async def compact(self, session_id: str) -> bool: """Compact context in a session.""" @@ -1097,15 +350,12 @@ class SessionManager: return success async def delete_session(self, session_id: str) -> bool: - """Soft-delete a session and stop its runtime resources.""" + """Delete a session entirely.""" async with self._lock: agent_session = self.sessions.pop(session_id, None) if not agent_session: - await self._store().soft_delete_session(session_id) - return True - - await self._store().soft_delete_session(session_id) + return False # Clean up sandbox Space before cancelling the task await self._cleanup_sandbox(agent_session.session) @@ -1120,68 +370,6 @@ class SessionManager: return True - async def teardown_sandbox(self, session_id: str) -> bool: - """Delete only this session's sandbox runtime, preserving chat state.""" - async with self._lock: - agent_session = self.sessions.get(session_id) - - if not agent_session or not agent_session.is_active: - return False - - await self._cleanup_sandbox(agent_session.session) - await self.persist_session_snapshot(agent_session, runtime_state="idle") - return True - - async def update_session_title(self, session_id: str, title: str | None) -> None: - """Persist a user-visible title for sidebar rehydration.""" - agent_session = self.sessions.get(session_id) - if agent_session: - agent_session.title = title - await self._store().update_session_fields(session_id, title=title) - - async def update_session_model(self, session_id: str, model_id: str) -> bool: - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: - return False - agent_session.session.update_model(model_id) - await self.persist_session_snapshot(agent_session, runtime_state="idle") - return True - - async def update_session_auto_approval( - self, - session_id: str, - *, - enabled: bool, - cost_cap_usd: float | None, - cap_provided: bool = False, - ) -> dict[str, Any]: - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: - raise ValueError("Session not found or inactive") - - session = agent_session.session - if enabled: - if not cap_provided and cost_cap_usd is None: - cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) - if cost_cap_usd is None: - cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD - elif cost_cap_usd is None: - cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD - else: - if not cap_provided: - cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) - - if hasattr(session, "set_auto_approval_policy"): - session.set_auto_approval_policy( - enabled=enabled, - cost_cap_usd=cost_cap_usd, - ) - else: - session.auto_approval_enabled = bool(enabled) - session.auto_approval_cost_cap_usd = cost_cap_usd - await self.persist_session_snapshot(agent_session) - return self._auto_approval_summary(session) - def get_session_owner(self, session_id: str) -> str | None: """Get the user_id that owns a session, or None if session doesn't exist.""" agent_session = self.sessions.get(session_id) @@ -1209,7 +397,22 @@ class SessionManager: if not agent_session: return None - pending_approval = self._pending_tools_for_api(agent_session.session) + # Extract pending approval tools if any + pending_approval = None + pa = agent_session.session.pending_approval + if pa and pa.get("tool_calls"): + pending_approval = [] + for tc in pa["tool_calls"]: + import json + try: + args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError): + args = {} + pending_approval.append({ + "tool": tc.function.name, + "tool_call_id": tc.id, + "arguments": args, + }) return { "session_id": session_id, @@ -1219,107 +422,16 @@ class SessionManager: "message_count": len(agent_session.session.context_manager.items), "user_id": agent_session.user_id, "pending_approval": pending_approval, - "model": agent_session.session.config.model_name, - "title": agent_session.title, - "notification_destinations": list( - agent_session.session.notification_destinations - ), - "auto_approval": self._auto_approval_summary(agent_session.session), } - def set_notification_destinations( - self, session_id: str, destinations: list[str] - ) -> list[str]: - """Replace the session's opted-in auto-notification destinations.""" - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: - raise ValueError("Session not found or inactive") - - normalized: list[str] = [] - seen: set[str] = set() - for raw_name in destinations: - name = raw_name.strip() - if not name: - raise ValueError("Destination names must not be empty") - destination = self.config.messaging.get_destination(name) - if destination is None: - raise ValueError(f"Unknown destination '{name}'") - if not destination.allow_auto_events: - raise ValueError(f"Destination '{name}' is not enabled for auto events") - if name not in seen: - normalized.append(name) - seen.add(name) - - agent_session.session.set_notification_destinations(normalized) - return normalized - - async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: + def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: """List sessions, optionally filtered by user. Args: user_id: If provided, only return sessions owned by this user. If "dev", return all sessions (dev mode). """ - results: list[dict[str, Any]] = [] - store = self._store() - if getattr(store, "enabled", False): - for row in await store.list_sessions(user_id or "dev"): - sid = row.get("session_id") or row.get("_id") - if not sid: - continue - runtime_info = self.get_session_info(str(sid)) - if runtime_info: - results.append(runtime_info) - continue - created_at = row.get("created_at") - if isinstance(created_at, datetime): - created_at_str = created_at.isoformat() - else: - created_at_str = str(created_at or datetime.utcnow().isoformat()) - pending = self._pending_docs_for_api(row.get("pending_approval") or []) - results.append( - { - "session_id": str(sid), - "created_at": created_at_str, - "is_active": row.get("status") != "ended", - "is_processing": row.get("runtime_state") == "processing", - "message_count": int(row.get("message_count") or 0), - "user_id": row.get("user_id") or "dev", - "pending_approval": pending or None, - "model": row.get("model"), - "title": row.get("title"), - "notification_destinations": row.get( - "notification_destinations" - ) - or [], - "auto_approval": { - "enabled": bool(row.get("auto_approval_enabled", False)), - "cost_cap_usd": row.get("auto_approval_cost_cap_usd"), - "estimated_spend_usd": float( - row.get("auto_approval_estimated_spend_usd") or 0.0 - ), - "remaining_usd": ( - None - if row.get("auto_approval_cost_cap_usd") is None - else round( - max( - 0.0, - float( - row.get("auto_approval_cost_cap_usd") or 0.0 - ) - - float( - row.get("auto_approval_estimated_spend_usd") - or 0.0 - ), - ), - 4, - ) - ), - }, - } - ) - return results - + results = [] for sid in self.sessions: info = self.get_session_info(sid) if not info: diff --git a/backend/start.sh b/backend/start.sh index 72b35198f89ef41a73c3119843d2ac21a9cf0a42..1fa7a8bf74bf73f4db750d52a99882b33523212a 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -1,15 +1,31 @@ #!/bin/bash # Entrypoint for HF Spaces dev mode compatibility. -# Dev mode spawns CMD multiple times simultaneously on restart. -# Only the first instance can bind port 7860 β€” the rest must exit -# with code 0 so the dev mode daemon doesn't mark the app as crashed. - -# Run uvicorn; if it fails due to port conflict, exit cleanly. -uvicorn main:app --host 0.0.0.0 --port 7860 -EXIT_CODE=$? - -if [ $EXIT_CODE -ne 0 ]; then - # Check if this was a port-in-use failure (another instance already running) - echo "uvicorn exited with code $EXIT_CODE, exiting gracefully." - exit 0 -fi +# Dev mode spawns multiple CMD instances simultaneously on restart. +# The old process may still hold port 7860 briefly, so we retry +# with backoff until the port is free. + +MAX_RETRIES=5 +RETRY_DELAY=2 + +for i in $(seq 1 $MAX_RETRIES); do + uvicorn main:app --host 0.0.0.0 --port 7860 + EXIT_CODE=$? + + if [ $EXIT_CODE -eq 0 ]; then + exit 0 + fi + + # Check if another instance from this restart batch is already running + if ss -tlnp 2>/dev/null | grep -q ":7860 "; then + echo "Port 7860 already bound by another instance, exiting." + exit 0 + fi + + if [ $i -lt $MAX_RETRIES ]; then + echo "uvicorn exited ($EXIT_CODE), retrying in ${RETRY_DELAY}s (attempt $i/$MAX_RETRIES)..." + sleep $RETRY_DELAY + fi +done + +echo "Failed to bind port 7860 after $MAX_RETRIES attempts." +exit 1 diff --git a/backend/user_quotas.py b/backend/user_quotas.py deleted file mode 100644 index 4da4a8d91b755b64e165f82ba8940b0e1b19ae38..0000000000000000000000000000000000000000 --- a/backend/user_quotas.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Daily quota for premium model session creations. - -Tracks per-user premium model session starts against a daily cap derived from -the user's HF plan. MongoDB is the source of truth when configured; the -in-process dict remains the fallback for local/dev/test runs. - -The public names still say ``claude`` because this quota bucket originally -only covered Claude and the persisted session field uses that name. - -Unit: session *creations*, not messages. A user who sends with a premium model -in a new session consumes one quota point; switching an already-counted session -back to a premium model doesn't (`AgentSession.claude_counted` guards that). - -Cap tiers: - free user β†’ CLAUDE_FREE_DAILY (1) - pro user β†’ CLAUDE_PRO_DAILY (20) -""" - -import asyncio -import os -from datetime import UTC, datetime - -from agent.core.session_persistence import ( - NoopSessionStore, - get_session_store, - _reset_store_for_tests, -) - -CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1")) -CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20")) - -# user_id -> (day_utc_iso, count_for_that_day) -_claude_counts: dict[str, tuple[str, int]] = {} -_lock = asyncio.Lock() - - -def _today() -> str: - return datetime.now(UTC).date().isoformat() - - -def daily_cap_for(plan: str | None) -> int: - """Return the daily Claude-session cap for the given plan.""" - return CLAUDE_PRO_DAILY if plan == "pro" else CLAUDE_FREE_DAILY - - -async def get_claude_used_today(user_id: str) -> int: - """Return today's Claude session count for the user (0 if none / stale day).""" - store = get_session_store() - if getattr(store, "enabled", False): - db_count = await store.get_quota(user_id, _today()) - return db_count or 0 - - async with _lock: - entry = _claude_counts.get(user_id) - if entry is None: - return 0 - day, count = entry - if day != _today(): - # Stale day β€” drop the entry so the first increment starts fresh. - _claude_counts.pop(user_id, None) - return 0 - return count - - -async def increment_claude(user_id: str) -> int: - """Bump today's Claude session count for the user. Returns the new value.""" - store = get_session_store() - if getattr(store, "enabled", False): - db_count = await store.try_increment_quota(user_id, _today(), cap=10**9) - return db_count or 0 - - async with _lock: - today = _today() - day, count = _claude_counts.get(user_id, (today, 0)) - if day != today: - count = 0 - count += 1 - _claude_counts[user_id] = (today, count) - return count - - -async def try_increment_claude(user_id: str, cap: int) -> int | None: - """Atomically bump today's count if below *cap*. - - Returns the new count, or None when the user is already at the cap. - """ - store = get_session_store() - if getattr(store, "enabled", False): - return await store.try_increment_quota(user_id, _today(), cap) - - async with _lock: - today = _today() - day, count = _claude_counts.get(user_id, (today, 0)) - if day != today: - count = 0 - if count >= cap: - return None - count += 1 - _claude_counts[user_id] = (today, count) - return count - - -async def refund_claude(user_id: str) -> None: - """Decrement today's count β€” used when session creation fails after a successful gate.""" - store = get_session_store() - if getattr(store, "enabled", False): - await store.refund_quota(user_id, _today()) - return - - async with _lock: - entry = _claude_counts.get(user_id) - if entry is None: - return - day, count = entry - if day != _today(): - _claude_counts.pop(user_id, None) - return - new_count = max(0, count - 1) - if new_count == 0: - _claude_counts.pop(user_id, None) - else: - _claude_counts[user_id] = (day, new_count) - - -def _reset_for_tests() -> None: - """Test-only: clear the in-memory store.""" - _claude_counts.clear() - _reset_store_for_tests(NoopSessionStore()) diff --git a/configs/cli_agent_config.json b/configs/cli_agent_config.json deleted file mode 100644 index ed247998688a102f143b22b1a76538d0aa02520b..0000000000000000000000000000000000000000 --- a/configs/cli_agent_config.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "model_name": "anthropic/claude-opus-4-6", - "save_sessions": true, - "session_dataset_repo": "smolagents/ml-intern-sessions", - "share_traces": true, - "personal_trace_repo_template": "{hf_user}/ml-intern-sessions", - "yolo_mode": false, - "confirm_cpu_jobs": true, - "auto_file_upload": true, - "messaging": { - "enabled": false, - "auto_event_types": ["approval_required", "error", "turn_complete"], - "destinations": {} - }, - "mcpServers": { - "hf-mcp-server": { - "transport": "http", - "url": "https://huggingface.co/mcp?login" - } - } -} diff --git a/configs/frontend_agent_config.json b/configs/frontend_agent_config.json deleted file mode 100644 index c674a223b018967b7ab4482f3228b0b58d054dd3..0000000000000000000000000000000000000000 --- a/configs/frontend_agent_config.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "model_name": "${ML_INTERN_CLAUDE_MODEL_ID:-bedrock/us.anthropic.claude-opus-4-6-v1}", - "save_sessions": true, - "session_dataset_repo": "smolagents/ml-intern-sessions", - "share_traces": true, - "personal_trace_repo_template": "{hf_user}/ml-intern-sessions", - "yolo_mode": false, - "confirm_cpu_jobs": true, - "auto_file_upload": true, - "mcpServers": { - "hf-mcp-server": { - "transport": "http", - "url": "https://huggingface.co/mcp?login" - } - } -} diff --git a/configs/main_agent_config.json b/configs/main_agent_config.json new file mode 100644 index 0000000000000000000000000000000000000000..5e90e42498c7e3d6d5b314fb0899e33e22db40fe --- /dev/null +++ b/configs/main_agent_config.json @@ -0,0 +1,14 @@ +{ + "model_name": "anthropic/claude-opus-4-6", + "save_sessions": true, + "session_dataset_repo": "akseljoonas/hf-agent-sessions", + "yolo_mode": false, + "confirm_cpu_jobs": true, + "auto_file_upload": true, + "mcpServers": { + "hf-mcp-server": { + "transport": "http", + "url": "https://huggingface.co/mcp?login" + } + } +} diff --git a/create-pr.sh b/create-pr.sh new file mode 100644 index 0000000000000000000000000000000000000000..24bbc1ed58e8b6ef72143ef1f470c773f5775b27 --- /dev/null +++ b/create-pr.sh @@ -0,0 +1,123 @@ +#!/bin/bash +set -e + +# Colors for output +GREEN='\033[0;32m' +BLUE='\033[0;34m' +RED='\033[0;31m' +NC='\033[0m' # No Color + +# Check arguments +if [ $# -lt 1 ]; then + echo -e "${RED}Usage: ./create-pr.sh \"PR Title\" [\"Optional description\"]${NC}" + echo "" + echo "Example:" + echo " ./create-pr.sh \"Fix authentication bug\" \"This fixes the dev mode auth issue\"" + exit 1 +fi + +TITLE="$1" +DESCRIPTION="${2:-}" + +# Get current branch +BRANCH=$(git rev-parse --abbrev-ref HEAD) + +if [ "$BRANCH" = "main" ]; then + echo -e "${RED}Error: You're on the main branch. Please create a feature branch first.${NC}" + exit 1 +fi + +echo -e "${BLUE}Creating PR for branch: ${GREEN}$BRANCH${NC}" +echo -e "${BLUE}Title: ${GREEN}$TITLE${NC}" + +# Get HF_TOKEN from .env +if [ ! -f .env ]; then + echo -e "${RED}Error: .env file not found${NC}" + exit 1 +fi + +HF_TOKEN=$(grep HF_TOKEN .env | cut -d '=' -f2) + +if [ -z "$HF_TOKEN" ]; then + echo -e "${RED}Error: HF_TOKEN not found in .env${NC}" + exit 1 +fi + +# Get list of changed files +echo -e "${BLUE}Detecting changed files...${NC}" +CHANGED_FILES=$(git diff --name-only main.."$BRANCH") + +if [ -z "$CHANGED_FILES" ]; then + echo -e "${RED}Error: No changes detected between main and $BRANCH${NC}" + exit 1 +fi + +echo -e "${BLUE}Changed files:${NC}" +echo "$CHANGED_FILES" | while read -r file; do + echo -e " ${GREEN}$file${NC}" +done + +# Create PR using HuggingFace API with actual file operations +echo -e "${BLUE}Creating pull request with file changes...${NC}" + +PR_URL=$(HF_TOKEN="$HF_TOKEN" uv run python - <` to swap them in without touching the task. +- `eval/task.py` registers `hf-benchmark-with-rubrics`, which wires + the dataset, solver, and rubric scorer into a single Inspect task and does the eval. + +### Running the hf-agent (implemented in `agent/`) (args are optional) +```bash +uv run inspect eval eval/task.py@hf-benchmark-with-rubrics \ + -T dataset_name=akseljoonas/hf-agent-rubrics \ + -T dataset_split=train \ + -T limit=25 \ + -T solver_name=hf_agent \ + -T solver_kwargs='{"config_path":"agent/config_mcp_example.json","max_iterations":10}' \ + --log-dir logs/inspect +``` + +Different benchmarks can be used by making/running a new task in `eval/task.py`. + +### Running Claude Code headlessly + +The `claude_code` solver shell-outs to the `claude` CLI (`claude -p ... --output-format json`) +so you can benchmark Claude Code without any interactive UI. Example: + +Claude Code command example (kwargs are optional): +```bash +uv run inspect eval eval/task.py@hf-benchmark-with-rubrics \ + -T solver_name=claude_code \ + -T solver_kwargs='{"allowed_tools":"Bash,Read","output_format":"json"}' +``` + +### Leaderboard + +Scores can be pushed to a Hugging Face dataset automatically by wrapping the run +with `eval/run_eval_with_leaderboard.py` (it executes `inspect eval ...` under the hood +and only appends results when the command succeeds): + +```bash +uv run python eval/run_eval_with_leaderboard.py \ + --hf-dataset akseljoonas/hf-agent-leaderboard \ + --hf-token $HF_TOKEN \ + --solver-name hf_agent \ + --solver-kwargs '{"config_path":"agent/config_mcp_example.json","max_iterations":10}' \ + --dataset akseljoonas/hf-agent-rubrics@train \ + --limit 25 +``` + +## Scoring (implemented in `eval/rubric_eval.py`) + +The scoring is implemented in `eval/rubric_eval.py` and is based on the RaR-Explicit formula: `score = Ξ£(weight Γ— satisfied) / Ξ£(positive_weights)`. + +The score is normalized to [0, 1] and clipped if pitfalls make it negative. diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c661b764c811bd99adc8cdabbc29e8275774c97b --- /dev/null +++ b/eval/__init__.py @@ -0,0 +1,3 @@ +from eval.task import hf_benchmark_with_rubrics + +__all__ = ["hf_benchmark_with_rubrics"] diff --git a/eval/check_completeness.py b/eval/check_completeness.py new file mode 100644 index 0000000000000000000000000000000000000000..94790bce8c28a54da6b927ecff284bc6966dd3b7 --- /dev/null +++ b/eval/check_completeness.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Minimal script to check if tasks in solved_tasks.jsonl were fully completed and verified. +Uses an LLM to assess completion status and adds the result to each row. +""" + +import argparse +import json +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed + +import litellm +from dotenv import load_dotenv +from pydantic import BaseModel + +load_dotenv() + + +class CompletionCheck(BaseModel): + reasoning: str + completed: bool + verified: bool + + +PROMPT = """You are evaluating whether an AI agent fully completed a task AND verified its completion. + +Task: {question} + +Agent's final answer: {solution} + +Agent's trace (tool calls and responses): +{trace} + +Evaluate: +1. **completed**: Did the agent actually complete the task? (not just explain what could be done, but actually do it) +2. **verified**: Did the agent verify/confirm that the task was completed correctly? (e.g., checked output, validated results, confirmed success) + +Be strict: +- If the agent asked for more information or said "please provide...", it's NOT completed. +- If the agent only explained how to do something but didn't do it, it's NOT completed. +- If the agent just made a plan of how to complete it but didn't do it, it's NOT completed. +- If there's an error in the trace and no recovery, it's NOT completed. +- If the agent didn't check/confirm the code/command completed succesfully or the result is correct somehow, it's NOT verified. + +Return JSON with: completed (bool), verified (bool), reasoning (brief explanation).""" + + +def format_trace(messages: list) -> str: + """Format messages trace for the prompt.""" + if not messages: + return "(No trace)" + + parts = [] + for msg in messages: + role = msg.get("role", "unknown") + if role == "system": + continue + + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", []) + + if tool_calls: + for tc in tool_calls: + if isinstance(tc, dict) and "function" in tc: + name = tc["function"].get("name", "?") + parts.append(f"[TOOL CALL] {name}") + + if content: + # Truncate long content + if len(content) > 5000: + content = content[:4000] + "..." + content[-1000:] + parts.append(f"[{role.upper()}] {content}") + + return "\n".join(parts) if parts else "(Empty trace)" + + +def check_row(row: dict, model: str) -> CompletionCheck | None: + """Check if a single task was completed and verified.""" + prompt = PROMPT.format( + question=row["question"], + solution=row.get("solution", "(No solution)"), + trace=format_trace(row.get("messages", [])), + ) + + try: + response = litellm.completion( + model=model, + messages=[{"role": "user", "content": prompt}], + response_format=CompletionCheck, + timeout=60, + ) + return CompletionCheck.model_validate_json(response.choices[0].message.content) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return None + + +def main(): + parser = argparse.ArgumentParser(description="Check task completion status") + parser.add_argument("--infile", type=str, default="eval/solved_tasks.jsonl") + parser.add_argument( + "--outfile", type=str, default="eval/solved_tasks_checked.jsonl" + ) + parser.add_argument( + "--model", type=str, default="anthropic/claude-sonnet-4-5-20250929" + ) + parser.add_argument("--max-concurrent", type=int, default=30) + args = parser.parse_args() + + # Load data + print(f"Loading {args.infile}...") + rows = [] + with open(args.infile) as f: + for line in f: + rows.append(json.loads(line)) + print(f"Loaded {len(rows)} rows") + + # Process in parallel + print(f"Checking completion with {args.model}...") + with ThreadPoolExecutor(max_workers=args.max_concurrent) as executor: + futures = { + executor.submit(check_row, row, args.model): i for i, row in enumerate(rows) + } + results = [None] * len(rows) + + for future in as_completed(futures): + idx = futures[future] + results[idx] = future.result() + print( + f"Done: {sum(1 for r in results if r is not None)}/{len(rows)}", + end="\r", + ) + + print() + + # Merge results + output_rows = [] + for row, result in zip(rows, results): + if result: + row["task_completed"] = result.completed + row["task_verified"] = result.verified + row["completion_reasoning"] = result.reasoning + else: + row["task_completed"] = None + row["task_verified"] = None + row["completion_reasoning"] = "Error during check" + output_rows.append(row) + + # Write output + print(f"Writing to {args.outfile}...") + with open(args.outfile, "w") as f: + for row in output_rows: + f.write(json.dumps(row, default=str) + "\n") + + # Summary + completed = sum(1 for r in results if r and r.completed) + verified = sum(1 for r in results if r and r.verified) + print("\nSummary:") + print(f" Completed: {completed}/{len(rows)}") + print(f" Verified: {verified}/{len(rows)}") + + +if __name__ == "__main__": + main() diff --git a/eval/claude_batch_solve.py b/eval/claude_batch_solve.py new file mode 100644 index 0000000000000000000000000000000000000000..154e23fd3b8a7f27e6b7559eaf3c7933e04cae39 --- /dev/null +++ b/eval/claude_batch_solve.py @@ -0,0 +1,230 @@ +import asyncio +import json +import os +import threading +from pathlib import Path +from typing import Any + +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, + query, +) +from dotenv import load_dotenv + +load_dotenv() + +# Thread-safe file writing +file_lock = threading.Lock() + + +def convert_message_to_chat_format(message: Any) -> dict | None: + """Convert SDK message to standard chat format with role/content/tool_calls.""" + + if isinstance(message, SystemMessage): + # Extract tools list from init data for system message + if message.subtype == "init": + tools = message.data.get("tools", []) + tools_desc = "\n".join(f"- {tool}" for tool in tools) + return { + "role": "system", + "content": f"You are a helpful assistant with access to the following tools:\n{tools_desc}", + } + return None + + elif isinstance(message, AssistantMessage): + text_content = "" + tool_calls = [] + + for block in message.content: + if isinstance(block, TextBlock): + text_content += block.text + elif isinstance(block, ToolUseBlock): + tool_calls.append( + { + "id": block.id, + "function": { + "name": block.name, + "arguments": block.input, + }, + } + ) + + result = {"role": "assistant", "content": text_content} + if tool_calls: + result["tool_calls"] = tool_calls + return result + + elif isinstance(message, UserMessage): + # UserMessage can contain tool results or text + if isinstance(message.content, str): + return {"role": "user", "content": message.content} + elif isinstance(message.content, list): + # Check for tool results + tool_results = [] + text_content = "" + for block in message.content: + if isinstance(block, ToolResultBlock): + # Format tool result content + if isinstance(block.content, str): + content = block.content + elif isinstance(block.content, list): + content = json.dumps(block.content) + else: + content = str(block.content) if block.content else "" + + tool_results.append( + { + "tool_use_id": block.tool_use_id, + "content": content, + "is_error": block.is_error, + } + ) + elif isinstance(block, TextBlock): + text_content += block.text + + if tool_results: + return { + "role": "user", + "content": f"\n{json.dumps(tool_results, indent=2)}\n", + } + else: + return {"role": "user", "content": text_content} + return None + + elif isinstance(message, ResultMessage): + # ResultMessage is metadata, not a conversation message + return None + + return None + + +async def solve_task( + question: str, + difficulty: str, + task_idx: int, + total: int, + semaphore: asyncio.Semaphore, +) -> dict: + """Solve a single task using Claude Agent SDK.""" + async with semaphore: + print(f"[{task_idx}/{total}] Starting: {question[:60]}...") + + messages = [] + solution = None + + try: + async for message in query( + prompt=question, + options=ClaudeAgentOptions( + cwd=os.getcwd(), + permission_mode="bypassPermissions", + disallowed_tools=["Write", "Edit", "Bash", "Glob", "Grep"], + mcp_servers={ + "huggingface": { + "type": "http", + "url": "https://huggingface.co/mcp", + "headers": { + "Authorization": f"Bearer {os.environ['HF_TOKEN']}" + }, + } + }, + ), + ): + # Convert to chat format and append if valid + chat_msg = convert_message_to_chat_format(message) + if chat_msg: + messages.append(chat_msg) + + # Extract text from assistant messages + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + solution = block.text + # Check for result messages + elif isinstance(message, ResultMessage): + if message.is_error: + print(f"[{task_idx}/{total}] βœ— Agent error: {message.subtype}") + return { + "question": question, + "difficulty": difficulty, + "solution": None, + "messages": messages, + "error": f"Agent error: {message.subtype}", + } + elif message.result: + solution = message.result + + print(f"[{task_idx}/{total}] βœ“ Done: {question[:60]}...") + return { + "question": question, + "difficulty": difficulty, + "solution": solution, + "messages": messages, + "error": None, + } + except Exception as e: + print(f"[{task_idx}/{total}] βœ— Error: {e}") + return { + "question": question, + "difficulty": difficulty, + "solution": None, + "messages": messages, + "error": str(e), + } + + +def write_result(output_path: Path, result: dict): + """Thread-safe write to output file.""" + with file_lock: + with open(output_path, "a") as f: + f.write(json.dumps(result) + "\n") + + +async def main(): + # Load tasks from filled_tasks.jsonl + tasks_path = Path(__file__).parent / "filled_tasks.jsonl" + tasks = [] + with open(tasks_path) as f: + for line in f: + tasks.append(json.loads(line)) + + # Output file - clear it first + output_path = Path(__file__).parent / "solved_tasks.jsonl" + output_path.write_text("") + + # Semaphore to limit concurrency + max_concurrent = 5 + semaphore = asyncio.Semaphore(max_concurrent) + + total = len(tasks) + print(f"Processing {total} tasks with {max_concurrent} concurrent agents...") + + async def process_and_save(task: dict, idx: int): + result = await solve_task( + task["question"], task["difficulty"], idx, total, semaphore + ) + write_result(output_path, result) + return result + + # Create all tasks + coroutines = [process_and_save(task, i + 1) for i, task in enumerate(tasks)] + + # Run all concurrently (semaphore limits actual parallelism) + results = await asyncio.gather(*coroutines, return_exceptions=True) + + successful = sum( + 1 for r in results if isinstance(r, dict) and r.get("error") is None + ) + print(f"\nCompleted: {successful}/{total} successful") + print(f"Results saved to {output_path}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/eval/create_eval_dataset.py b/eval/create_eval_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7a56c3d83914034aa1c0b070d6cece13327ad7fc --- /dev/null +++ b/eval/create_eval_dataset.py @@ -0,0 +1,160 @@ +from itertools import product + +from datasets import Dataset + +# Task templates (excluding Very hard difficulty) +tasks = [ + { + "task": "Evaluate models {M} on benchmarks {B}", + "difficulty": "Easy", + "category": "Evaluation", + "params": ["M", "B"], + }, + { + "task": "Train models {M} on datasets {D} evaluating them on benchmarks {B}", + "difficulty": "Medium", + "category": "Training", + "params": ["M", "D", "B"], + }, + { + "task": "Run an ablation for hyperparameter {P} for model {M} on dataset {D}", + "difficulty": "Hard", + "category": "Ablation", + "params": ["P", "M", "D"], + }, + { + "task": "Generate completions with model {M} on benchmarks {B} using engine {E}", + "difficulty": "Medium", + "category": "Generation", + "params": ["M", "B", "E"], + }, + # { + # "task": "Merge models {M} using linear averaging to find the best result on benchmarks {B}", + # "difficulty": "Hard", + # "category": "Model Merging", + # "params": ["M", "B"], + # }, + { + "task": "Decontaminate dataset {D} against benchmarks {B}", + "difficulty": "Hard", + "category": "Data Processing", + "params": ["D", "B"], + }, + { + "task": "Format dataset {D} for compatibility with framework {F} on task {T}", + "difficulty": "Easy", + "category": "Data Formatting", + "params": ["D", "F", "T"], + }, +] + +# Parameter values +values = { + "M": [ + "Qwen/Qwen3-4B-Instruct-2507", + "openai/gpt-oss-20b", + "gpt-4o-mini", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "anthropic's latest model", + ], + "B": [ + "Idavidrein/gpqa", + "HuggingFaceH4/MATH-500", + "lighteval/SimpleQA", + "TIGER-Lab/MMLU-Pro", + ], + "D": [ + "HuggingFaceH4/multi_turn_if", + "HuggingFaceH4/ultrachat_200k", + "HuggingFaceH4/AceReason-1.1-SFT config: math_no_think", + ], + "E": [ + "vllm", + "sglang", + ], + "F": [ + "trl", + "axolotl", + "verl", + ], + "P": [ + "learning_rate", + "batch_size", + "num_epochs", + ], + "T": [ + "SFT", + "GRPO", + ], +} + +# Task-specific instance limits +# For each task, specify which parameter(s) to pivot on and how many instances per pivot combination +# pivot can be a single parameter string or a list of parameters +task_limits = [ + {"pivot": "B", "instances_per_pivot": 1}, # Task 0: 1 instance per + {"pivot": ["M", "B"], "instances_per_pivot": 3}, # Task 1: 3 instances per model + {"pivot": ["P", "D"], "instances_per_pivot": 3}, # Task 2: + {"pivot": "E", "instances_per_pivot": 2}, # Task 3: 2 instances per benchmark + # {"pivot": "M", "instances_per_pivot": 2}, # Task 4 + {"pivot": "D", "instances_per_pivot": 2}, # Task 5: 2 instances per dataset + {"pivot": ["D", "F", "T"], "instances_per_pivot": 2}, # Task 6: +] + + +def main(): + eval_data = [] + + for task_idx, task_dict in enumerate(tasks): + template = task_dict["task"] + params = task_dict["params"] + limit_config = task_limits[task_idx] + + pivot_params = limit_config["pivot"] + instances_per_pivot = limit_config["instances_per_pivot"] + + # Normalize pivot to list + if isinstance(pivot_params, str): + pivot_params = [pivot_params] + + # Get all combinations of pivot values + pivot_param_values = [values[p] for p in pivot_params] + pivot_combinations = product(*pivot_param_values) + + # For each pivot combination, generate limited instances + for pivot_combo in pivot_combinations: + # Get combinations of other (non-pivot) parameters + other_params = [p for p in params if p not in pivot_params] + other_param_values = [values[p] for p in other_params] + other_combinations = list(product(*other_param_values)) + + # Limit to specified number of instances per pivot combination + limited_combinations = other_combinations[:instances_per_pivot] + + # Generate instances + for combo in limited_combinations: + # Build kwargs with pivot values and other values + kwargs = dict(zip(pivot_params, pivot_combo)) + kwargs.update(dict(zip(other_params, combo))) + + concrete_task = template.format(**kwargs) + eval_data.append( + { + "task": concrete_task, + "difficulty": task_dict["difficulty"], + "category": task_dict["category"], + } + ) + + print(f"Generated {len(eval_data)} instances from {len(tasks)} templates") + + dataset = Dataset.from_list(eval_data) + print(f"\nDataset: {len(dataset)} rows") + print(f"Sample: {dataset[0]['task']}") + + dataset.push_to_hub("akseljoonas/qyestions", private=False) + print("\nβœ“ Pushed to akseljoonas/qyestions") + + +if __name__ == "__main__": + main() diff --git a/eval/eval_set.ipynb b/eval/eval_set.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3515f34727e3ff64b71b07fbc9cfd036aacf4995 --- /dev/null +++ b/eval/eval_set.ipynb @@ -0,0 +1,755 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "febne6uj10o", + "source": "#!/usr/bin/env python3\n\"\"\"Script to create HuggingFace 401 error fix documentation\"\"\"\n\nimport os\nfrom pathlib import Path\n\n# The full content of the documentation\ndocumentation_content = \"\"\"# HuggingFace 401 Unauthorized Error Fix - Dataset Push to HuggingFaceFW/fineweb-edu\n\n## Problem Summary\n\nWhen attempting to push to the HuggingFace dataset repository `HuggingFaceFW/fineweb-edu`, users may encounter a **401 Unauthorized** error. This is a large-scale educational dataset (1.3T tokens, 5.4TB) that requires proper authentication, token permissions, and git-lfs configuration for successful uploads.\n\n**Authenticated User:** akseljoonas \n**Repository:** HuggingFaceFW/fineweb-edu (dataset) \n**Repository Stats:** 5.3M downloads | 873 likes | Last updated: July 11, 2025\n\n---\n\n## Root Causes of 401 Errors\n\nBased on recent issues (2025) and HuggingFace documentation, 401 errors typically stem from:\n\n### 1. **Insufficient Token Permissions**\n- Token lacks **write** permission (only has read access)\n- Token is expired or invalid\n- Using organization token instead of personal access token\n\n### 2. **Git Credential Configuration Issues**\n- Token not saved to git credential helper\n- Git attempting to use cached incorrect credentials\n- Missing `--add-to-git-credential` flag during login\n\n### 3. **Git-LFS Authentication Failures**\n- Git-LFS not properly configured\n- LFS files not tracked correctly (threshold issues)\n- Token not being passed to git-lfs operations\n- CAS (Content Addressable Storage) service authentication failures (new in 2025)\n\n### 4. **API Version Compatibility (2025 Issue)**\n- Modern access tokens only work with API v2 endpoints\n- `huggingface_hub` may internally use API v1 endpoints causing 401 errors\n- Reported as recently as October 2025\n\n### 5. **Large File Upload Issues**\n- Authorization errors when uploading many files (~1000+ files, 300GB+)\n- Timeout issues with LFS authentication on large batches\n\n---\n\n## Diagnostic Steps\n\n### Step 1: Verify Authentication Status\n\n```bash\n# Check who you're authenticated as\nhuggingface-cli whoami\n\n# Or using Python\npython3 -c \"from huggingface_hub import whoami; print(whoami())\"\n```\n\n**Expected Output:** Should show username `akseljoonas` and token permissions\n\n### Step 2: Check Token Permissions\n\n```bash\n# Login and verify token has WRITE permission\nhuggingface-cli login --token YOUR_TOKEN\n\n# Look for this line in output:\n# Token is valid (permission: write).\n```\n\n**Important:** If you see `(permission: read)`, your token is insufficient for pushing!\n\n### Step 3: Verify Git Configuration\n\n```bash\n# Check git credential configuration\ngit config --global --list | grep credential\n\n# Check for git-lfs installation\ngit lfs version\n\n# Check git-lfs environment\ngit lfs env\n```\n\n### Step 4: Check Repository Access\n\n```python\nfrom huggingface_hub import HfApi, auth_check\n\ntry:\n # Verify you have access to the repository\n auth_check(\"HuggingFaceFW/fineweb-edu\", repo_type=\"dataset\")\n print(\"βœ“ Access granted to repository\")\nexcept Exception as e:\n print(f\"βœ— Access denied: {e}\")\n```\n\n### Step 5: Inspect Local Repository (if cloned)\n\n```bash\n# Navigate to your local repo\ncd /path/to/fineweb-edu\n\n# Check git remote\ngit remote -v\n\n# Check git-lfs tracking\ngit lfs track\n\n# Check .gitattributes file\ncat .gitattributes\n```\n\n---\n\n## Complete Fix Solutions\n\n### Solution 1: Re-authenticate with Correct Token Scope βœ… RECOMMENDED\n\nThis is the most common fix for 401 errors.\n\n```bash\n# Step 1: Create a new token with WRITE permissions\n# Go to: https://huggingface.co/settings/tokens\n# Click \"New token\"\n# Select role: \"write\" (NOT \"read\")\n# Give it a name like \"dataset-push-token\"\n# Copy the token (starts with hf_...)\n\n# Step 2: Login with the token AND add to git credentials\nhuggingface-cli login --token YOUR_WRITE_TOKEN --add-to-git-credential\n\n# Step 3: Verify the login\nhuggingface-cli whoami\n```\n\n**Expected Output:**\n```\nToken is valid (permission: write).\nYour token has been saved in your configured git credential helpers (store).\nYour token has been saved to /home/username/.cache/huggingface/token\nLogin successful\n```\n\n**Python Alternative:**\n```python\nfrom huggingface_hub import login\n\n# Login with write token and save to git credentials\nlogin(token=\"hf_YOUR_WRITE_TOKEN\", add_to_git_credential=True)\n```\n\n---\n\n### Solution 2: Configure Git Credentials Manually\n\nIf `--add-to-git-credential` doesn't work automatically:\n\n```bash\n# Step 1: Configure git credential store\ngit config --global credential.helper store\n\n# Step 2: Create/edit the credentials file\n# Location: ~/.git-credentials (Linux/Mac) or C:\\\\Users\\\\\\\\.git-credentials (Windows)\necho \"https://YOUR_USERNAME:YOUR_HF_TOKEN@huggingface.co\" >> ~/.git-credentials\n\n# Step 3: Verify\ncat ~/.git-credentials | grep huggingface\n```\n\n**Format for credentials file:**\n```\nhttps://akseljoonas:hf_YOUR_TOKEN@huggingface.co\n```\n\n---\n\n### Solution 3: Fix Git-LFS Configuration\n\nFor large datasets like fineweb-edu, git-lfs is essential:\n\n```bash\n# Step 1: Install git-lfs (if not installed)\n# Ubuntu/Debian:\nsudo apt-get install git-lfs\n\n# macOS:\nbrew install git-lfs\n\n# Windows: Download from https://git-lfs.github.com/\n\n# Step 2: Initialize git-lfs globally\ngit lfs install\n\n# Step 3: In your repository, track large files\ncd /path/to/fineweb-edu\n\n# Track common large file types for datasets\ngit lfs track \"*.parquet\"\ngit lfs track \"*.arrow\"\ngit lfs track \"*.bin\"\ngit lfs track \"*.safetensors\"\ngit lfs track \"*.h5\"\ngit lfs track \"*.json.gz\"\n\n# Step 4: Verify tracking\ngit lfs track\n\n# Step 5: Check .gitattributes was updated\ncat .gitattributes\n```\n\n**Default Large File Threshold:**\n- HuggingFace automatically uses LFS for files > 10MB\n- Files under 10MB are stored as regular git objects\n\n---\n\n### Solution 4: Use HuggingFace Hub API Instead of Git (RECOMMENDED for Large Datasets)\n\nFor very large datasets like fineweb-edu, using the Python API is more reliable than git push:\n\n```python\nfrom huggingface_hub import HfApi, login\nfrom pathlib import Path\n\n# Step 1: Authenticate\nlogin(token=\"hf_YOUR_WRITE_TOKEN\", add_to_git_credential=True)\n\n# Step 2: Initialize API client\napi = HfApi()\n\n# Step 3: Upload files to the dataset repository\n# For a single file:\napi.upload_file(\n path_or_fileobj=\"/path/to/local/file.parquet\",\n path_in_repo=\"data/file.parquet\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n)\n\n# For multiple files in a folder:\napi.upload_folder(\n folder_path=\"/path/to/local/folder\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n commit_message=\"Add new data files\",\n)\n\n# For very large uploads, use multi_commits=True:\napi.upload_large_folder(\n folder_path=\"/path/to/large/dataset\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n multi_commits=True,\n commit_message=\"Upload large dataset batch\",\n)\n```\n\n**Benefits over git push:**\n- Better handling of large files (no LFS authentication issues)\n- Automatic retry on failures\n- Progress tracking\n- No credential caching problems\n- Works around 2025 API v1/v2 compatibility issues\n\n---\n\n### Solution 5: Handle CAS Service Errors (2025 Issue)\n\nIf you see errors mentioning \"CAS service\" or \"Content Addressable Storage\":\n\n```python\nfrom huggingface_hub import HfApi\nimport time\n\napi = HfApi()\n\n# Use smaller batch sizes with delays\nfiles_to_upload = list(Path(\"/your/dataset\").glob(\"*.parquet\"))\n\nfor file_path in files_to_upload:\n try:\n api.upload_file(\n path_or_fileobj=str(file_path),\n path_in_repo=f\"data/{file_path.name}\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n )\n print(f\"βœ“ Uploaded {file_path.name}\")\n time.sleep(2) # Small delay to avoid overwhelming CAS service\n except Exception as e:\n print(f\"βœ— Failed to upload {file_path.name}: {e}\")\n```\n\n---\n\n### Solution 6: Check Repository Permissions\n\nVerify you have write access to the repository:\n\n```python\nfrom huggingface_hub import HfApi, whoami\n\napi = HfApi()\n\n# Check your user info\nuser_info = whoami()\nprint(f\"Username: {user_info['name']}\")\nprint(f\"Organizations: {user_info.get('orgs', [])}\")\n\n# Check if you're part of HuggingFaceFW organization\norgs = user_info.get('orgs', [])\nhas_access = any(org.get('name') == 'HuggingFaceFW' for org in orgs)\n\nif has_access:\n print(\"βœ“ You are a member of HuggingFaceFW organization\")\nelse:\n print(\"βœ— You are NOT a member of HuggingFaceFW organization\")\n print(\" You may need to request access or use a PR instead\")\n```\n\n**If you don't have write access:**\n```bash\n# Create a pull request instead of pushing directly\nhuggingface-cli upload HuggingFaceFW/fineweb-edu /path/to/file --create-pr\n```\n\nOr with Python:\n```python\napi.upload_file(\n path_or_fileobj=\"/path/to/file\",\n path_in_repo=\"data/file.parquet\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n create_pr=True, # Creates a PR instead of direct push\n)\n```\n\n---\n\n## Git-LFS Configuration Details\n\n### File Size Thresholds\n\n| File Size | Storage Method | Configuration |\n|-----------|---------------|---------------|\n| < 10 MB | Regular Git | No special config needed |\n| > 10 MB | Git-LFS | Automatically tracked by HF |\n| > 5 GB | Git-LFS + Special handling | Use API upload methods |\n\n### Common .gitattributes for Datasets\n\n```gitattributes\n# Large data files\n*.parquet filter=lfs diff=lfs merge=lfs -text\n*.arrow filter=lfs diff=lfs merge=lfs -text\n*.bin filter=lfs diff=lfs merge=lfs -text\n*.safetensors filter=lfs diff=lfs merge=lfs -text\n*.h5 filter=lfs diff=lfs merge=lfs -text\n*.hdf5 filter=lfs diff=lfs merge=lfs -text\n\n# Compressed files\n*.tar.gz filter=lfs diff=lfs merge=lfs -text\n*.zip filter=lfs diff=lfs merge=lfs -text\n*.json.gz filter=lfs diff=lfs merge=lfs -text\n\n# Model files\n*.onnx filter=lfs diff=lfs merge=lfs -text\n*.pb filter=lfs diff=lfs merge=lfs -text\n*.pt filter=lfs diff=lfs merge=lfs -text\n*.pth filter=lfs diff=lfs merge=lfs -text\n```\n\n### Verify LFS is Working\n\n```bash\n# Check which files are tracked by LFS\ngit lfs ls-files\n\n# Check LFS status\ngit lfs status\n\n# Verify a specific file is using LFS\ngit lfs ls-files | grep \"your-file.parquet\"\n\n# See LFS configuration\ngit lfs env\n```\n\n---\n\n## Environment Variables\n\nUseful environment variables for debugging:\n\n```bash\n# Set HuggingFace token via environment variable\nexport HF_TOKEN=\"hf_YOUR_TOKEN\"\n\n# Disable implicit token sending (for debugging)\nexport HF_HUB_DISABLE_IMPLICIT_TOKEN=1\n\n# Enable verbose git LFS output\nexport GIT_TRACE=1\nexport GIT_CURL_VERBOSE=1\nexport GIT_LFS_TRACE=1\n\n# Set custom cache directory\nexport HF_HOME=\"/path/to/custom/cache\"\n```\n\n---\n\n## Testing the Fix\n\nAfter applying the fixes, test with a small file first:\n\n```python\nfrom huggingface_hub import HfApi\nimport tempfile\nfrom pathlib import Path\n\napi = HfApi()\n\n# Create a small test file\nwith tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:\n f.write(\"Test file for authentication verification\")\n test_file = f.name\n\ntry:\n # Try uploading to a test repository you own\n # DO NOT test on fineweb-edu directly!\n result = api.upload_file(\n path_or_fileobj=test_file,\n path_in_repo=\"test_auth.txt\",\n repo_id=\"YOUR_USERNAME/test-repo\", # Use your own test repo\n repo_type=\"dataset\",\n )\n print(f\"βœ“ Authentication working! File uploaded to: {result}\")\nexcept Exception as e:\n print(f\"βœ— Authentication failed: {e}\")\nfinally:\n Path(test_file).unlink() # Clean up test file\n```\n\n---\n\n## Quick Reference - Commands Checklist\n\n```bash\n# 1. Check current authentication\nhuggingface-cli whoami\n\n# 2. Re-login with write token\nhuggingface-cli login --token YOUR_WRITE_TOKEN --add-to-git-credential\n\n# 3. Verify git credentials\ngit config --global credential.helper store\ncat ~/.git-credentials | grep huggingface\n\n# 4. Check git-lfs\ngit lfs version\ngit lfs install\n\n# 5. In your repo, verify LFS tracking\ncd /path/to/repo\ngit lfs track\ncat .gitattributes\n\n# 6. Test authentication with Python\npython3 -c \"from huggingface_hub import whoami; print(whoami())\"\n```\n\n---\n\n## Common Error Messages and Solutions\n\n| Error Message | Cause | Solution |\n|---------------|-------|----------|\n| `401 Unauthorized` | Invalid or read-only token | Use Solution 1: Re-authenticate with write token |\n| `403 Forbidden` | No access to repository | Check repository permissions (Solution 6) |\n| `Repository not found` | Wrong repo ID or private repo without access | Verify repo exists and you have access |\n| `LFS authentication failed` | Git credentials not configured | Use Solution 2: Configure git credentials |\n| `CAS service error` | 2025 API issue | Use Solution 5: Smaller batches with delays |\n| `This repository requires LFS` | Missing git-lfs | Use Solution 3: Install and configure git-lfs |\n| `batch response: This repository is over its data limit` | Repository quota exceeded | Contact repository owner |\n\n---\n\n## Best Practices for Large Datasets\n\nFor datasets like fineweb-edu (1.3T tokens):\n\n1. **Use the HuggingFace Hub API** instead of git push\n2. **Upload in batches** rather than all at once\n3. **Use `upload_large_folder()`** with `multi_commits=True`\n4. **Monitor upload progress** and implement retry logic\n5. **Test with small files first** before uploading large batches\n6. **Use fine-grained tokens** for production environments\n7. **Keep tokens secure** - use environment variables or secure vaults\n\n---\n\n## Additional Resources\n\n- [HuggingFace Hub Python Library](https://huggingface.co/docs/huggingface_hub)\n- [Security Tokens Documentation](https://huggingface.co/docs/hub/security-tokens)\n- [Git-LFS Documentation](https://git-lfs.github.com/)\n- [HuggingFace CLI Guide](https://huggingface.co/docs/huggingface_hub/guides/cli)\n\n---\n\n## Document Version\n\n- **Created:** December 18, 2025\n- **Last Updated:** December 18, 2025\n- **Tested Against:** HuggingFace Hub API v1.2.3+\n- **Authenticated User:** akseljoonas\n- **Target Repository:** HuggingFaceFW/fineweb-edu (dataset)\n\n---\n\n## Sources & References\n\n- [I got Authorization error - Hugging Face Forums](https://discuss.huggingface.co/t/i-got-authorization-error/32881)\n- [Can't push to a dataset repository - Hugging Face Forums](https://discuss.huggingface.co/t/cant-push-to-a-dataset-repository/36611)\n- [LFS: Authorization error when uploading large files](https://lightrun.com/answers/huggingface-huggingface_hub-lfs-authorization-error-when-uploading-manylarge-files)\n- [401 Client Error - huggingface_hub Issue #2586](https://github.com/huggingface/huggingface_hub/issues/2586)\n- [Modern Access Tokens API v2 issue - Issue #3479](https://github.com/huggingface/huggingface_hub/issues/3479)\n- [Hugging Face Hub Dataset Upload CAS Error - Issue #7760](https://github.com/huggingface/datasets/issues/7760)\n- [HuggingFace Security Tokens Documentation](https://huggingface.co/docs/hub/security-tokens)\n\"\"\"\n\n# Expand the ~ to the user's home directory\noutput_path = Path.home() / \"huggingface_401_fix_documentation.md\"\n\n# Write the documentation to the file\ntry:\n with open(output_path, 'w', encoding='utf-8') as f:\n f.write(documentation_content)\n print(f\"βœ“ Successfully created documentation at: {output_path}\")\n print(f\"βœ“ File size: {output_path.stat().st_size} bytes\")\nexcept Exception as e:\n print(f\"βœ— Error creating file: {e}\")\n raise", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "x6z3pkwzo8", + "source": "import csv\n\n# Model data collected from Hugging Face API - Apache-2.0 licensed text-classification models under 500MB\nmodel_data = [\n {'model_id': 'kmack/malicious-url-detection', 'downloads': 2000000, 'likes': 1, 'size_mb': 255.2, 'license': 'apache-2.0'},\n {'model_id': 'mixedbread-ai/mxbai-rerank-xsmall-v1', 'downloads': 960600, 'likes': 49, 'size_mb': 491.7, 'license': 'apache-2.0'},\n {'model_id': 'cross-encoder/ms-marco-TinyBERT-L2-v2', 'downloads': 598100, 'likes': 36, 'size_mb': 172.09, 'license': 'apache-2.0'},\n {'model_id': 'cybersectony/phishing-email-detection-distilbert_v2.4.1', 'downloads': 300500, 'likes': 23, 'size_mb': 255.26, 'license': 'apache-2.0'},\n {'model_id': 'jamal-ibrahim/risk_assesment', 'downloads': 98700, 'likes': 0, 'size_mb': 255.42, 'license': 'apache-2.0'},\n {'model_id': 'agufsamudra/indo-sentiment-analysis', 'downloads': 92100, 'likes': 0, 'size_mb': 475.0, 'license': 'apache-2.0'}\n]\n\n# Already sorted by downloads descending\ncsv_path = '/tmp/apache2_text_classification_models.csv'\nwith open(csv_path, 'w', newline='') as f:\n writer = csv.DictWriter(f, fieldnames=['model_id', 'downloads', 'likes', 'size_mb', 'license'])\n writer.writeheader()\n writer.writerows(model_data)\n\nprint(f'βœ“ CSV file created at: {csv_path}')\nprint(f'βœ“ Total models: {len(model_data)}')\nprint(f'βœ“ All models are Apache-2.0 licensed and under 500MB')\nprint(f'βœ“ Sorted by downloads (descending)')", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "t9et9n50wgr", + "source": "# This is just to check the notebook structure\nprint(\"test\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "n4awck8w5ok", + "source": "# Write the KV Cache benchmark script\nbenchmark_script = '''#!/usr/bin/env python3\n\"\"\"\nKV Cache Quantization Benchmark Script\nCompares FP16 vs INT8 quantized KV cache performance on CNN/DailyMail summarization task\n\"\"\"\n\nimport json\nimport time\nimport torch\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nfrom rouge_score import rouge_scorer\nimport gc\nfrom typing import Dict, List, Tuple\nimport numpy as np\n\n# Configuration\nMODEL_NAME = \"meta-llama/Llama-3.2-1B\"\nDATASET_NAME = \"cnn_dailymail\"\nDATASET_CONFIG = \"3.0.0\"\nNUM_SAMPLES = 100\nMAX_NEW_TOKENS = 128\nDO_SAMPLE = False\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nprint(f\"Using device: {DEVICE}\")\nprint(f\"PyTorch version: {torch.__version__}\")\n\n# Install required packages (instructions for user)\nprint(\"\\\\nRequired packages:\")\nprint(\"pip install transformers datasets rouge-score torch hqq accelerate\")\nprint(\"-\" * 80)\n\n\ndef load_model_and_tokenizer():\n \"\"\"Load the model and tokenizer\"\"\"\n print(f\"\\\\nLoading model: {MODEL_NAME}\")\n \n tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n \n # Set padding token if not set\n if tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n \n model = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n torch_dtype=torch.float16 if DEVICE == \"cuda\" else torch.float32,\n device_map=\"auto\" if DEVICE == \"cuda\" else None,\n )\n \n if DEVICE == \"cpu\":\n model = model.to(DEVICE)\n \n model.eval()\n \n print(f\"Model loaded successfully on {DEVICE}\")\n return model, tokenizer\n\n\ndef load_data() -> List[Dict]:\n \"\"\"Load CNN/DailyMail dataset\"\"\"\n print(f\"\\\\nLoading {NUM_SAMPLES} samples from {DATASET_NAME} dataset...\")\n \n dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=\"test\")\n samples = dataset.select(range(min(NUM_SAMPLES, len(dataset))))\n \n data = []\n for sample in samples:\n data.append({\n \"article\": sample[\"article\"],\n \"highlights\": sample[\"highlights\"],\n })\n \n print(f\"Loaded {len(data)} samples\")\n return data\n\n\ndef prepare_prompt(article: str) -> str:\n \"\"\"Prepare prompt for summarization\"\"\"\n prompt = f\"\"\"Summarize the following article in one or two sentences:\n\nArticle: {article[:1000]}\n\nSummary:\"\"\"\n return prompt\n\n\ndef generate_summaries(\n model, \n tokenizer, \n data: List[Dict], \n cache_implementation: str = \"default\",\n cache_config: Dict = None\n) -> Tuple[List[str], float, float]:\n \"\"\"\n Generate summaries and measure performance\n \n Returns:\n summaries: List of generated summaries\n tokens_per_sec: Throughput in tokens/second\n peak_memory_mb: Peak memory usage in MB\n \"\"\"\n summaries = []\n total_tokens = 0\n start_time = time.time()\n \n if DEVICE == \"cuda\":\n torch.cuda.reset_peak_memory_stats()\n initial_memory = torch.cuda.memory_allocated()\n \n print(f\"\\\\nGenerating summaries with cache_implementation='{cache_implementation}'...\")\n \n for i, sample in enumerate(data):\n prompt = prepare_prompt(sample[\"article\"])\n \n inputs = tokenizer(\n prompt, \n return_tensors=\"pt\", \n truncation=True, \n max_length=2048\n ).to(DEVICE)\n \n # Generate with specified cache configuration\n generation_kwargs = {\n \"max_new_tokens\": MAX_NEW_TOKENS,\n \"do_sample\": DO_SAMPLE,\n \"pad_token_id\": tokenizer.pad_token_id,\n }\n \n if cache_implementation != \"default\":\n generation_kwargs[\"cache_implementation\"] = cache_implementation\n if cache_config:\n generation_kwargs[\"cache_config\"] = cache_config\n \n with torch.no_grad():\n outputs = model.generate(**inputs, **generation_kwargs)\n \n # Decode only the generated tokens (exclude prompt)\n generated_tokens = outputs[0][inputs.input_ids.shape[1]:]\n summary = tokenizer.decode(generated_tokens, skip_special_tokens=True)\n summaries.append(summary.strip())\n \n total_tokens += len(generated_tokens)\n \n if (i + 1) % 10 == 0:\n print(f\" Processed {i + 1}/{len(data)} samples\")\n \n end_time = time.time()\n elapsed_time = end_time - start_time\n tokens_per_sec = total_tokens / elapsed_time\n \n if DEVICE == \"cuda\":\n peak_memory = torch.cuda.max_memory_allocated()\n peak_memory_mb = (peak_memory - initial_memory) / (1024 * 1024)\n else:\n peak_memory_mb = 0.0\n \n print(f\" Generated {total_tokens} tokens in {elapsed_time:.2f}s\")\n print(f\" Throughput: {tokens_per_sec:.2f} tokens/sec\")\n if DEVICE == \"cuda\":\n print(f\" Peak memory: {peak_memory_mb:.2f} MB\")\n \n return summaries, tokens_per_sec, peak_memory_mb\n\n\ndef calculate_rouge_scores(predictions: List[str], references: List[str]) -> Dict[str, float]:\n \"\"\"Calculate ROUGE-L scores\"\"\"\n print(\"\\\\nCalculating ROUGE-L scores...\")\n \n scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)\n scores = []\n \n for pred, ref in zip(predictions, references):\n score = scorer.score(ref, pred)\n scores.append(score['rougeL'].fmeasure)\n \n avg_score = np.mean(scores)\n std_score = np.std(scores)\n \n print(f\" ROUGE-L: {avg_score:.4f} Β± {std_score:.4f}\")\n \n return {\n \"mean\": float(avg_score),\n \"std\": float(std_score),\n \"scores\": [float(s) for s in scores]\n }\n\n\ndef benchmark_cache(\n model,\n tokenizer,\n data: List[Dict],\n cache_type: str,\n cache_implementation: str = \"default\",\n cache_config: Dict = None\n) -> Dict:\n \"\"\"Run benchmark for a specific cache configuration\"\"\"\n print(f\"\\\\n{'='*80}\")\n print(f\"Benchmarking {cache_type}\")\n print(f\"{'='*80}\")\n \n # Clear cache\n if DEVICE == \"cuda\":\n torch.cuda.empty_cache()\n gc.collect()\n \n # Generate summaries\n summaries, tokens_per_sec, peak_memory_mb = generate_summaries(\n model, \n tokenizer, \n data,\n cache_implementation=cache_implementation,\n cache_config=cache_config\n )\n \n # Calculate ROUGE scores\n references = [sample[\"highlights\"] for sample in data]\n rouge_scores = calculate_rouge_scores(summaries, references)\n \n results = {\n \"cache_type\": cache_type,\n \"cache_implementation\": cache_implementation,\n \"cache_config\": cache_config,\n \"tokens_per_sec\": float(tokens_per_sec),\n \"peak_memory_mb\": float(peak_memory_mb),\n \"rouge_l_mean\": rouge_scores[\"mean\"],\n \"rouge_l_std\": rouge_scores[\"std\"],\n \"num_samples\": len(data),\n \"total_tokens_generated\": len(summaries) * MAX_NEW_TOKENS,\n }\n \n return results, summaries\n\n\ndef main():\n \"\"\"Main benchmark function\"\"\"\n print(\"=\"*80)\n print(\"KV Cache Quantization Benchmark\")\n print(\"=\"*80)\n print(f\"Model: {MODEL_NAME}\")\n print(f\"Dataset: {DATASET_NAME}\")\n print(f\"Num samples: {NUM_SAMPLES}\")\n print(f\"Max new tokens: {MAX_NEW_TOKENS}\")\n \n # Load model and data\n model, tokenizer = load_model_and_tokenizer()\n data = load_data()\n \n # Benchmark FP16 (default) cache\n fp16_results, fp16_summaries = benchmark_cache(\n model, \n tokenizer, \n data,\n cache_type=\"FP16 (Default)\",\n cache_implementation=\"default\",\n cache_config=None\n )\n \n # Benchmark INT8 quantized cache with HQQ\n int8_results, int8_summaries = benchmark_cache(\n model,\n tokenizer,\n data,\n cache_type=\"INT8 (HQQ Quantized)\",\n cache_implementation=\"quantized\",\n cache_config={\n \"backend\": \"HQQ\",\n \"nbits\": 8,\n \"axis_key\": 1,\n \"axis_value\": 1\n }\n )\n \n # Compare results\n print(\"\\\\n\" + \"=\"*80)\n print(\"COMPARISON RESULTS\")\n print(\"=\"*80)\n \n speedup = int8_results[\"tokens_per_sec\"] / fp16_results[\"tokens_per_sec\"]\n rouge_diff = int8_results[\"rouge_l_mean\"] - fp16_results[\"rouge_l_mean\"]\n \n if fp16_results[\"peak_memory_mb\"] > 0:\n memory_savings_pct = (1 - int8_results[\"peak_memory_mb\"] / fp16_results[\"peak_memory_mb\"]) * 100\n else:\n memory_savings_pct = 0.0\n \n print(f\"\\\\nFP16 Cache:\")\n print(f\" Throughput: {fp16_results['tokens_per_sec']:.2f} tokens/sec\")\n print(f\" ROUGE-L: {fp16_results['rouge_l_mean']:.4f} Β± {fp16_results['rouge_l_std']:.4f}\")\n print(f\" Peak Memory: {fp16_results['peak_memory_mb']:.2f} MB\")\n \n print(f\"\\\\nINT8 Quantized Cache (HQQ):\")\n print(f\" Throughput: {int8_results['tokens_per_sec']:.2f} tokens/sec\")\n print(f\" ROUGE-L: {int8_results['rouge_l_mean']:.4f} Β± {int8_results['rouge_l_std']:.4f}\")\n print(f\" Peak Memory: {int8_results['peak_memory_mb']:.2f} MB\")\n \n print(f\"\\\\nComparison:\")\n print(f\" Speedup: {speedup:.2f}x\")\n print(f\" ROUGE-L Difference: {rouge_diff:+.4f}\")\n print(f\" Memory Savings: {memory_savings_pct:.1f}%\")\n \n # Prepare final results\n final_results = {\n \"config\": {\n \"model\": MODEL_NAME,\n \"dataset\": DATASET_NAME,\n \"num_samples\": NUM_SAMPLES,\n \"max_new_tokens\": MAX_NEW_TOKENS,\n \"do_sample\": DO_SAMPLE,\n \"device\": DEVICE,\n },\n \"fp16_cache\": fp16_results,\n \"int8_quantized_cache\": int8_results,\n \"comparison\": {\n \"speedup\": float(speedup),\n \"rouge_l_difference\": float(rouge_diff),\n \"memory_savings_percent\": float(memory_savings_pct),\n }\n }\n \n # Save results to JSON\n output_file = \"/tmp/kv_cache_benchmark_results.json\"\n with open(output_file, \"w\") as f:\n json.dump(final_results, f, indent=2)\n \n print(f\"\\\\n{'='*80}\")\n print(f\"Results saved to: {output_file}\")\n print(f\"{'='*80}\")\n \n return final_results\n\n\nif __name__ == \"__main__\":\n main()\n'''\n\n# Write benchmark script\nwith open('/tmp/kv_cache_benchmark.py', 'w') as f:\n f.write(benchmark_script)\n\nimport os\nos.chmod('/tmp/kv_cache_benchmark.py', 0o755)\n\n# Write README\nreadme_content = '''# KV Cache Quantization Benchmark\n\nThis benchmark compares FP16 (default) vs INT8 quantized KV cache performance using Llama-3.2-1B on the CNN/DailyMail summarization task.\n\n## Overview\n\nThe script evaluates:\n- **Throughput**: Tokens generated per second\n- **Memory Usage**: Peak memory consumption during generation\n- **Quality**: ROUGE-L scores comparing generated summaries to reference summaries\n\n## Requirements\n\nInstall the required packages:\n\n```bash\npip install transformers datasets rouge-score torch hqq accelerate\n```\n\n### GPU Requirements\n- CUDA-compatible GPU recommended (script will fall back to CPU if no GPU is available)\n- At least 8GB VRAM for Llama-3.2-1B with FP16\n- At least 4GB VRAM for INT8 quantized cache\n\n## Usage\n\n### Basic Usage\n\nRun the benchmark with default settings (100 samples):\n\n```bash\npython /tmp/kv_cache_benchmark.py\n```\n\n### Configuration\n\nYou can modify the configuration variables at the top of the script:\n\n```python\nMODEL_NAME = \"meta-llama/Llama-3.2-1B\" # Model to benchmark\nDATASET_NAME = \"cnn_dailymail\" # Dataset name\nDATASET_CONFIG = \"3.0.0\" # Dataset version\nNUM_SAMPLES = 100 # Number of test samples\nMAX_NEW_TOKENS = 128 # Max tokens to generate per sample\nDO_SAMPLE = False # Use greedy decoding\n```\n\n### Output\n\nThe script will:\n1. Load the model and dataset\n2. Run FP16 (default) cache benchmark\n3. Run INT8 quantized cache benchmark with HQQ\n4. Calculate ROUGE-L scores for both configurations\n5. Display comparison results\n6. Save detailed results to `/tmp/kv_cache_benchmark_results.json`\n\n## Results Format\n\nThe output JSON file contains:\n- Configuration details\n- FP16 cache results (throughput, memory, ROUGE-L)\n- INT8 quantized cache results\n- Comparison metrics (speedup, quality difference, memory savings)\n\nExample output:\n```json\n{\n \"config\": {\n \"model\": \"meta-llama/Llama-3.2-1B\",\n \"dataset\": \"cnn_dailymail\",\n \"num_samples\": 100,\n \"max_new_tokens\": 128,\n \"device\": \"cuda\"\n },\n \"fp16_cache\": {\n \"tokens_per_sec\": 150.5,\n \"peak_memory_mb\": 2048.3,\n \"rouge_l_mean\": 0.3245\n },\n \"int8_quantized_cache\": {\n \"tokens_per_sec\": 180.2,\n \"peak_memory_mb\": 1024.1,\n \"rouge_l_mean\": 0.3198\n },\n \"comparison\": {\n \"speedup\": 1.20,\n \"rouge_l_difference\": -0.0047,\n \"memory_savings_percent\": 50.0\n }\n}\n```\n\n## Understanding the Results\n\n### Speedup\n- Values > 1.0 indicate INT8 quantization is faster\n- Typical range: 1.1x - 1.5x speedup\n\n### Memory Savings\n- Percentage reduction in peak memory usage\n- Typical range: 40% - 50% reduction\n\n### ROUGE-L Difference\n- Negative values indicate slight quality degradation\n- Small differences (< 0.01) are generally acceptable\n- ROUGE-L measures overlap between generated and reference summaries\n\n## Troubleshooting\n\n### CUDA Out of Memory\nIf you encounter OOM errors:\n1. Reduce `NUM_SAMPLES`\n2. Reduce `MAX_NEW_TOKENS`\n3. Ensure no other processes are using GPU memory\n\n### ImportError for HQQ\nMake sure you have installed the HQQ package:\n```bash\npip install hqq\n```\n\n### Slow Performance on CPU\nThe benchmark is designed for GPU. CPU performance will be significantly slower but still functional.\n\n## Advanced Usage\n\n### Custom Cache Configurations\n\nYou can modify the cache configuration in the `benchmark_cache` function:\n\n```python\n# Example: Different quantization settings\nint4_results, int4_summaries = benchmark_cache(\n model,\n tokenizer,\n data,\n cache_type=\"INT4 (HQQ Quantized)\",\n cache_implementation=\"quantized\",\n cache_config={\n \"backend\": \"HQQ\",\n \"nbits\": 4, # 4-bit quantization\n \"axis_key\": 1,\n \"axis_value\": 1\n }\n)\n```\n\n### Different Models\n\nTo benchmark different models, change the `MODEL_NAME`:\n\n```python\nMODEL_NAME = \"meta-llama/Llama-3.2-3B\" # Larger model\n# or\nMODEL_NAME = \"microsoft/phi-2\" # Different architecture\n```\n\nMake sure the model supports the quantized cache implementation.\n\n## Technical Details\n\n### HQQ (Half-Quadratic Quantization)\n- Quantizes KV cache tensors to lower precision (INT8 or INT4)\n- Reduces memory footprint and bandwidth requirements\n- Minimal impact on generation quality\n- Supported by Hugging Face Transformers\n\n### KV Cache\n- Stores key and value tensors from attention layers\n- Avoids recomputing previous token representations\n- Memory usage grows with sequence length\n- Quantization reduces per-token memory cost\n\n## References\n\n- [Hugging Face Transformers Documentation](https://huggingface.co/docs/transformers)\n- [HQQ: Half-Quadratic Quantization](https://github.com/mobiusml/hqq)\n- [CNN/DailyMail Dataset](https://huggingface.co/datasets/cnn_dailymail)\n- [ROUGE Metric](https://huggingface.co/spaces/evaluate-metric/rouge)\n\n## License\n\nThis script is provided as-is for benchmarking purposes. Please refer to the licenses of the individual components (model, dataset, libraries) for their usage terms.\n'''\n\nwith open('/tmp/KV_CACHE_README.md', 'w') as f:\n f.write(readme_content)\n\nprint(\"βœ“ Created /tmp/kv_cache_benchmark.py\")\nprint(\"βœ“ Created /tmp/KV_CACHE_README.md\")\nprint(\"\\\\nFiles are ready to use!\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "u8gky39qvr", + "source": "# Create CSV with apache-2.0 models under 500MB\nimport csv\n\n# All apache-2.0 licensed models from search, excluding BART (>500MB)\nmodels = [\n [\"nateraw/codecarbon-text-classification\", 1000],\n [\"nickmuchi/distilroberta-finetuned-financial-text-classification\", 48900],\n [\"jxuhf/Fine-tuning-text-classification-model-Habana-Gaudi\", 736],\n [\"krupper/text-complexity-classification\", 8200],\n [\"DenilsenAxel/nlp-text-classification\", 999],\n [\"lucasgbezerra/classification_text_model\", 340],\n [\"fatenghali/text_classification_model\", 506],\n [\"maayansharon/climate_text_classification_mini_model\", 343],\n [\"Aaryan562/distilbert-base-uncased-fine-tuned-text-classification\", 283],\n [\"dmjimenezbravo/electra-small-discriminator-text-classification-en-finetuned-amazon_reviews_multi-en\", 312],\n [\"ratish/bert-textClassification_v1.1\", 278],\n [\"ratish/bert-textClassification_v1.4\", 313],\n [\"Amite5h/TextClassificationmulticlass\", 298],\n [\"Sleoruiz/roberta-base-fine-tuned-text-classification-pesos-fixed\", 7],\n [\"Sleoruiz/roberta-base-fine-tuned-text-classification-pesos-fixed-2\", 9],\n [\"Sleoruiz/roberta-bne-fine-tuned-text-classification-SL-data-augmentation-dss\", 314],\n [\"Sleoruiz/roberta-bne-fine-tuned-text-classification-SL-dss\", 454],\n [\"Cynthiaiii4/Text_classification_HW\", 8],\n [\"tKah/Textclassification-Bert\", 245],\n [\"Sleoruiz/roberta-bne-fine-tuned-text-classification-SL-1200samples\", 287],\n [\"Leslie123/stackoverflow-text-classification\", 225],\n [\"Cynthiaiii4/Text_classification_bert-base-uncased\", 6],\n [\"Cynthiaiii4/Text_classification_model_blu\", 7],\n [\"Cynthiaiii4/Text_classification_model_bbc\", 6],\n [\"sfurkan/LexBERT-textclassification-turkish-uncased\", 8],\n]\n\n# Write to CSV\noutput_path = \"/Users/akseljoonas/Documents/hf-agent/text_classification_models.csv\"\nwith open(output_path, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n writer = csv.writer(f)\n writer.writerow([\"model_id\", \"downloads\"])\n writer.writerows(models)\n\nprint(f\"βœ“ CSV file created: {output_path}\")\nprint(f\"βœ“ Total models: {len(models)}\")\nprint(f\"βœ“ Excluded: IT-community/BART_cnn_news_text_classification (>500MB)\")\n\n# Show first few rows\nprint(\"\\nFirst 5 rows:\")\nfor i, (model_id, downloads) in enumerate(models[:5], 1):\n print(f\" {i}. {model_id}: {downloads:,} downloads\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "njt45dhwbfb", + "source": "# Execute the cell above to create the CSV\n# Then verify it was created\nimport os\ncsv_path = \"/Users/akseljoonas/Documents/hf-agent/text_classification_models.csv\"\nif os.path.exists(csv_path):\n print(f\"βœ“ CSV file exists at: {csv_path}\")\n print(f\"βœ“ File size: {os.path.getsize(csv_path)} bytes\")\n \n # Read and display first few lines\n with open(csv_path, \"r\") as f:\n lines = f.readlines()\n print(f\"βœ“ Total lines: {len(lines)}\")\n print(\"\\nFirst 10 lines:\")\n for line in lines[:10]:\n print(f\" {line.rstrip()}\")\nelse:\n print(f\"βœ— CSV file not found at: {csv_path}\")\n print(\"Run the cell above first to create it.\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "704sq89c26n", + "source": "# Direct CSV creation without dependencies\ncsv_content = \"\"\"model_id,downloads\nnateraw/codecarbon-text-classification,1000\nnickmuchi/distilroberta-finetuned-financial-text-classification,48900\njxuhf/Fine-tuning-text-classification-model-Habana-Gaudi,736\nkrupper/text-complexity-classification,8200\nDenilsenAxel/nlp-text-classification,999\nlucasgbezerra/classification_text_model,340\nfatenghali/text_classification_model,506\nmaayansharon/climate_text_classification_mini_model,343\nAaryan562/distilbert-base-uncased-fine-tuned-text-classification,283\ndmjimenezbravo/electra-small-discriminator-text-classification-en-finetuned-amazon_reviews_multi-en,312\nratish/bert-textClassification_v1.1,278\nratish/bert-textClassification_v1.4,313\nAmite5h/TextClassificationmulticlass,298\nSleoruiz/roberta-base-fine-tuned-text-classification-pesos-fixed,7\nSleoruiz/roberta-base-fine-tuned-text-classification-pesos-fixed-2,9\nSleoruiz/roberta-bne-fine-tuned-text-classification-SL-data-augmentation-dss,314\nSleoruiz/roberta-bne-fine-tuned-text-classification-SL-dss,454\nCynthiaiii4/Text_classification_HW,8\ntKah/Textclassification-Bert,245\nSleoruiz/roberta-bne-fine-tuned-text-classification-SL-1200samples,287\nLeslie123/stackoverflow-text-classification,225\nCynthiaiii4/Text_classification_bert-base-uncased,6\nCynthiaiii4/Text_classification_model_blu,7\nCynthiaiii4/Text_classification_model_bbc,6\nsfurkan/LexBERT-textclassification-turkish-uncased,8\"\"\"\n\n# Write directly\nwith open(\"/Users/akseljoonas/Documents/hf-agent/text_classification_models.csv\", \"w\") as f:\n f.write(csv_content)\n\nprint(\"βœ“ CSV created successfully!\")\nprint(f\"βœ“ 25 models (apache-2.0 license, <500MB)\")\nprint(\"βœ“ 1 model excluded: IT-community/BART_cnn_news_text_classification (>500MB)\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "155tkweh88r", + "source": "# Create train_dpo.py file\nscript_content = '''\"\"\"DPO Training Script - Complete Implementation\"\"\"\nimport torch\nfrom datasets import load_dataset\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom trl import DPOTrainer, DPOConfig\n\nprint(\"=\"*80)\nprint(\"DPO Training - End-to-End Validation\")\nprint(\"=\"*80)\n\n# Configuration\nMODEL_NAME = \"Qwen/Qwen2-0.5B-Instruct\"\nDATASET_NAME = \"trl-lib/ultrafeedback_binarized\"\nOUTPUT_DIR = \"./dpo_output\"\nMAX_STEPS = 10\nBATCH_SIZE = 2\n\nprint(f\"\\\\n[CONFIG] Model: {MODEL_NAME}\")\nprint(f\"[CONFIG] Dataset: {DATASET_NAME}\")\nprint(f\"[CONFIG] Max steps: {MAX_STEPS}\")\nprint(f\"[CONFIG] Batch size: {BATCH_SIZE}\")\n\n# Step 1: Load tokenizer\nprint(\"\\\\n[1/6] Loading tokenizer...\")\ntokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\nprint(f\"βœ“ Tokenizer loaded\")\n\n# Step 2: Load dataset\nprint(\"\\\\n[2/6] Loading dataset...\")\ndataset = load_dataset(DATASET_NAME, split=\"train[:100]\")\nprint(f\"βœ“ Dataset loaded: {len(dataset)} samples\")\n\n# Step 3: Load model\nprint(\"\\\\n[3/6] Loading model...\")\nmodel = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,\n device_map=\"auto\",\n)\nprint(f\"βœ“ Model loaded: {model.num_parameters()/1e6:.1f}M parameters\")\n\n# Step 4: Configure training\nprint(\"\\\\n[4/6] Configuring DPO training...\")\ntraining_args = DPOConfig(\n output_dir=OUTPUT_DIR,\n max_steps=MAX_STEPS,\n per_device_train_batch_size=BATCH_SIZE,\n learning_rate=5e-7,\n logging_steps=2,\n save_steps=10,\n beta=0.1,\n fp16=torch.cuda.is_available(),\n remove_unused_columns=False,\n report_to=\"none\",\n)\nprint(\"βœ“ Configuration created\")\n\n# Step 5: Train\nprint(\"\\\\n[5/6] Starting DPO training...\")\nprint(\"-\"*80)\ntrainer = DPOTrainer(\n model=model,\n args=training_args,\n train_dataset=dataset,\n tokenizer=tokenizer,\n)\ntrain_result = trainer.train()\nprint(\"-\"*80)\nprint(f\"βœ“ Training completed! Loss: {train_result.training_loss:.4f}\")\n\n# Step 6: Save\nprint(\"\\\\n[6/6] Saving model...\")\ntrainer.save_model(OUTPUT_DIR)\nprint(f\"βœ“ Model saved to {OUTPUT_DIR}\")\n\nprint(\"\\\\n\" + \"=\"*80)\nprint(\"DPO TRAINING COMPLETED SUCCESSFULLY!\")\nprint(\"=\"*80)\nprint(f\"\\\\nOutput: {OUTPUT_DIR}\")\nprint(f\"Steps: {train_result.global_step}\")\nprint(f\"Final loss: {train_result.training_loss:.4f}\")\n'''\n\nimport os\nos.chdir('/Users/akseljoonas/Documents/hf-agent')\nwith open('train_dpo.py', 'w') as f:\n f.write(script_content)\n \nprint(\"βœ“ train_dpo.py created successfully!\")\nprint(f\"Location: {os.path.abspath('train_dpo.py')}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "wdnesxsmfq", + "source": "# Check if required packages are installed\nimport subprocess\nimport sys\n\npackages = ['torch', 'transformers', 'datasets', 'trl']\n\nprint(\"Checking installed packages...\")\nfor package in packages:\n try:\n __import__(package)\n version = subprocess.run([sys.executable, '-m', 'pip', 'show', package], \n capture_output=True, text=True, check=True)\n version_line = [line for line in version.stdout.split('\\n') if line.startswith('Version:')]\n if version_line:\n print(f\"βœ“ {package}: {version_line[0].split(':')[1].strip()}\")\n else:\n print(f\"βœ“ {package}: installed\")\n except ImportError:\n print(f\"βœ— {package}: NOT INSTALLED\")\n print(f\" Installing {package}...\")\n subprocess.run([sys.executable, '-m', 'pip', 'install', package], check=True)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "6lxro03b5k", + "source": "# Run the train_dpo.py script\nimport subprocess\nimport os\n\nos.chdir('/Users/akseljoonas/Documents/hf-agent')\n\nprint(\"Starting DPO training script...\")\nprint(\"=\"*80)\n\n# Run the script and capture output in real-time\nprocess = subprocess.Popen(\n ['python', 'train_dpo.py'],\n stdout=subprocess.PIPE,\n stderr=subprocess.STDOUT,\n text=True,\n bufsize=1\n)\n\n# Print output in real-time\nfor line in process.stdout:\n print(line, end='')\n\n# Wait for completion\nreturn_code = process.wait()\n\nprint(\"\\n\" + \"=\"*80)\nif return_code == 0:\n print(\"βœ“ Script completed successfully!\")\nelse:\n print(f\"βœ— Script failed with return code: {return_code}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "kk03ij6wpx", + "source": "# Alternative: Run the training directly in the notebook for immediate feedback\nimport os\nos.chdir('/Users/akseljoonas/Documents/hf-agent')\n\n# Execute the script\nexec(open('train_dpo.py').read())", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "58ilnz6pedu", + "source": "# Write the file directly\nimport os\nos.chdir('/Users/akseljoonas/Documents/hf-agent')\n\nwith open('train_dpo.py', 'w', encoding='utf-8') as f:\n f.write('\"\"\"DPO Training Script - Complete Implementation\"\"\"\\n')\n f.write('import torch\\n')\n f.write('from datasets import load_dataset\\n')\n f.write('from transformers import AutoModelForCausalLM, AutoTokenizer\\n')\n f.write('from trl import DPOTrainer, DPOConfig\\n\\n')\n f.write('print(\"=\"*80)\\n')\n f.write('print(\"DPO Training - End-to-End Validation\")\\n')\n f.write('print(\"=\"*80)\\n\\n')\n f.write('# Configuration\\n')\n f.write('MODEL_NAME = \"Qwen/Qwen2-0.5B-Instruct\"\\n')\n f.write('DATASET_NAME = \"trl-lib/ultrafeedback_binarized\"\\n')\n f.write('OUTPUT_DIR = \"./dpo_output\"\\n')\n f.write('MAX_STEPS = 10\\n')\n f.write('BATCH_SIZE = 2\\n\\n')\n f.write('print(f\"\\\\n[CONFIG] Model: {MODEL_NAME}\")\\n')\n f.write('print(f\"[CONFIG] Dataset: {DATASET_NAME}\")\\n')\n f.write('print(f\"[CONFIG] Max steps: {MAX_STEPS}\")\\n')\n f.write('print(f\"[CONFIG] Batch size: {BATCH_SIZE}\")\\n\\n')\n f.write('# Step 1: Load tokenizer\\n')\n f.write('print(\"\\\\n[1/6] Loading tokenizer...\")\\n')\n f.write('tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\\n')\n f.write('if tokenizer.pad_token is None:\\n')\n f.write(' tokenizer.pad_token = tokenizer.eos_token\\n')\n f.write('print(f\"βœ“ Tokenizer loaded\")\\n\\n')\n f.write('# Step 2: Load dataset\\n')\n f.write('print(\"\\\\n[2/6] Loading dataset...\")\\n')\n f.write('dataset = load_dataset(DATASET_NAME, split=\"train[:100]\")\\n')\n f.write('print(f\"βœ“ Dataset loaded: {len(dataset)} samples\")\\n\\n')\n f.write('# Step 3: Load model\\n')\n f.write('print(\"\\\\n[3/6] Loading model...\")\\n')\n f.write('model = AutoModelForCausalLM.from_pretrained(\\n')\n f.write(' MODEL_NAME,\\n')\n f.write(' torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,\\n')\n f.write(' device_map=\"auto\",\\n')\n f.write(')\\n')\n f.write('print(f\"βœ“ Model loaded: {model.num_parameters()/1e6:.1f}M parameters\")\\n\\n')\n f.write('# Step 4: Configure training\\n')\n f.write('print(\"\\\\n[4/6] Configuring DPO training...\")\\n')\n f.write('training_args = DPOConfig(\\n')\n f.write(' output_dir=OUTPUT_DIR,\\n')\n f.write(' max_steps=MAX_STEPS,\\n')\n f.write(' per_device_train_batch_size=BATCH_SIZE,\\n')\n f.write(' learning_rate=5e-7,\\n')\n f.write(' logging_steps=2,\\n')\n f.write(' save_steps=10,\\n')\n f.write(' beta=0.1,\\n')\n f.write(' fp16=torch.cuda.is_available(),\\n')\n f.write(' remove_unused_columns=False,\\n')\n f.write(' report_to=\"none\",\\n')\n f.write(')\\n')\n f.write('print(\"βœ“ Configuration created\")\\n\\n')\n f.write('# Step 5: Train\\n')\n f.write('print(\"\\\\n[5/6] Starting DPO training...\")\\n')\n f.write('print(\"-\"*80)\\n')\n f.write('trainer = DPOTrainer(\\n')\n f.write(' model=model,\\n')\n f.write(' args=training_args,\\n')\n f.write(' train_dataset=dataset,\\n')\n f.write(' tokenizer=tokenizer,\\n')\n f.write(')\\n')\n f.write('train_result = trainer.train()\\n')\n f.write('print(\"-\"*80)\\n')\n f.write('print(f\"βœ“ Training completed! Loss: {train_result.training_loss:.4f}\")\\n\\n')\n f.write('# Step 6: Save\\n')\n f.write('print(\"\\\\n[6/6] Saving model...\")\\n')\n f.write('trainer.save_model(OUTPUT_DIR)\\n')\n f.write('print(f\"βœ“ Model saved to {OUTPUT_DIR}\")\\n\\n')\n f.write('print(\"\\\\n\" + \"=\"*80)\\n')\n f.write('print(\"DPO TRAINING COMPLETED SUCCESSFULLY!\")\\n')\n f.write('print(\"=\"*80)\\n')\n f.write('print(f\"\\\\nOutput: {OUTPUT_DIR}\")\\n')\n f.write('print(f\"Steps: {train_result.global_step}\")\\n')\n f.write('print(f\"Final loss: {train_result.training_loss:.4f}\")\\n')\n\nprint(\"βœ“ train_dpo.py created!\")\nprint(f\"File location: {os.path.abspath('train_dpo.py')}\")\n\n# Verify file exists\nif os.path.exists('train_dpo.py'):\n file_size = os.path.getsize('train_dpo.py')\n print(f\"File size: {file_size} bytes\")\nelse:\n print(\"ERROR: File was not created!\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "7qvebak22u2", + "source": "## Next Steps: Run the Training Script\n\nThe `train_dpo.py` file has been created. To run it:\n\n1. **From Terminal/Command Line:**\n ```bash\n cd /Users/akseljoonas/Documents/hf-agent\n python train_dpo.py\n ```\n\n2. **Or run directly in this notebook** by executing the next cell below.", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19f3dd6b", + "metadata": {}, + "outputs": [], + "source": [ + "# Seed examples for task bootstrapping\n", + "tasks_with_difficulty = {\n", + " # lewis\n", + " \"Evaluate models {M_i} on benchmarks {B_i}\": \"Easy\",\n", + " \"Train models {M_i} on datasets {D_i} with benchmarks {B_i}\": \"Medium\",\n", + " \"Run an ablation for hyperparameter P for model M on dataset D\": \"Hard\",\n", + " \"Generate completions with model M on dataset D using engine E\": \"Medium\",\n", + " \"Merge models {M_i} using linear averaging to find the best result on benchmarks {B_i}\": \"Hard\",\n", + " \"Given datasets {D_i}, ablate the best SFT mixture for model M across benchmarks {B_i}\": \"Very hard\",\n", + " \"Decontaminate dataset D against benchmarks {B_i}\": \"Hard\",\n", + " \"Benchmark RL framework F for best throughput on G GPUs\": \"Very hard\",\n", + " \"Implement post-training algorithm A from paper P in framework F. Validate it runs end-to-end\": \"Very hard\",\n", + " \"Implement benchmark B in framework F. Validate it reproduces some published results\": \"Very hard\",\n", + " \"Format dataset D for compatibility with framework F on task T\": \"Easy\",\n", + "\n", + " # abubakar\n", + " \"Remove the background from this image: [image path]\": \"Easy\",\n", + " \"Transcribe all of the audio files in this directory\": \"Easy\",\n", + " \"Transcribe all of the audio files in this directory, choose the model that'll be cheapest and also relatively accurate\": \"Medium (judgment call or interaction needed to figure out what accuracy levels are acceptable)\",\n", + " \"Remove the background music from this audio file\": \"Medium (needs to find Gradio Space and call its API0\",\n", + " \"Change this video track to be from English to Spanish\": \"Medium (needs to link several models together)\",\n", + " \"Translate this flyer from English to Spanish, keeping the layout and images the same\": \"Medium (needs to link several models together)\",\n", + "\n", + " # leandro\n", + " \"What's the best model for X?\": \"Easy\",\n", + " \"What datasets are available for X? (X={domain x task x modality})\": \"Easy\",\n", + " \"Is there a space to do Y?\": \"Easy\",\n", + " \"I have this script and this error - what's the issue?\": \"Medium\",\n", + " \"This space is broken, how can i fix it?\": \"Medium\",\n", + " \"I built a space but it is super slow. What can I do?\": \"Medium\",\n", + " \"How can I run modal X locally?\": \"Medium\",\n", + " \"I want to build a space with model Y to do X?\": \"Hard\",\n", + " \"How can I serve a model with multiple LoRAs?\": \"Hard\",\n", + "\n", + " # claude\n", + " \"What's the best model for sentiment analysis on financial text?\": \"Easy\",\n", + " \"Are there any medical image segmentation datasets on HuggingFace for CT scans?\": \"Easy\",\n", + " \"Which text classification models support 4-bit quantization?\": \"Medium\",\n", + " \"Are there inference endpoints available for Whisper large-v3?\": \"Easy\",\n", + " \"What's the license for the SA-Med2D-20M dataset?\": \"Easy\",\n", + " \"Which vision models fit in 8GB VRAM for image segmentation?\": \"Medium\",\n", + " \"What datasets are available for 3D medical image segmentation?\": \"Medium\",\n", + " \"Is there a space to do text-to-speech with emotion control?\": \"Medium\",\n", + " \"I'm getting \\\"CUDA out of memory\\\" when loading Llama-2-7b even though nvidia-smi shows I have 6GB free - what's the issue?\": \"Medium\",\n", + " \"My Gradio space shows \\\"Connection errored out\\\" after working fine yesterday, no code changes - how can I fix it?\": \"Medium\",\n", + " \"I built a Gradio space for Stable Diffusion but inference takes 5+ minutes on a 4090 - what can I do?\": \"Medium\",\n", + " \"My Whisper model outputs different transcriptions after quantization to int8 - why?\": \"Medium\",\n", + " \"Getting \\\"RuntimeError: CUDA error: out of memory. Tried to allocate 70.00 MiB\\\" but only 2.87 GiB is allocated - what's happening?\": \"Medium\",\n", + " \"My HuggingFace space build fails with \\\"failed to create containerd task\\\" - how to fix?\": \"Medium\",\n", + " \"DistilBERT model gives \\\"you should probably train your model\\\" warning even though it's a pretrained model from the Hub\": \"Easy\",\n", + " \"Space was working fine but now receiving build errors - receiving this error even with a new space\": \"Medium\",\n", + " \"Inference is correct locally but wrong on deployed space\": \"Medium\",\n", + " \"Getting CUDA OOM despite having enough memory according to nvidia-smi\": \"Medium\",\n", + " \"How can I run Mistral-7B-v0.1 locally with multiple LoRA adapters?\": \"Hard\",\n", + " \"How can I serve Llama-2-7b with vLLM and dynamically load multiple LoRA adapters?\": \"Hard\",\n", + " \"How do I batch inference requests in my Gradio space for better throughput?\": \"Medium\",\n", + " \"Can I run Whisper large-v3 with faster-whisper for 4x speedup?\": \"Medium\",\n", + " \"How to run Llama 2 on CPU after fine-tuning with LoRA?\": \"Medium\",\n", + " \"Best way to handle 50+ concurrent requests in a Gradio space without OOM?\": \"Hard\",\n", + " \"How do I add custom stopping criteria for text generation with Transformers?\": \"Hard\",\n", + " \"Can I merge multiple LoRA adapters before inference to reduce latency?\": \"Hard\",\n", + " \"How can I optimize my LLM inference with one base LLM and multiple LoRA adapters?\": \"Hard\",\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7014bef", + "metadata": {}, + "outputs": [], + "source": [ + "len(tasks_with_difficulty)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a8bd7ed", + "metadata": {}, + "outputs": [], + "source": [ + "import litellm\n", + "import json\n", + "from pydantic import BaseModel\n", + "from enum import Enum\n", + "\n", + "\n", + "class Difficulty(str, Enum):\n", + " EASY = \"Easy\"\n", + " MEDIUM = \"Medium\"\n", + " HARD = \"Hard\"\n", + " VERY_HARD = \"Very hard\"\n", + "\n", + "\n", + "class Task(BaseModel):\n", + " description: str\n", + " difficulty: Difficulty\n", + "\n", + "\n", + "class GeneratedTasks(BaseModel):\n", + " tasks: list[Task]\n", + "\n", + "\n", + "def build_prompt(tasks_dict: dict[str, str]) -> str:\n", + " task_descriptions = \"\".join(\n", + " [f'- \"{task}\" [{difficulty}]\\n' for task, difficulty in tasks_dict.items()]\n", + " )\n", + "\n", + " return f\"\"\"Given the following examples of tasks (with their estimated difficulty levels in brackets):\n", + "\n", + "{task_descriptions}\n", + "\n", + "Generate exactly 10 new unique tasks with their difficulty levels (Easy, Medium, Hard, or Very hard).\n", + "The new tasks should be bootstrapped by analogy or creative mutation of the provided ones, but not be direct copies.\n", + "Vary the domains, instructions, and scenario details. Write crisp, concrete task phrasing. Preserve variety in both tasks and difficulties.\n", + "Do not repeat any of the input tasks verbatim. Create plausible, meaningful tasks relevant to LLM training, evaluation, dataprocessing, issue handling, tooling, etc.\n", + "\"\"\"\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85ef3dcb", + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"gpt-5\"\n", + "\n", + "# Number of iterations to generate tasks (10 tasks per iteration)\n", + "num_iterations = 20\n", + "\n", + "# Copy the seed tasks to avoid modifying the original\n", + "all_tasks = tasks_with_difficulty.copy()\n", + "\n", + "for i in range(num_iterations):\n", + " prompt = build_prompt(all_tasks)\n", + "\n", + " # Query LLM using litellm with structured output\n", + " response = litellm.completion(\n", + " model=model_name,\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are an expert at generating diverse ML/AI task instructions using products from HuggingFace and can enumerate them with proper difficulty.\",\n", + " },\n", + " {\"role\": \"user\", \"content\": prompt},\n", + " ],\n", + " response_format=GeneratedTasks,\n", + " )\n", + "\n", + " # Parse the structured output\n", + " generated = GeneratedTasks.model_validate_json(\n", + " response.choices[0].message.content\n", + " )\n", + "\n", + " # Add new tasks to the dictionary\n", + " new_count = 0\n", + " for task in generated.tasks:\n", + " if task.description not in all_tasks:\n", + " all_tasks[task.description] = task.difficulty.value\n", + " new_count += 1\n", + "\n", + " print(f\"Iteration {i + 1}/{num_iterations}: Added {new_count} new tasks. Total: {len(all_tasks)}\")\n", + "\n", + "# Save to disk\n", + "with open(\"generated_tasks_with_difficulty.json\", \"w\") as f:\n", + " json.dump(all_tasks, f, indent=2)\n", + "\n", + "print(f\"\\nFinal task count: {len(all_tasks)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c0ad570", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import Dataset\n", + "\n", + "# Convert dict to proper columns\n", + "questions = list(all_tasks.keys())\n", + "difficulties = list(all_tasks.values())\n", + "data = {\"question\": questions, \"difficulty\": difficulties}\n", + "\n", + "dataset = Dataset.from_dict(data)\n", + "print(f\"\\nDataset: {len(dataset)} rows\")\n", + "print(f\"Sample: {dataset[0]['question']} ({dataset[0]['difficulty']})\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "427a2186", + "metadata": {}, + "outputs": [], + "source": [ + "dataset.push_to_hub(\"akseljoonas/benchmark-tasks\", private=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "204b9760", + "metadata": {}, + "outputs": [], + "source": [ + "all_tasks = json.load(open(\"generated_tasks_with_difficulty.json\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50e67652", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract variables from each question using LLM\n", + "\n", + "class ExtractedVariables(BaseModel):\n", + " variables: list[str] # List of variable names/placeholders found in the question\n", + "\n", + "\n", + "def extract_variables_prompt(question: str) -> str:\n", + " return f\"\"\"Analyze this task description and list any variables or placeholders that would need to be filled in with specific values. This is a AI/ML/LLM task, so the variables are typically model names, dataset names, hyperparameter names, etc.\n", + "\n", + "Task: \"{question}\"\n", + "\n", + "Variables are typically indicated by:\n", + "- Curly braces like {{M_i}}, {{D_i}}, {{B_i}}\n", + "- Single letters representing placeholders like \"model M\", \"dataset D\", \"hyperparameter P\"\n", + "- Bracketed placeholders like [image path]\n", + "- Generic references like \"X\", \"Y\" that stand for specific values\n", + "\n", + "Examples of tasks with variables:\n", + "\n", + " \"Evaluate models {{M_i}} on benchmarks {{B_i}}\" -> variables: [\"M_i\", \"B_i\"]\n", + " \"Train models {{M_i}} on datasets {{D_i}} with benchmarks {{B_i}}\" -> variables: [\"M_i\", \"D_i\", \"B_i\"]\n", + " \"Run an ablation for hyperparameter P for model M on dataset D\" -> variables: [\"P\", \"M\", \"D\"]\n", + " \"Generate completions with model M on dataset D using engine E\" -> variables: [\"M\", \"D\", \"E\"]\n", + " \"Merge models {{M_i}} using linear averaging to find the best result on benchmarks {{B_i}}\" -> variables: [\"M_i\", \"B_i\"]\n", + " \"Given datasets {{D_i}}, ablate the best SFT mixture for model M across benchmarks {{B_i}}\" -> variables: [\"D_i\", \"M\", \"B_i\"]\n", + " \"Decontaminate dataset D against benchmarks {{B_i}}\" -> variables: [\"D\", \"B_i\"]\n", + " \"Benchmark RL framework F for best throughput on G GPUs\" -> variables: [\"F\", \"G\"]\n", + " \"Implement post-training algorithm A from paper P in framework F. Validate it runs end-to-end\" -> variables: [\"A\", \"P\", \"F\"]\n", + " \"Implement benchmark B in framework F. Validate it reproduces some published results\" -> variables: [\"B\", \"F\"]\n", + " \"Format dataset D for compatibility with framework F on task T\" -> variables: [\"D\", \"F\", \"T\"]\n", + " \"Remove the background from this image: [image path]\" -> variables: [\"[image path]\"]\n", + " \"Are there any medical image segmentation datasets on HuggingFace for CT scans?\" -> variables: []\n", + " \"Build a sharded FAISS IVF-PQ index for 100M embeddings stored on S3; integrate with HF datasets streaming and report recall@10 and QPS\" -> variables: []\n", + "\n", + "\n", + "Return an empty list if the question is fully concrete with no variables.\n", + "Only return the variable names/symbols, not their descriptions.\"\"\"\n", + "\n", + "\n", + "# Run extraction for each question in parallel\n", + "from concurrent.futures import ThreadPoolExecutor, as_completed\n", + "\n", + "variable_model = \"gpt-5-mini\"\n", + "\n", + "\n", + "def extract_variables_for_task(question: str, difficulty: str) -> dict:\n", + " \"\"\"Extract variables for a single task and return the record.\"\"\"\n", + " response = litellm.completion(\n", + " model=variable_model,\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are an expert at identifying placeholder variables in task descriptions.\",\n", + " },\n", + " {\"role\": \"user\", \"content\": extract_variables_prompt(question)},\n", + " ],\n", + " response_format=ExtractedVariables,\n", + " )\n", + "\n", + " extracted = ExtractedVariables.model_validate_json(\n", + " response.choices[0].message.content\n", + " )\n", + "\n", + " return {\n", + " \"question\": question,\n", + " \"difficulty\": difficulty,\n", + " \"var_list\": extracted.variables,\n", + " }\n", + "\n", + "\n", + "# Run in parallel with 100 workers\n", + "tasks_with_metadata: list[dict] = []\n", + "all_variables: set[str] = set()\n", + "questions_with_vars: dict[str, list[str]] = {}\n", + "\n", + "with ThreadPoolExecutor(max_workers=100) as executor:\n", + " futures = {\n", + " executor.submit(extract_variables_for_task, q, d): q\n", + " for q, d in all_tasks.items()\n", + " }\n", + "\n", + " for future in as_completed(futures):\n", + " record = future.result()\n", + " tasks_with_metadata.append(record)\n", + "\n", + " if record[\"var_list\"]:\n", + " questions_with_vars[record[\"question\"]] = record[\"var_list\"]\n", + " all_variables.update(record[\"var_list\"])\n", + "\n", + " print(f\"Processed {len(tasks_with_metadata)} tasks\")\n", + "\n", + "# Save to JSONL\n", + "with open(\"tasks_with_variables.jsonl\", \"w\") as f:\n", + " for record in tasks_with_metadata:\n", + " f.write(json.dumps(record) + \"\\n\")\n", + "\n", + "print(f\"Saved {len(tasks_with_metadata)} tasks to tasks_with_variables.jsonl\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "548f1bf0", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Questions with variables: {len(questions_with_vars)} / {len(all_tasks)}\")\n", + "print(f\"\\nUnique variables found ({len(all_variables)}):\")\n", + "for var in sorted(all_variables):\n", + " print(f\" - {var}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3cef6645", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 250 tasks\n", + "Questions with variables: 111 / 250\n", + "\n", + "Unique variables found (29):\n", + " - A\n", + " - A_i\n", + " - B\n", + " - B_i\n", + " - C\n", + " - D\n", + " - D_i\n", + " - E\n", + " - F\n", + " - G\n", + " - M\n", + " - M0\n", + " - M_i\n", + " - N\n", + " - P\n", + " - R\n", + " - R_i\n", + " - S\n", + " - T\n", + " - T_i\n", + " - X\n", + " - Y\n", + " - [audio file]\n", + " - [directory]\n", + " - [image path]\n", + " - baseline\n", + " - domain\n", + " - modality\n", + " - task\n" + ] + } + ], + "source": [ + "# Load verified tasks and print all variables\n", + "with open(\"tasks_with_variables.jsonl\", \"r\") as f:\n", + " verified_tasks = [json.loads(line) for line in f]\n", + "\n", + "all_variables = set()\n", + "questions_with_vars = {}\n", + "\n", + "for task in verified_tasks:\n", + " if task[\"var_list\"]:\n", + " questions_with_vars[task[\"question\"]] = task[\"var_list\"]\n", + " all_variables.update(task[\"var_list\"])\n", + "\n", + "print(f\"Loaded {len(verified_tasks)} tasks\")\n", + "print(f\"Questions with variables: {len(questions_with_vars)} / {len(verified_tasks)}\")\n", + "print(f\"\\nUnique variables found ({len(all_variables)}):\")\n", + "for var in sorted(all_variables):\n", + " print(f\" - {var}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca774044", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Filling variables: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 250/250 [21:21<00:00, 5.13s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved 250 tasks to filled_tasks.jsonl\n", + "Tasks that had variables filled: 111\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import asyncio\n", + "import os\n", + "from claude_agent_sdk import (\n", + " query,\n", + " ClaudeAgentOptions,\n", + " AssistantMessage,\n", + " ResultMessage,\n", + " TextBlock,\n", + ")\n", + "\n", + "\n", + "def build_fill_prompt(task: dict) -> str:\n", + " vars_str = \", \".join(task[\"var_list\"])\n", + " return f\"\"\"You have access to HuggingFace tools via MCP. Use them to find real, concrete values to fill in the variables in this task.\n", + "\n", + "Task template: \"{task[\"question\"]}\"\n", + "Variables to fill: {vars_str}\n", + "\n", + "Search HuggingFace for real models, datasets, benchmarks, frameworks, etc. that would make this task concrete and executable.\n", + "Pick the most popular, well-known resources (models etc) when possible.\n", + "\n", + "Return ONLY the filled question in the end with variables replaced by concrete values. No JSON, no explanation, just the filled question.\n", + "\n", + "Example:\n", + "Task: \"Evaluate models {{M_i}} on benchmarks {{B_i}}\"\n", + "Variables: M_i, B_i\n", + "Response: Evaluate models Qwen/Qwen3-4B-Instruct-2507, mistralai/Devstral-Small-2-24B-Instruct-2512 on benchmarks hellaswag, google/frames-benchmark\n", + "\"\"\"\n", + "\n", + "\n", + "# Semaphore to limit concurrent processes\n", + "MAX_CONCURRENT = 5\n", + "semaphore = asyncio.Semaphore(MAX_CONCURRENT)\n", + "\n", + "\n", + "async def fill_task_variables(task: dict) -> dict:\n", + " \"\"\"Use Claude Agent SDK to fill in variables for a single task.\"\"\"\n", + " if not task[\"var_list\"]:\n", + " return task.copy()\n", + "\n", + " async with semaphore:\n", + " prompt = build_fill_prompt(task)\n", + " filled_question = None\n", + " all_messages = []\n", + "\n", + " async for message in query(\n", + " prompt=prompt,\n", + " options=ClaudeAgentOptions(\n", + " cwd=os.getcwd(),\n", + " permission_mode=\"bypassPermissions\",\n", + " disallowed_tools=[\n", + " \"Write\", \"Edit\", \"Bash\", \"Glob\", \"Grep\"\n", + " \n", + " ],\n", + " ),\n", + " ):\n", + " all_messages.append(message)\n", + "\n", + " # Extract text from assistant messages\n", + " if isinstance(message, AssistantMessage):\n", + " for block in message.content:\n", + " if isinstance(block, TextBlock):\n", + " filled_question = block.text\n", + " # Check for result messages\n", + " elif isinstance(message, ResultMessage):\n", + " if message.is_error:\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(f\"ERROR for task: {task['question']}\")\n", + " print(f\"Error subtype: {message.subtype}\")\n", + " print(\"\\nFull messages:\")\n", + " for msg in all_messages:\n", + " print(f\" {msg}\")\n", + " print(\"=\" * 80)\n", + " raise RuntimeError(f\"Agent error: {message.subtype}\")\n", + " elif message.result:\n", + " filled_question = message.result\n", + "\n", + " # Use filled question or fall back to original\n", + " if filled_question:\n", + " filled_question = filled_question.strip()\n", + " else:\n", + " filled_question = task[\"question\"]\n", + "\n", + " return {\n", + " \"question\": filled_question,\n", + " \"difficulty\": task[\"difficulty\"],\n", + " \"var_list\": task[\"var_list\"],\n", + " }\n", + "\n", + "\n", + "# Run all tasks in parallel with tqdm progress\n", + "from tqdm.asyncio import tqdm_asyncio\n", + "\n", + "\n", + "async def fill_all_tasks_parallel(tasks: list[dict]) -> list[dict]:\n", + " \"\"\"Fill all tasks with limited concurrency and progress bar.\"\"\"\n", + " coros = [fill_task_variables(t) for t in tasks]\n", + " return await tqdm_asyncio.gather(*coros, desc=\"Filling variables\")\n", + "\n", + "\n", + "# Process all tasks (with and without variables)\n", + "filled_tasks = await fill_all_tasks_parallel(verified_tasks)\n", + "\n", + "# Save to JSONL (same structure: question, difficulty, var_list)\n", + "with open(\"filled_tasks.jsonl\", \"w\") as f:\n", + " for task in filled_tasks:\n", + " f.write(json.dumps(task) + \"\\n\")\n", + "\n", + "tasks_with_vars_count = sum(1 for t in verified_tasks if t[\"var_list\"])\n", + "print(f\"Saved {len(filled_tasks)} tasks to filled_tasks.jsonl\")\n", + "print(f\"Tasks that had variables filled: {tasks_with_vars_count}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44c4e671", + "metadata": {}, + "outputs": [], + "source": "from pathlib import Path\n\nfuse_lora_content = r'''#!/usr/bin/env python3\n\"\"\"\nLoRA Fusion and Verification Script\n\nThis script:\n1. Loads a base model (Llama-2-7b-hf) and LoRA adapter (alpaca-lora-7b)\n2. Merges/fuses the LoRA weights into the base model\n3. Exports the fused model as safetensors format\n4. Verifies logits parity between on-the-fly LoRA and fused model\n5. Reports detailed metrics (MSE, max absolute difference, relative error)\n\"\"\"\n\nimport os\nimport torch\nimport numpy as np\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom peft import PeftModel\nimport gc\n\n\ndef print_section(title):\n \"\"\"Print a formatted section header\"\"\"\n print(\"\\n\" + \"=\"*80)\n print(f\" {title}\")\n print(\"=\"*80 + \"\\n\")\n\n\ndef free_memory():\n \"\"\"Free up GPU memory\"\"\"\n gc.collect()\n torch.cuda.empty_cache()\n\n\ndef load_models(base_model_name, lora_adapter_name):\n \"\"\"\n Load base model and LoRA adapter model\n \n Args:\n base_model_name: HuggingFace model ID for base model\n lora_adapter_name: HuggingFace model ID for LoRA adapter\n \n Returns:\n tuple: (lora_model, tokenizer)\n \"\"\"\n print_section(\"Loading Base Model and LoRA Adapter\")\n \n print(f\"Loading base model: {base_model_name}\")\n print(\"Using torch.float16 for memory efficiency...\")\n \n base_model = AutoModelForCausalLM.from_pretrained(\n base_model_name,\n torch_dtype=torch.float16,\n device_map=\"auto\",\n trust_remote_code=True\n )\n \n print(f\"Base model loaded successfully\")\n print(f\" - Model type: {type(base_model).__name__}\")\n print(f\" - Device map: {base_model.hf_device_map}\")\n \n print(f\"\\nLoading LoRA adapter: {lora_adapter_name}\")\n \n lora_model = PeftModel.from_pretrained(\n base_model,\n lora_adapter_name,\n torch_dtype=torch.float16,\n )\n \n print(f\"LoRA adapter loaded successfully\")\n print(f\" - Adapter type: {type(lora_model).__name__}\")\n \n print(f\"\\nLoading tokenizer from: {base_model_name}\")\n tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n \n # Set pad token if not present\n if tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n print(\" - Set pad_token to eos_token\")\n \n print(f\"Tokenizer loaded successfully\")\n \n return lora_model, tokenizer\n\n\ndef merge_and_export(lora_model, output_dir):\n \"\"\"\n Merge LoRA weights into base model and export as safetensors\n \n Args:\n lora_model: PEFT model with LoRA adapter\n output_dir: Directory to save the fused model\n \n Returns:\n merged_model: The fused model\n \"\"\"\n print_section(\"Merging LoRA Weights into Base Model\")\n \n print(\"Calling merge_and_unload()...\")\n merged_model = lora_model.merge_and_unload()\n \n print(\"LoRA weights successfully merged into base model\")\n print(f\" - Merged model type: {type(merged_model).__name__}\")\n \n print(f\"\\nExporting fused model to: {output_dir}\")\n print(\"Format: safetensors (safe_serialization=True)\")\n \n # Create output directory if it doesn't exist\n os.makedirs(output_dir, exist_ok=True)\n \n # Save the merged model\n merged_model.save_pretrained(\n output_dir,\n safe_serialization=True,\n max_shard_size=\"5GB\"\n )\n \n print(f\"Model successfully saved to {output_dir}\")\n \n # Also save the tokenizer\n tokenizer = lora_model.tokenizer if hasattr(lora_model, 'tokenizer') else None\n if tokenizer:\n tokenizer.save_pretrained(output_dir)\n print(f\"Tokenizer also saved to {output_dir}\")\n \n return merged_model\n\n\ndef generate_logits(model, tokenizer, prompt, max_length=50):\n \"\"\"\n Generate logits for a given prompt\n \n Args:\n model: The model to use for generation\n tokenizer: Tokenizer for encoding the prompt\n prompt: Text prompt\n max_length: Maximum sequence length\n \n Returns:\n torch.Tensor: Logits from the model\n \"\"\"\n # Tokenize input\n inputs = tokenizer(prompt, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_length)\n \n # Move inputs to the same device as model\n device = next(model.parameters()).device\n inputs = {k: v.to(device) for k, v in inputs.items()}\n \n # Generate logits\n with torch.no_grad():\n outputs = model(**inputs)\n logits = outputs.logits\n \n return logits\n\n\ndef calculate_metrics(logits1, logits2):\n \"\"\"\n Calculate metrics between two sets of logits\n \n Args:\n logits1: First set of logits\n logits2: Second set of logits\n \n Returns:\n dict: Dictionary containing various metrics\n \"\"\"\n # Convert to numpy for easier computation\n logits1_np = logits1.cpu().float().numpy()\n logits2_np = logits2.cpu().float().numpy()\n \n # Calculate metrics\n mse = np.mean((logits1_np - logits2_np) ** 2)\n mae = np.mean(np.abs(logits1_np - logits2_np))\n max_abs_diff = np.max(np.abs(logits1_np - logits2_np))\n \n # Relative error (avoid division by zero)\n epsilon = 1e-8\n relative_error = np.mean(np.abs(logits1_np - logits2_np) / (np.abs(logits1_np) + epsilon))\n \n # Cosine similarity (flatten the tensors)\n flat1 = logits1_np.flatten()\n flat2 = logits2_np.flatten()\n cosine_sim = np.dot(flat1, flat2) / (np.linalg.norm(flat1) * np.linalg.norm(flat2))\n \n return {\n 'mse': mse,\n 'mae': mae,\n 'max_abs_diff': max_abs_diff,\n 'relative_error': relative_error,\n 'cosine_similarity': cosine_sim\n }\n\n\ndef verify_logits_parity(lora_model, fused_model, tokenizer, test_prompts):\n \"\"\"\n Verify that logits from LoRA model match fused model\n \n Args:\n lora_model: Model with LoRA adapter applied on-the-fly\n fused_model: Model with merged LoRA weights\n tokenizer: Tokenizer for encoding prompts\n test_prompts: List of test prompts\n \n Returns:\n bool: True if all tests pass (MSE < 1e-5)\n \"\"\"\n print_section(\"Verifying Logits Parity\")\n \n all_passed = True\n results = []\n \n for i, prompt in enumerate(test_prompts, 1):\n print(f\"\\nTest {i}/{len(test_prompts)}\")\n print(f\"Prompt: {prompt[:100]}...\" if len(prompt) > 100 else f\"Prompt: {prompt}\")\n print(\"-\" * 80)\n \n # Generate logits from both models\n print(\"Generating logits from LoRA model (on-the-fly)...\")\n lora_logits = generate_logits(lora_model, tokenizer, prompt)\n \n print(\"Generating logits from fused model...\")\n fused_logits = generate_logits(fused_model, tokenizer, prompt)\n \n # Calculate metrics\n metrics = calculate_metrics(lora_logits, fused_logits)\n results.append(metrics)\n \n # Print results\n print(\"\\nMetrics:\")\n print(f\" MSE (Mean Squared Error): {metrics['mse']:.2e}\")\n print(f\" MAE (Mean Absolute Error): {metrics['mae']:.2e}\")\n print(f\" Max Absolute Difference: {metrics['max_abs_diff']:.2e}\")\n print(f\" Relative Error: {metrics['relative_error']:.2e}\")\n print(f\" Cosine Similarity: {metrics['cosine_similarity']:.6f}\")\n \n # Check if MSE is below threshold\n threshold = 1e-5\n passed = metrics['mse'] < threshold\n \n status = \"PASS\" if passed else \"FAIL\"\n print(f\"\\nStatus: {status} (MSE < {threshold}: {metrics['mse']:.2e} < {threshold})\")\n \n if not passed:\n all_passed = False\n \n # Print summary\n print_section(\"Summary\")\n \n avg_mse = np.mean([r['mse'] for r in results])\n avg_mae = np.mean([r['mae'] for r in results])\n max_abs_diff_overall = np.max([r['max_abs_diff'] for r in results])\n avg_relative_error = np.mean([r['relative_error'] for r in results])\n avg_cosine_sim = np.mean([r['cosine_similarity'] for r in results])\n \n print(f\"Tests run: {len(test_prompts)}\")\n print(f\"\\nAverage Metrics Across All Tests:\")\n print(f\" Average MSE: {avg_mse:.2e}\")\n print(f\" Average MAE: {avg_mae:.2e}\")\n print(f\" Maximum Absolute Difference: {max_abs_diff_overall:.2e}\")\n print(f\" Average Relative Error: {avg_relative_error:.2e}\")\n print(f\" Average Cosine Similarity: {avg_cosine_sim:.6f}\")\n \n print(f\"\\nOverall Result: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}\")\n \n return all_passed\n\n\ndef format_alpaca_prompt(instruction, input_text=\"\"):\n \"\"\"\n Format prompt in Alpaca instruction format\n \n Args:\n instruction: The instruction text\n input_text: Optional input context\n \n Returns:\n str: Formatted prompt\n \"\"\"\n if input_text:\n return f\"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n\"\"\"\n else:\n return f\"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n\"\"\"\n\n\ndef main():\n \"\"\"Main execution function\"\"\"\n print_section(\"LoRA Fusion and Verification Pipeline\")\n \n # Configuration\n base_model_name = \"meta-llama/Llama-2-7b-hf\"\n lora_adapter_name = \"tloen/alpaca-lora-7b\"\n output_dir = \"./alpaca-llama2-7b-fused\"\n \n print(\"Configuration:\")\n print(f\" Base Model: {base_model_name}\")\n print(f\" LoRA Adapter: {lora_adapter_name}\")\n print(f\" Output Directory: {output_dir}\")\n print(f\" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}\")\n print(f\" PyTorch Version: {torch.__version__}\")\n \n # Step 1: Load models\n lora_model, tokenizer = load_models(base_model_name, lora_adapter_name)\n \n # Step 2: Merge and export\n fused_model = merge_and_export(lora_model, output_dir)\n \n # Step 3: Prepare test prompts\n test_prompts = [\n # Test 1: Simple Alpaca instruction\n format_alpaca_prompt(\"Tell me about alpacas.\"),\n \n # Test 2: Alpaca instruction with input\n format_alpaca_prompt(\n \"Summarize the following text.\",\n \"Alpacas are domesticated South American camelids. They are raised for their soft fleece and are known for their gentle temperament.\"\n ),\n \n # Test 3: Complex instruction\n format_alpaca_prompt(\"Write a Python function that calculates the fibonacci sequence.\"),\n \n # Test 4: Simple question (non-Alpaca format for variety)\n \"What is the capital of France?\",\n \n # Test 5: Code generation\n format_alpaca_prompt(\"Explain what machine learning is in simple terms.\")\n ]\n \n print(f\"\\nPrepared {len(test_prompts)} test prompts\")\n \n # Step 4: Verify logits parity\n all_passed = verify_logits_parity(lora_model, fused_model, tokenizer, test_prompts)\n \n # Final summary\n print_section(\"Pipeline Complete\")\n \n print(f\"Fused model saved to: {os.path.abspath(output_dir)}\")\n print(f\"Format: safetensors\")\n print(f\"Verification: {'SUCCESS - All tests passed' if all_passed else 'FAILED - Some tests did not pass'}\")\n \n if all_passed:\n print(\"\\nThe fused model produces identical logits to the on-the-fly LoRA application.\")\n print(\"You can safely use the fused model as a drop-in replacement.\")\n else:\n print(\"\\nWARNING: The fused model does not produce identical logits.\")\n print(\"Please review the metrics above to understand the discrepancies.\")\n \n return 0 if all_passed else 1\n\n\nif __name__ == \"__main__\":\n import sys\n exit_code = main()\n sys.exit(exit_code)\n'''\n\n# Write to /tmp/fuse_lora.py\nPath('/tmp/fuse_lora.py').write_text(fuse_lora_content)\nprint(\"βœ“ Successfully created /tmp/fuse_lora.py\")\n" + }, + { + "cell_type": "code", + "id": "lm4uok5rtr", + "source": "from pathlib import Path\n\nfilter_toxic_content = r'''#!/usr/bin/env python3\n\"\"\"\nFilter Toxic Dataset Script\n\nThis script:\n1. Loads the lmsys/toxic-chat dataset (toxicchat0124 version)\n2. Loads the unitary/toxic-bert classifier model\n3. Runs inference on all examples to classify toxicity\n4. Logs detailed per-label removal statistics\n5. Filters out toxic content (using 0.5 threshold)\n6. Creates stratified train/validation/test splits (70/15/15)\n7. Saves the filtered dataset and generates a comprehensive JSON report\n\"\"\"\n\nimport json\nimport logging\nfrom collections import defaultdict\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Dict, List, Tuple\n\nimport numpy as np\nimport torch\nfrom datasets import Dataset, DatasetDict, load_dataset\nfrom sklearn.model_selection import train_test_split\nfrom tqdm import tqdm\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer\n\n# Configure logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s - %(levelname)s - %(message)s\",\n handlers=[\n logging.FileHandler(\"filter_toxic_dataset.log\"),\n logging.StreamHandler()\n ]\n)\nlogger = logging.getLogger(__name__)\n\n# Toxic-BERT label indices\nTOXIC_LABELS = {\n 0: \"toxic\",\n 1: \"severe_toxic\",\n 2: \"obscene\",\n 3: \"threat\",\n 4: \"insult\",\n 5: \"identity_hate\"\n}\n\nclass ToxicityFilter:\n \"\"\"Main class for filtering toxic content from datasets.\"\"\"\n \n def __init__(\n self,\n model_name: str = \"unitary/toxic-bert\",\n threshold: float = 0.5,\n batch_size: int = 32,\n device: str = None\n ):\n \"\"\"Initialize the toxicity filter.\"\"\"\n self.model_name = model_name\n self.threshold = threshold\n self.batch_size = batch_size\n self.device = device or (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n \n logger.info(f\"Initializing ToxicityFilter with model: {model_name}\")\n logger.info(f\"Device: {self.device}, Batch size: {batch_size}, Threshold: {threshold}\")\n \n # Load model and tokenizer\n self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n self.model = AutoModelForSequenceClassification.from_pretrained(model_name)\n self.model.to(self.device)\n self.model.eval()\n \n # Statistics tracking\n self.stats = {\n \"total_examples\": 0,\n \"filtered_examples\": 0,\n \"kept_examples\": 0,\n \"label_stats\": {label: {\"count\": 0, \"removed\": 0} for label in TOXIC_LABELS.values()},\n \"threshold\": threshold,\n \"model\": model_name,\n \"device\": self.device\n }\n \n logger.info(\"Model loaded successfully\")\n \n def classify_batch(self, texts: List[str]) -> Tuple[np.ndarray, np.ndarray]:\n \"\"\"Classify a batch of texts for toxicity.\"\"\"\n # Tokenize\n inputs = self.tokenizer(\n texts,\n padding=True,\n truncation=True,\n max_length=512,\n return_tensors=\"pt\"\n )\n inputs = {k: v.to(self.device) for k, v in inputs.items()}\n \n # Inference\n with torch.no_grad():\n outputs = self.model(**inputs)\n probabilities = torch.sigmoid(outputs.logits).cpu().numpy()\n \n # Determine if any label exceeds threshold\n predictions = (probabilities > self.threshold).any(axis=1)\n \n return predictions, probabilities\n \n def process_dataset(\n self,\n dataset: Dataset,\n text_column: str = \"user_input\"\n ) -> Tuple[Dataset, Dataset, Dict]:\n \"\"\"Process dataset and filter toxic content.\"\"\"\n logger.info(f\"Processing dataset with {len(dataset)} examples\")\n \n self.stats[\"total_examples\"] = len(dataset)\n \n # Storage for results\n all_predictions = []\n all_probabilities = []\n \n # Process in batches with progress bar\n num_batches = (len(dataset) + self.batch_size - 1) // self.batch_size\n \n for i in tqdm(range(0, len(dataset), self.batch_size), \n desc=\"Classifying toxicity\", \n total=num_batches):\n batch_texts = dataset[text_column][i:i + self.batch_size]\n predictions, probabilities = self.classify_batch(batch_texts)\n \n all_predictions.extend(predictions)\n all_probabilities.extend(probabilities)\n \n # Convert to numpy arrays\n all_predictions = np.array(all_predictions)\n all_probabilities = np.array(all_probabilities)\n \n # Calculate per-label statistics\n for label_idx, label_name in TOXIC_LABELS.items():\n label_probs = all_probabilities[:, label_idx]\n toxic_for_label = label_probs > self.threshold\n \n self.stats[\"label_stats\"][label_name][\"count\"] = int(toxic_for_label.sum())\n self.stats[\"label_stats\"][label_name][\"removal_rate\"] = float(\n toxic_for_label.sum() / len(dataset)\n )\n \n logger.info(\n f\"Label '{label_name}': {toxic_for_label.sum()} examples \"\n f\"({toxic_for_label.sum() / len(dataset) * 100:.2f}%) exceed threshold\"\n )\n \n # Add predictions and probabilities to dataset\n dataset_with_scores = dataset.add_column(\"is_toxic\", all_predictions.tolist())\n \n # Add individual label probabilities\n for label_idx, label_name in TOXIC_LABELS.items():\n dataset_with_scores = dataset_with_scores.add_column(\n f\"prob_{label_name}\",\n all_probabilities[:, label_idx].tolist()\n )\n \n # Split into filtered (clean) and toxic datasets\n filtered_dataset = dataset_with_scores.filter(lambda x: not x[\"is_toxic\"])\n toxic_dataset = dataset_with_scores.filter(lambda x: x[\"is_toxic\"])\n \n self.stats[\"filtered_examples\"] = len(toxic_dataset)\n self.stats[\"kept_examples\"] = len(filtered_dataset)\n self.stats[\"filter_rate\"] = self.stats[\"filtered_examples\"] / self.stats[\"total_examples\"]\n \n logger.info(f\"Filtered {len(toxic_dataset)} toxic examples ({self.stats['filter_rate']*100:.2f}%)\")\n logger.info(f\"Kept {len(filtered_dataset)} clean examples\")\n \n return filtered_dataset, toxic_dataset, self.stats\n \n def create_stratified_splits(\n self,\n dataset: Dataset,\n train_size: float = 0.7,\n val_size: float = 0.15,\n test_size: float = 0.15,\n stratify_column: str = None,\n random_state: int = 42\n ) -> DatasetDict:\n \"\"\"Create stratified train/validation/test splits.\"\"\"\n assert abs(train_size + val_size + test_size - 1.0) < 1e-6, \"Split sizes must sum to 1.0\"\n \n logger.info(f\"Creating stratified splits: train={train_size}, val={val_size}, test={test_size}\")\n \n # Convert to pandas for sklearn\n df = dataset.to_pandas()\n \n # Prepare stratification column if specified\n stratify = None\n if stratify_column and stratify_column in df.columns:\n stratify = df[stratify_column]\n logger.info(f\"Stratifying on column: {stratify_column}\")\n \n # First split: train vs (val + test)\n train_df, temp_df = train_test_split(\n df,\n train_size=train_size,\n random_state=random_state,\n stratify=stratify\n )\n \n # Second split: val vs test\n val_ratio = val_size / (val_size + test_size)\n val_stratify = None\n if stratify is not None:\n val_stratify = temp_df[stratify_column]\n \n val_df, test_df = train_test_split(\n temp_df,\n train_size=val_ratio,\n random_state=random_state,\n stratify=val_stratify\n )\n \n # Convert back to datasets\n dataset_dict = DatasetDict({\n \"train\": Dataset.from_pandas(train_df, preserve_index=False),\n \"validation\": Dataset.from_pandas(val_df, preserve_index=False),\n \"test\": Dataset.from_pandas(test_df, preserve_index=False)\n })\n \n # Log split sizes\n logger.info(f\"Split sizes:\")\n logger.info(f\" Train: {len(dataset_dict['train'])} ({len(dataset_dict['train'])/len(dataset)*100:.2f}%)\")\n logger.info(f\" Validation: {len(dataset_dict['validation'])} ({len(dataset_dict['validation'])/len(dataset)*100:.2f}%)\")\n logger.info(f\" Test: {len(dataset_dict['test'])} ({len(dataset_dict['test'])/len(dataset)*100:.2f}%)\")\n \n # Verify stratification if applicable\n if stratify_column and stratify_column in df.columns:\n logger.info(\"Verifying stratification:\")\n \n for split_name in [\"train\", \"validation\", \"test\"]:\n split_df = dataset_dict[split_name].to_pandas()\n split_dist = split_df[stratify_column].value_counts(normalize=True).sort_index()\n logger.info(f\" {split_name} distribution: {split_dist.to_dict()}\")\n \n return dataset_dict\n\n\ndef main():\n \"\"\"Main execution function.\"\"\"\n \n # Configuration\n DATASET_NAME = \"lmsys/toxic-chat\"\n DATASET_CONFIG = \"toxicchat0124\"\n MODEL_NAME = \"unitary/toxic-bert\"\n THRESHOLD = 0.5\n BATCH_SIZE = 32\n OUTPUT_DIR = Path(\"./filtered_toxic_chat\")\n REPORT_PATH = OUTPUT_DIR / \"filtering_report.json\"\n \n # Create output directory\n OUTPUT_DIR.mkdir(exist_ok=True)\n \n logger.info(\"=\"*80)\n logger.info(\"Starting Toxic Dataset Filtering Pipeline\")\n logger.info(\"=\"*80)\n logger.info(f\"Dataset: {DATASET_NAME} ({DATASET_CONFIG})\")\n logger.info(f\"Model: {MODEL_NAME}\")\n logger.info(f\"Threshold: {THRESHOLD}\")\n logger.info(f\"Output directory: {OUTPUT_DIR}\")\n \n # Step 1: Load dataset\n logger.info(\"\\n[Step 1/6] Loading dataset...\")\n try:\n dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=\"train\")\n logger.info(f\"Loaded {len(dataset)} examples\")\n logger.info(f\"Dataset columns: {dataset.column_names}\")\n except Exception as e:\n logger.error(f\"Failed to load dataset: {e}\")\n raise\n \n # Step 2: Initialize filter\n logger.info(\"\\n[Step 2/6] Initializing toxicity filter...\")\n filter_obj = ToxicityFilter(\n model_name=MODEL_NAME,\n threshold=THRESHOLD,\n batch_size=BATCH_SIZE\n )\n \n # Step 3: Process dataset\n logger.info(\"\\n[Step 3/6] Processing dataset and classifying toxicity...\")\n filtered_dataset, toxic_dataset, stats = filter_obj.process_dataset(\n dataset,\n text_column=\"user_input\"\n )\n \n # Step 4: Create stratified splits\n logger.info(\"\\n[Step 4/6] Creating stratified train/validation/test splits...\")\n \n # Try to stratify on a relevant column if available\n stratify_col = None\n if \"jailbreaking\" in filtered_dataset.column_names:\n stratify_col = \"jailbreaking\"\n elif \"toxicity\" in filtered_dataset.column_names:\n stratify_col = \"toxicity\"\n \n dataset_splits = filter_obj.create_stratified_splits(\n filtered_dataset,\n train_size=0.7,\n val_size=0.15,\n test_size=0.15,\n stratify_column=stratify_col\n )\n \n # Step 5: Save datasets\n logger.info(\"\\n[Step 5/6] Saving filtered datasets...\")\n \n # Save main filtered dataset with splits\n dataset_splits.save_to_disk(str(OUTPUT_DIR / \"filtered_dataset\"))\n logger.info(f\"Saved filtered dataset splits to {OUTPUT_DIR / 'filtered_dataset'}\")\n \n # Save toxic examples separately for analysis\n toxic_dataset.save_to_disk(str(OUTPUT_DIR / \"toxic_examples\"))\n logger.info(f\"Saved {len(toxic_dataset)} toxic examples to {OUTPUT_DIR / 'toxic_examples'}\")\n \n # Step 6: Generate comprehensive report\n logger.info(\"\\n[Step 6/6] Generating comprehensive JSON report...\")\n \n report = {\n \"metadata\": {\n \"timestamp\": datetime.now().isoformat(),\n \"dataset_source\": DATASET_NAME,\n \"dataset_config\": DATASET_CONFIG,\n \"model\": MODEL_NAME,\n \"threshold\": THRESHOLD,\n \"batch_size\": BATCH_SIZE,\n \"device\": filter_obj.device\n },\n \"dataset_statistics\": {\n \"original_size\": stats[\"total_examples\"],\n \"filtered_size\": stats[\"kept_examples\"],\n \"removed_size\": stats[\"filtered_examples\"],\n \"removal_rate\": f\"{stats['filter_rate']*100:.2f}%\",\n \"retention_rate\": f\"{(1-stats['filter_rate'])*100:.2f}%\"\n },\n \"per_label_statistics\": {},\n \"split_statistics\": {\n \"train\": {\n \"size\": len(dataset_splits[\"train\"]),\n \"percentage\": f\"{len(dataset_splits['train'])/stats['kept_examples']*100:.2f}%\"\n },\n \"validation\": {\n \"size\": len(dataset_splits[\"validation\"]),\n \"percentage\": f\"{len(dataset_splits['validation'])/stats['kept_examples']*100:.2f}%\"\n },\n \"test\": {\n \"size\": len(dataset_splits[\"test\"]),\n \"percentage\": f\"{len(dataset_splits['test'])/stats['kept_examples']*100:.2f}%\"\n }\n },\n \"output_paths\": {\n \"filtered_dataset\": str(OUTPUT_DIR / \"filtered_dataset\"),\n \"toxic_examples\": str(OUTPUT_DIR / \"toxic_examples\"),\n \"report\": str(REPORT_PATH)\n }\n }\n \n # Add per-label statistics\n for label_name, label_stats in stats[\"label_stats\"].items():\n report[\"per_label_statistics\"][label_name] = {\n \"count_above_threshold\": label_stats[\"count\"],\n \"removal_rate\": f\"{label_stats['removal_rate']*100:.2f}%\",\n \"percentage_of_dataset\": f\"{label_stats['removal_rate']*100:.2f}%\"\n }\n \n # Add stratification verification if applicable\n if stratify_col:\n report[\"stratification\"] = {\n \"stratified_on\": stratify_col,\n \"verification\": \"Stratification verified - see logs for distribution details\"\n }\n \n # Save report\n with open(REPORT_PATH, \"w\") as f:\n json.dump(report, f, indent=2)\n \n logger.info(f\"Report saved to {REPORT_PATH}\")\n \n # Print summary\n logger.info(\"\\n\" + \"=\"*80)\n logger.info(\"FILTERING COMPLETE - SUMMARY\")\n logger.info(\"=\"*80)\n logger.info(f\"Original dataset: {stats['total_examples']} examples\")\n logger.info(f\"Filtered (clean): {stats['kept_examples']} examples ({(1-stats['filter_rate'])*100:.2f}%)\")\n logger.info(f\"Removed (toxic): {stats['filtered_examples']} examples ({stats['filter_rate']*100:.2f}%)\")\n logger.info(\"\\nPer-label removal rates:\")\n for label_name, label_stats in stats[\"label_stats\"].items():\n logger.info(f\" {label_name:15s}: {label_stats['count']:5d} examples ({label_stats['removal_rate']*100:5.2f}%)\")\n logger.info(\"\\nDataset splits:\")\n logger.info(f\" Train: {len(dataset_splits['train']):5d} examples (70.0%)\")\n logger.info(f\" Validation: {len(dataset_splits['validation']):5d} examples (15.0%)\")\n logger.info(f\" Test: {len(dataset_splits['test']):5d} examples (15.0%)\")\n logger.info(f\"\\nAll outputs saved to: {OUTPUT_DIR}\")\n logger.info(\"=\"*80)\n\n\nif __name__ == \"__main__\":\n main()\n'''\n\n# Write to parent directory (hf-agent/)\nimport os\nparent_dir = Path(__file__).parent.parent if '__file__' in globals() else Path.cwd().parent\noutput_path = parent_dir / 'filter_toxic_dataset.py'\noutput_path.write_text(filter_toxic_content)\n\n# Make it executable\nimport stat\nst = output_path.stat()\noutput_path.chmod(st.st_mode | stat.S_IEXEC | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)\n\nprint(f\"βœ“ Successfully created {output_path}\")\nprint(f\"βœ“ File is executable\")\n\n# Also check dependencies\nprint(\"\\nChecking dependencies...\")\ntry:\n import sklearn\n print(\"βœ“ scikit-learn is installed (version: {})\".format(sklearn.__version__))\nexcept ImportError:\n print(\"βœ— scikit-learn is NOT installed\")\n\ntry:\n import tqdm\n print(\"βœ“ tqdm is installed (version: {})\".format(tqdm.__version__))\nexcept ImportError:\n print(\"βœ— tqdm is NOT installed\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "vv876s0gpqk", + "source": "# Create requirements.txt and README.md files\nfrom pathlib import Path\n\n# Navigate to project root (parent of eval/)\nproject_root = Path('/Users/akseljoonas/Documents/hf-agent')\n\nrequirements_content = '''# HF-Agent Requirements\n# Production-ready dependencies for the HF-Agent project\n# Install with: pip install -r requirements.txt or use uv sync (recommended)\n\n# Core ML/AI Dependencies\ntorch>=2.0.0\ntransformers>=4.35.0\ndatasets>=2.14.0\nnumpy>=1.24.0\naccelerate>=0.24.0\n\n# Agent SDK and API\nclaude-agent-sdk>=0.1.0\nlitellm>=1.0.0\npydantic>=2.12.3\n\n# Hugging Face Integration\nhuggingface-hub>=1.0.1\nfastmcp>=2.4.0\n\n# Evaluation Framework\ninspect-ai>=0.3.149\nlmnr[all]>=0.7.23\n\n# Utilities\npython-dotenv>=1.2.1\nrequests>=2.32.5\ntenacity>=8.0.0\ntqdm>=4.65.0\npandas>=2.3.3\n\n# Optional but recommended for evaluation\nscikit-learn>=1.3.0 # For stratified splits in dataset processing\npeft>=0.7.0 # For LoRA fusion tasks\n'''\n\nreadme_content = '''# HF Agent\n\nAn MLE agent CLI with MCP (Model Context Protocol) integration, built-in tool support, and comprehensive evaluation framework.\n\n## Quick Start\n\n### Installation\n\n```bash\n# Clone the repository\ngit clone git@github.com:huggingface/hf_agent.git\ncd hf-agent\n\n# Install dependencies (using uv - recommended)\nuv sync\n\n# Or use pip\npip install -r requirements.txt\n```\n\n### Set Up Environment\n\nCreate a `.env` file in the project root:\n\n```bash\n# Required for Claude Agent SDK\nANTHROPIC_API_KEY=your_api_key_here\n\n# Required for Hugging Face features\nHF_TOKEN=your_hf_token_here\n\n# Optional: LiteLLM API keys if using other providers\nOPENAI_API_KEY=your_openai_key_here\n```\n\n### Interactive CLI\n\n```bash\nuv run python -m agent.main\n```\n\nThis starts an interactive chat session with the agent. Type your messages and the agent will respond, using tools as needed.\n\n## Features\n\n### Core Capabilities\n\n- **Agent SDK Integration**: Built on Claude Agent SDK with support for async operations and streaming\n- **MCP Protocol Support**: Full Model Context Protocol integration for extensible tool management\n- **Built-in Tools**: File operations (Read/Write), Bash execution, and more\n- **Hugging Face Integration**: Search models, datasets, papers, and spaces directly through MCP\n- **LiteLLM Backend**: Flexible LLM provider support (Anthropic, OpenAI, custom)\n- **Context Management**: Intelligent message history tracking and compaction\n- **Evaluation Framework**: Rubric-based evaluation pipeline implementing Rubrics as Rewards (RaR) paper\n\n### Evaluation Suite\n\nThe `eval/` directory contains a comprehensive benchmark framework:\n\n- **Rubric Generation**: Instance-specific evaluation criteria from QA pairs\n- **Multiple Solvers**: Benchmark `hf_agent`, `claude_code`, or custom solvers\n- **Leaderboard Integration**: Track performance over time on HuggingFace datasets\n- **Inspect AI Integration**: Full integration with the Inspect AI evaluation framework\n\nSee [eval/README.md](eval/README.md) for detailed evaluation documentation.\n\n## Running the Agent\n\n### Basic Usage\n\n```bash\n# Start interactive mode\nuv run python -m agent.main\n```\n\n### With Custom Configuration\n\n```bash\n# Use a specific MCP server configuration\nuv run python -m agent.main --config agent/config_mcp_example.json\n```\n\n### Batch Processing\n\nProcess multiple tasks concurrently using the batch solver:\n\n```bash\n# Run batch evaluation with 5 concurrent agents\nuv run python eval/amp_batch_solve.py\n```\n\nThis processes tasks from `eval/filled_tasks.jsonl` and outputs results to `eval/solved_tasks.jsonl`.\n\n## Configuration\n\n### Agent Configuration\n\nCreate a JSON config file (e.g., `agent/config_mcp_example.json`):\n\n```json\n{\n \"model_name\": \"anthropic/claude-sonnet-4-5-20250929\",\n \"max_iterations\": 10,\n \"mcp_servers\": [\n {\n \"name\": \"huggingface\",\n \"command\": \"uvx\",\n \"args\": [\"fastmcp\", \"run\", \"huggingface\"],\n \"env\": {\n \"HF_TOKEN\": \"${HF_TOKEN}\"\n }\n }\n ]\n}\n```\n\n### Customizing Tools\n\nEdit `agent/core/tools.py` to add built-in tools:\n\n```python\ndef create_builtin_tools() -> list[ToolSpec]:\n return [\n ToolSpec(\n name=\"your_tool\",\n description=\"What your tool does\",\n parameters={\n \"type\": \"object\",\n \"properties\": {\n \"param\": {\"type\": \"string\", \"description\": \"Parameter description\"}\n },\n \"required\": [\"param\"]\n },\n handler=your_async_handler\n ),\n # ... existing tools\n ]\n```\n\n### Adding MCP Servers\n\nAdd to your config JSON:\n\n```json\n{\n \"mcp_servers\": [\n {\n \"name\": \"your_server\",\n \"command\": \"command\",\n \"args\": [\"arg1\", \"arg2\"],\n \"env\": {\"KEY\": \"value\"}\n }\n ]\n}\n```\n\n## Evaluation\n\n### Generate Rubrics\n\n```bash\nuv run python eval/generate_rubrics.py \\\n --infile qa_pairs.jsonl \\\n --outfile qa_rubrics.jsonl \\\n --model anthropic/claude-sonnet-4-5-20250929 \\\n --push-to-hub akseljoonas/hf-agent-benchmark@rubrics\n```\n\n### Run Evaluation\n\n```bash\n# Evaluate hf-agent\nuv run inspect eval eval/task.py@hf-benchmark-with-rubrics \\\n -T dataset_name=akseljoonas/hf-agent-rubrics \\\n -T dataset_split=train \\\n -T limit=25 \\\n -T solver_name=hf_agent \\\n -T solver_kwargs='{\"config_path\":\"agent/config_mcp_example.json\",\"max_iterations\":10}' \\\n --log-dir logs/inspect\n\n# Evaluate Claude Code headlessly\nuv run inspect eval eval/task.py@hf-benchmark-with-rubrics \\\n -T solver_name=claude_code \\\n -T solver_kwargs='{\"allowed_tools\":\"Bash,Read\",\"output_format\":\"json\"}'\n```\n\n### Push to Leaderboard\n\n```bash\nuv run python eval/run_eval_with_leaderboard.py \\\n --hf-dataset akseljoonas/hf-agent-leaderboard \\\n --hf-token $HF_TOKEN \\\n --solver-name hf_agent \\\n --solver-kwargs '{\"config_path\":\"agent/config_mcp_example.json\",\"max_iterations\":10}' \\\n --dataset akseljoonas/hf-agent-rubrics@train \\\n --limit 25\n```\n\n## Troubleshooting\n\n### Common Issues\n\n#### 1. MCP Server Connection Errors\n\n**Problem**: Agent fails to connect to MCP servers.\n\n**Solutions**:\n- Verify MCP server command is in PATH: `which uvx` or `which fastmcp`\n- Check environment variables are set correctly in `.env`\n- Ensure HF_TOKEN is valid: `huggingface-cli whoami`\n- Try running MCP server manually: `uvx fastmcp run huggingface`\n\n#### 2. CUDA Out of Memory\n\n**Problem**: GPU memory errors during model loading or inference.\n\n**Solutions**:\n- Use smaller batch sizes in evaluation scripts\n- Enable gradient checkpointing for large models\n- Use `torch.float16` or `torch.bfloat16` for reduced memory\n- Clear CUDA cache: `torch.cuda.empty_cache()`\n- Use CPU inference for testing: `device_map=\"cpu\"`\n\n#### 3. LiteLLM API Errors\n\n**Problem**: API key or rate limit errors.\n\n**Solutions**:\n- Verify API keys in `.env`: `ANTHROPIC_API_KEY`, `OPENAI_API_KEY`\n- Check rate limits for your API provider\n- Add retry logic with exponential backoff (already included via `tenacity`)\n- Monitor usage: `litellm --debug`\n\n#### 4. Import Errors\n\n**Problem**: `ModuleNotFoundError` for packages.\n\n**Solutions**:\n```bash\n# Reinstall dependencies\nuv sync\n\n# Or with pip\npip install -r requirements.txt\n\n# Check Python version (requires >=3.12)\npython --version\n```\n\n#### 5. Evaluation Rubrics Not Loading\n\n**Problem**: Rubric scorer fails or returns invalid scores.\n\n**Solutions**:\n- Verify rubrics dataset format matches expected schema\n- Check that `eval/generate_rubrics.py` completed successfully\n- Validate JSONL format: each line should be valid JSON\n- Inspect rubric structure: must have `criteria` list with `criterion`, `weight`, `type`\n\n#### 6. Permission Errors with Bash Tool\n\n**Problem**: Agent cannot execute bash commands.\n\n**Solutions**:\n- Verify `permission_mode` in config: should be `\"bypassPermissions\"` for batch mode\n- Check file permissions: `chmod +x script.sh`\n- Ensure working directory exists and is writable\n- Review `disallowed_tools` list in configuration\n\n### Getting Help\n\n- **Documentation**: See [eval/README.md](eval/README.md) for evaluation details\n- **Issues**: Open an issue on GitHub with error logs\n- **Logs**: Check `logs/inspect/` for detailed evaluation logs\n- **Debug Mode**: Set `LITELLM_LOG=DEBUG` environment variable\n\n## Example Output\n\n### Successful Evaluation\n\n```\n[1/25] Starting: What's the best model for sentiment analysis...\n[1/25] βœ“ Done: What's the best model for sentiment analysis...\n[2/25] Starting: How can I serve a model with multiple LoRAs...\n[2/25] βœ“ Done: How can I serve a model with multiple LoRAs...\n\nCompleted: 25/25 successful\nResults saved to eval/solved_tasks.jsonl\n```\n\n### Rubric Scoring\n\n```\nTask: \"Find the best text-generation model for medical domain\"\nCriteria:\n βœ“ Searches HuggingFace for domain-specific models (weight: 5) - PASS\n βœ“ Considers model size and hardware requirements (weight: 3) - PASS\n βœ“ Checks model licenses for commercial use (weight: 4) - PASS\n βœ— Provides code example for inference (weight: 2) - FAIL\n \nScore: 0.857 (12/14 weighted points)\n```\n\n## Project Structure\n\n```\nhf-agent/\nβ”œβ”€β”€ agent/ # Main agent implementation\nβ”‚ β”œβ”€β”€ config.py # Configuration models\nβ”‚ β”œβ”€β”€ main.py # Interactive CLI entry point\nβ”‚ β”œβ”€β”€ context_manager/\nβ”‚ β”‚ └── manager.py # Message history management\nβ”‚ └── core/\nβ”‚ β”œβ”€β”€ agent_loop.py # Main agent loop and handlers\nβ”‚ β”œβ”€β”€ session.py # Session management\nβ”‚ β”œβ”€β”€ mcp_client.py # MCP SDK integration\nβ”‚ └── tools.py # ToolRouter and built-in tools\nβ”‚\nβ”œβ”€β”€ eval/ # Evaluation suite\nβ”‚ β”œβ”€β”€ README.md # Detailed evaluation docs\nβ”‚ β”œβ”€β”€ generate_rubrics.py # Rubric generation from QA pairs\nβ”‚ β”œβ”€β”€ rubric_eval.py # RaR-Explicit scoring implementation\nβ”‚ β”œβ”€β”€ task.py # Inspect AI task definitions\nβ”‚ β”œβ”€β”€ solvers.py # Solver registry (hf_agent, claude_code, etc.)\nβ”‚ β”œβ”€β”€ hf_agent_connector.py # Bridge to agent stack\nβ”‚ β”œβ”€β”€ leaderboard.py # HuggingFace leaderboard utilities\nβ”‚ β”œβ”€β”€ run_eval_with_leaderboard.py # CLI wrapper for evals\nβ”‚ β”œβ”€β”€ amp_batch_solve.py # Concurrent batch processing\nβ”‚ └── models.py # Shared Pydantic models\nβ”‚\nβ”œβ”€β”€ requirements.txt # Python dependencies\nβ”œβ”€β”€ pyproject.toml # Project metadata (for uv)\nβ”œβ”€β”€ README.md # This file\n└── .env # Environment variables (create this)\n```\n\n## Advanced Usage\n\n### Custom Solver Implementation\n\nCreate a new solver in `eval/solvers.py`:\n\n```python\n@solver\ndef my_custom_solver():\n async def solve(state: TaskState, generate: Generate):\n # Your solver logic here\n response = await your_agent_call(state.input_text)\n return response\n return solve\n```\n\nRegister and use:\n\n```bash\nuv run inspect eval eval/task.py@hf-benchmark-with-rubrics \\\n -T solver_name=my_custom_solver\n```\n\n### Streaming Responses\n\nEnable streaming in the agent connector:\n\n```python\nfrom agent.core.session import Session\n\nsession = Session(config)\nasync for chunk in session.stream_response(prompt):\n print(chunk, end=\"\", flush=True)\n```\n\n### Cost Tracking\n\nMonitor API costs using LiteLLM callbacks:\n\n```python\nimport litellm\nlitellm.success_callback = [\"langfuse\"] # Or other integrations\n```\n\n## Contributing\n\n1. Fork the repository\n2. Create a feature branch: `git checkout -b feature/your-feature`\n3. Make your changes\n4. Run tests: `uv run pytest`\n5. Commit with clear messages: `git commit -m \"Add feature X\"`\n6. Push and create a Pull Request\n\n## License\n\n[Your License Here]\n\n## Acknowledgments\n\n- Built on [Claude Agent SDK](https://github.com/anthropics/claude-agent-sdk)\n- Evaluation framework inspired by [Rubrics as Rewards](https://arxiv.org/abs/2507.17746)\n- Powered by [Hugging Face](https://huggingface.co/) ecosystem\n'''\n\n# Write files\n(project_root / 'requirements.txt').write_text(requirements_content)\n(project_root / 'README_NEW.md').write_text(readme_content)\n\nprint(f\"βœ“ Created {project_root / 'requirements.txt'}\")\nprint(f\"βœ“ Created {project_root / 'README_NEW.md'}\")\nprint(\"\\nBoth files are production-ready!\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "7fljbwefz1v", + "source": "from pathlib import Path\n\n# Complete monitoring script for HF Job 694306ebc67c9f186cfe3879\nmonitoring_script = r'''#!/usr/bin/env python3\n\"\"\"\nHugging Face Job Monitor\nJob ID: 694306ebc67c9f186cfe3879\nvLLM Benchmark: Testing 4 block sizes (8, 16, 32, 64) for Llama-3.1-8B-Instruct\n\"\"\"\nimport time\nimport os\nimport sys\nfrom huggingface_hub import HfApi\nfrom dotenv import load_dotenv\n\ndef main():\n # Load environment\n load_dotenv()\n \n # Configuration\n job_id = \"694306ebc67c9f186cfe3879\"\n check_interval = 60 # seconds\n \n # Initialize API\n token = os.environ.get('HF_TOKEN')\n if not token:\n print(\"ERROR: HF_TOKEN environment variable not set\")\n print(\"Please set it in your .env file or export it:\")\n print(\" export HF_TOKEN='your_token_here'\")\n sys.exit(1)\n \n api = HfApi(token=token)\n \n # Display header\n print(\"=\"*80)\n print(f\"Monitoring Hugging Face Job: {job_id}\")\n print(\"=\"*80)\n print(\"Benchmark: vLLM with 4 block sizes (8, 16, 32, 64)\")\n print(\"Model: Llama-3.1-8B-Instruct\")\n print(f\"Check Interval: {check_interval} seconds\")\n print(\"=\"*80)\n \n seen_log_length = 0\n check_count = 0\n \n while True:\n try:\n check_count += 1\n \n # Inspect job status\n job_info = api.inspect_job(job_id)\n \n # Display status\n timestamp = time.strftime('%Y-%m-%d %H:%M:%S')\n print(f\"\\n[Check #{check_count}] [{timestamp}]\")\n print(f\"Status: {job_info.status.stage}\")\n \n if job_info.status.message:\n print(f\"Message: {job_info.status.message}\")\n \n # Fetch and process logs\n try:\n current_logs = \"\"\n for log_line in api.fetch_job_logs(job_id):\n current_logs += log_line + \"\\n\"\n \n # Display only new log content\n if len(current_logs) > seen_log_length:\n new_content = current_logs[seen_log_length:]\n if new_content.strip():\n print(\"\\n--- New Log Output ---\")\n print(new_content)\n print(\"--- End New Logs ---\")\n seen_log_length = len(current_logs)\n \n # Look for benchmark results markers\n if \"BENCHMARK RESULTS SUMMARY\" in current_logs:\n print(\"\\n\" + \"=\"*80)\n print(\"🎯 BENCHMARK RESULTS SUMMARY DETECTED!\")\n print(\"=\"*80)\n \n if \"JSON Results\" in current_logs:\n print(\"\\n\" + \"=\"*80)\n print(\"πŸ“Š JSON RESULTS DETECTED!\")\n print(\"=\"*80)\n \n except Exception as log_error:\n print(f\"Note: Could not fetch logs: {log_error}\")\n \n # Check if job has completed\n if job_info.status.stage in [\"COMPLETED\", \"CANCELED\", \"ERROR\", \"DELETED\"]:\n print(\"\\n\" + \"=\"*80)\n print(f\"JOB FINISHED\")\n print(f\"Final Status: {job_info.status.stage}\")\n print(\"=\"*80)\n \n # Fetch and display complete final output\n print(\"\\nFetching complete job output...\")\n try:\n final_logs = \"\"\n for log_line in api.fetch_job_logs(job_id):\n final_logs += log_line + \"\\n\"\n \n print(\"\\n\" + \"=\"*80)\n print(\"COMPLETE JOB OUTPUT\")\n print(\"=\"*80 + \"\\n\")\n print(final_logs)\n print(\"\\n\" + \"=\"*80)\n print(\"END OF COMPLETE OUTPUT\")\n print(\"=\"*80)\n \n except Exception as e:\n print(f\"Error fetching final logs: {e}\")\n \n print(f\"\\nJob URL: {job_info.url}\")\n print(f\"Job ID: {job_id}\")\n \n # Exit with appropriate code\n if job_info.status.stage == \"COMPLETED\":\n sys.exit(0)\n else:\n sys.exit(1)\n \n # Wait before next check\n print(f\"\\nWaiting {check_interval} seconds before next check...\")\n print(f\"(Current status: {job_info.status.stage})\")\n print(\"(Press Ctrl+C to stop monitoring)\")\n time.sleep(check_interval)\n \n except KeyboardInterrupt:\n print(\"\\n\\n\" + \"=\"*80)\n print(\"Monitoring interrupted by user (Ctrl+C)\")\n print(\"=\"*80)\n try:\n job_info = api.inspect_job(job_id)\n print(f\"\\nLatest Status: {job_info.status.stage}\")\n print(f\"Job URL: {job_info.url}\")\n except:\n pass\n print(f\"\\nYou can resume monitoring by running this script again\")\n sys.exit(0)\n \n except Exception as e:\n print(f\"\\nError: {e}\")\n print(f\"Retrying in {check_interval} seconds...\")\n time.sleep(check_interval)\n\nif __name__ == \"__main__\":\n main()\n'''\n\n# Write script to eval directory\neval_dir = Path('/Users/akseljoonas/Documents/hf-agent/eval')\nscript_path = eval_dir / 'monitor_hf_job.py'\nscript_path.write_text(monitoring_script)\n\n# Make executable\nimport stat\nst = script_path.stat()\nscript_path.chmod(st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)\n\nprint(f\"βœ“ Created monitoring script: {script_path}\")\nprint(f\"\\nTo start monitoring, run one of:\")\nprint(f\" python {script_path}\")\nprint(f\" uv run python {script_path}\")\nprint(f\"\\nThe script will:\")\nprint(\" - Check job status every 60 seconds\")\nprint(\" - Display new log output as it becomes available\")\nprint(\" - Detect BENCHMARK RESULTS SUMMARY and JSON Results\")\nprint(\" - Display complete output when job finishes\")\nprint(\" - Exit automatically when job completes or fails\")\nprint(\"\\nPress Ctrl+C to stop monitoring at any time\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "yjf9l5kmab8", + "source": "from pathlib import Path\nimport sys\n\n# Add parent directory to path\nsys.path.insert(0, str(Path.cwd().parent))\n\n# Define all the scripts we need to create\nproject_root = Path('/Users/akseljoonas/Documents/hf-agent')\n\n# 1. convert_to_webdataset.py\nconvert_script = r'''#!/usr/bin/env python3\n\"\"\"\nConvert HuggingFaceFW/fineweb-edu dataset to WebDataset format with checksum validation.\n\nThis script loads the fineweb-edu dataset and converts it to WebDataset tar archives\nwith proper sharding, checksum validation, and metadata tracking.\n\"\"\"\n\nimport argparse\nimport hashlib\nimport json\nimport logging\nimport os\nimport sys\nfrom pathlib import Path\nfrom typing import Dict, Optional, Any\nimport tarfile\nfrom io import BytesIO\n\nfrom datasets import load_dataset\nfrom tqdm import tqdm\n\n# Configure logging\nlogging.basicConfig(\n level=logging.INFO,\n format='%(asctime)s - %(levelname)s - %(message)s'\n)\nlogger = logging.getLogger(__name__)\n\n\nclass WebDatasetConverter:\n \"\"\"Convert HuggingFace dataset to WebDataset format with checksums.\"\"\"\n \n def __init__(\n self,\n dataset_name: str = \"HuggingFaceFW/fineweb-edu\",\n config_name: Optional[str] = None,\n split: str = \"train\",\n output_dir: str = \"./webdataset_output\",\n shard_size_mb: int = 500,\n max_samples: Optional[int] = None,\n streaming: bool = True\n ):\n \"\"\"\n Initialize the converter.\n \n Args:\n dataset_name: HuggingFace dataset identifier\n config_name: Dataset configuration name (e.g., \"sample-10BT\")\n split: Dataset split to convert\n output_dir: Directory to save WebDataset shards\n shard_size_mb: Target size for each shard in MB\n max_samples: Maximum number of samples to convert (None for all)\n streaming: Use streaming mode for large datasets\n \"\"\"\n self.dataset_name = dataset_name\n self.config_name = config_name\n self.split = split\n self.output_dir = Path(output_dir)\n self.shard_size_bytes = shard_size_mb * 1024 * 1024\n self.max_samples = max_samples\n self.streaming = streaming\n \n # Create output directory\n self.output_dir.mkdir(parents=True, exist_ok=True)\n \n # Track checksums and metadata\n self.checksums: Dict[str, str] = {}\n self.shard_metadata: Dict[str, Dict[str, Any]] = {}\n self.total_samples = 0\n self.current_shard = 0\n self.current_shard_size = 0\n self.current_shard_samples = 0\n \n def compute_sha256(self, filepath: Path) -> str:\n \"\"\"Compute SHA256 checksum of a file.\"\"\"\n sha256_hash = hashlib.sha256()\n with open(filepath, \"rb\") as f:\n for byte_block in iter(lambda: f.read(4096), b\"\"):\n sha256_hash.update(byte_block)\n return sha256_hash.hexdigest()\n \n def format_sample_id(self, index: int) -> str:\n \"\"\"Format sample ID with zero padding.\"\"\"\n return f\"sample_{index:012d}\"\n \n def create_tar_member(self, name: str, data: bytes) -> tarfile.TarInfo:\n \"\"\"Create a tar member from data.\"\"\"\n tarinfo = tarfile.TarInfo(name=name)\n tarinfo.size = len(data)\n return tarinfo\n \n def should_create_new_shard(self) -> bool:\n \"\"\"Check if we should start a new shard.\"\"\"\n return self.current_shard_size >= self.shard_size_bytes\n \n def get_shard_path(self, shard_num: int) -> Path:\n \"\"\"Get the path for a shard file.\"\"\"\n return self.output_dir / f\"fineweb_edu_{shard_num:06d}.tar\"\n \n def write_sample_to_tar(\n self,\n tar: tarfile.TarFile,\n sample_id: str,\n text: str,\n metadata: Dict[str, Any]\n ) -> int:\n \"\"\"\n Write a sample to the tar archive.\n \n Returns the size in bytes written.\n \"\"\"\n # Write text file\n text_bytes = text.encode('utf-8')\n text_name = f\"{sample_id}.txt\"\n text_info = self.create_tar_member(text_name, text_bytes)\n tar.addfile(text_info, BytesIO(text_bytes))\n \n # Write JSON metadata file\n json_bytes = json.dumps(metadata, ensure_ascii=False).encode('utf-8')\n json_name = f\"{sample_id}.json\"\n json_info = self.create_tar_member(json_name, json_bytes)\n tar.addfile(json_info, BytesIO(json_bytes))\n \n # Return total size\n return len(text_bytes) + len(json_bytes)\n \n def finalize_shard(self, shard_path: Path):\n \"\"\"Compute checksum and save metadata for a completed shard.\"\"\"\n if shard_path.exists():\n # Compute checksum\n checksum = self.compute_sha256(shard_path)\n shard_name = shard_path.name\n self.checksums[shard_name] = checksum\n \n # Store metadata\n self.shard_metadata[shard_name] = {\n \"shard_number\": self.current_shard,\n \"num_samples\": self.current_shard_samples,\n \"size_bytes\": shard_path.stat().st_size,\n \"checksum\": checksum\n }\n \n logger.info(\n f\"Finalized {shard_name}: {self.current_shard_samples} samples, \"\n f\"{shard_path.stat().st_size / (1024*1024):.2f} MB, \"\n f\"checksum: {checksum[:16]}...\"\n )\n \n def convert(self):\n \"\"\"Convert the dataset to WebDataset format.\"\"\"\n logger.info(f\"Loading dataset: {self.dataset_name}\")\n if self.config_name:\n logger.info(f\"Config: {self.config_name}\")\n logger.info(f\"Split: {self.split}\")\n logger.info(f\"Streaming: {self.streaming}\")\n \n # Load dataset\n try:\n dataset = load_dataset(\n self.dataset_name,\n name=self.config_name,\n split=self.split,\n streaming=self.streaming\n )\n except Exception as e:\n logger.error(f\"Failed to load dataset: {e}\")\n sys.exit(1)\n \n logger.info(f\"Dataset loaded successfully\")\n \n # Initialize first shard\n shard_path = self.get_shard_path(self.current_shard)\n tar = tarfile.open(shard_path, 'w')\n \n try:\n # Process samples\n sample_iter = iter(dataset)\n if self.max_samples:\n logger.info(f\"Processing up to {self.max_samples} samples\")\n \n # Create progress bar\n pbar = tqdm(\n total=self.max_samples,\n desc=\"Converting samples\",\n unit=\"samples\"\n )\n \n for idx, sample in enumerate(sample_iter):\n if self.max_samples and idx >= self.max_samples:\n break\n \n # Check if we need a new shard\n if self.should_create_new_shard() and self.current_shard_samples > 0:\n # Finalize current shard\n tar.close()\n self.finalize_shard(shard_path)\n \n # Start new shard\n self.current_shard += 1\n self.current_shard_size = 0\n self.current_shard_samples = 0\n shard_path = self.get_shard_path(self.current_shard)\n tar = tarfile.open(shard_path, 'w')\n logger.info(f\"Starting new shard: {shard_path.name}\")\n \n # Create sample ID\n sample_id = self.format_sample_id(self.total_samples)\n \n # Extract text and metadata\n text = sample.get('text', '')\n metadata = {\n 'id': sample.get('id', ''),\n 'url': sample.get('url', ''),\n 'dump': sample.get('dump', ''),\n 'score': sample.get('score', None),\n 'token_count': sample.get('token_count', None),\n 'language': sample.get('language', ''),\n 'language_score': sample.get('language_score', None),\n 'sample_id': sample_id,\n 'sample_index': self.total_samples\n }\n \n # Write to tar\n sample_size = self.write_sample_to_tar(tar, sample_id, text, metadata)\n \n # Update counters\n self.current_shard_size += sample_size\n self.current_shard_samples += 1\n self.total_samples += 1\n pbar.update(1)\n \n pbar.close()\n \n # Finalize last shard\n tar.close()\n self.finalize_shard(shard_path)\n \n except Exception as e:\n logger.error(f\"Error during conversion: {e}\")\n tar.close()\n raise\n \n # Write checksums and metadata\n self.write_checksums()\n self.write_dataset_metadata()\n \n logger.info(f\"\\nConversion complete!\")\n logger.info(f\"Total samples: {self.total_samples}\")\n logger.info(f\"Total shards: {self.current_shard + 1}\")\n logger.info(f\"Output directory: {self.output_dir}\")\n \n def write_checksums(self):\n \"\"\"Write checksums.json file.\"\"\"\n checksums_path = self.output_dir / \"checksums.json\"\n with open(checksums_path, 'w') as f:\n json.dump(self.checksums, f, indent=2)\n logger.info(f\"Checksums written to: {checksums_path}\")\n \n def write_dataset_metadata(self):\n \"\"\"Write dataset_metadata.json file.\"\"\"\n metadata = {\n \"dataset_name\": self.dataset_name,\n \"config_name\": self.config_name,\n \"split\": self.split,\n \"total_samples\": self.total_samples,\n \"num_shards\": self.current_shard + 1,\n \"shard_size_mb\": self.shard_size_bytes / (1024 * 1024),\n \"shards\": self.shard_metadata,\n \"format\": \"webdataset\",\n \"sample_structure\": {\n \"text\": \".txt file\",\n \"metadata\": \".json file (id, url, dump, score, token_count, language, language_score, sample_id, sample_index)\"\n }\n }\n \n metadata_path = self.output_dir / \"dataset_metadata.json\"\n with open(metadata_path, 'w') as f:\n json.dump(metadata, f, indent=2)\n logger.info(f\"Dataset metadata written to: {metadata_path}\")\n\n\ndef main():\n \"\"\"Main entry point.\"\"\"\n parser = argparse.ArgumentParser(\n description=\"Convert HuggingFaceFW/fineweb-edu to WebDataset format\"\n )\n parser.add_argument(\n \"--dataset\",\n type=str,\n default=\"HuggingFaceFW/fineweb-edu\",\n help=\"HuggingFace dataset name\"\n )\n parser.add_argument(\n \"--config\",\n type=str,\n default=None,\n help=\"Dataset configuration (e.g., 'sample-10BT', 'sample-100BT', 'sample-350BT')\"\n )\n parser.add_argument(\n \"--split\",\n type=str,\n default=\"train\",\n help=\"Dataset split to convert\"\n )\n parser.add_argument(\n \"--output-dir\",\n type=str,\n default=\"./webdataset_output\",\n help=\"Output directory for WebDataset shards\"\n )\n parser.add_argument(\n \"--shard-size\",\n type=int,\n default=500,\n help=\"Target shard size in MB\"\n )\n parser.add_argument(\n \"--max-samples\",\n type=int,\n default=None,\n help=\"Maximum number of samples to convert (for testing)\"\n )\n parser.add_argument(\n \"--no-streaming\",\n action=\"store_true\",\n help=\"Disable streaming mode (loads entire dataset into memory)\"\n )\n \n args = parser.parse_args()\n \n # Create converter\n converter = WebDatasetConverter(\n dataset_name=args.dataset,\n config_name=args.config,\n split=args.split,\n output_dir=args.output_dir,\n shard_size_mb=args.shard_size,\n max_samples=args.max_samples,\n streaming=not args.no_streaming\n )\n \n # Run conversion\n converter.convert()\n\n\nif __name__ == \"__main__\":\n main()\n'''\n\n# Write the conversion script\n(project_root / 'convert_to_webdataset.py').write_text(convert_script)\nprint(f\"βœ“ Created {project_root / 'convert_to_webdataset.py'}\")\nprint(f\" Size: {len(convert_script)} bytes\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "9fqo47lnws", + "source": "# 2. webdataset_loader.py\nloader_script = r'''#!/usr/bin/env python3\n\"\"\"\nWebDataset Streaming Loader with Checksum Validation.\n\nThis module provides a streaming loader for WebDataset format with:\n- Checksum validation before loading shards\n- PyTorch DataLoader compatible interface\n- Support for distributed training (worker sharding)\n- Optional sample filtering and transformation\n\"\"\"\n\nimport hashlib\nimport json\nimport logging\nimport warnings\nfrom pathlib import Path\nfrom typing import Dict, Optional, Callable, Any, List, Iterator\nimport tarfile\nfrom io import BytesIO\n\nimport torch\nfrom torch.utils.data import IterableDataset, DataLoader\nimport webdataset as wds\n\n# Configure logging\nlogging.basicConfig(level=logging.INFO)\nlogger = logging.getLogger(__name__)\n\n\nclass ChecksumValidator:\n \"\"\"Validate checksums for WebDataset shards.\"\"\"\n \n def __init__(self, checksums_file: Path):\n \"\"\"\n Initialize validator with checksums file.\n \n Args:\n checksums_file: Path to checksums.json file\n \"\"\"\n self.checksums_file = Path(checksums_file)\n self.checksums: Dict[str, str] = {}\n self._load_checksums()\n \n def _load_checksums(self):\n \"\"\"Load checksums from JSON file.\"\"\"\n if not self.checksums_file.exists():\n raise FileNotFoundError(f\"Checksums file not found: {self.checksums_file}\")\n \n with open(self.checksums_file, 'r') as f:\n self.checksums = json.load(f)\n \n logger.info(f\"Loaded {len(self.checksums)} checksums from {self.checksums_file}\")\n \n def compute_sha256(self, filepath: Path) -> str:\n \"\"\"Compute SHA256 checksum of a file.\"\"\"\n sha256_hash = hashlib.sha256()\n with open(filepath, \"rb\") as f:\n for byte_block in iter(lambda: f.read(4096), b\"\"):\n sha256_hash.update(byte_block)\n return sha256_hash.hexdigest()\n \n def validate_shard(self, shard_path: Path) -> bool:\n \"\"\"\n Validate a shard's checksum.\n \n Args:\n shard_path: Path to the shard file\n \n Returns:\n True if checksum matches, False otherwise\n \"\"\"\n shard_name = shard_path.name\n \n if shard_name not in self.checksums:\n logger.warning(f\"No checksum found for shard: {shard_name}\")\n return False\n \n expected_checksum = self.checksums[shard_name]\n actual_checksum = self.compute_sha256(shard_path)\n \n if actual_checksum != expected_checksum:\n logger.error(\n f\"Checksum mismatch for {shard_name}!\\n\"\n f\" Expected: {expected_checksum}\\n\"\n f\" Actual: {actual_checksum}\"\n )\n return False\n \n logger.debug(f\"Checksum validated for {shard_name}\")\n return True\n \n def validate_all_shards(self, shard_dir: Path) -> bool:\n \"\"\"\n Validate all shards in a directory.\n \n Args:\n shard_dir: Directory containing shard files\n \n Returns:\n True if all shards are valid, False otherwise\n \"\"\"\n shard_dir = Path(shard_dir)\n all_valid = True\n \n for shard_name in self.checksums.keys():\n shard_path = shard_dir / shard_name\n \n if not shard_path.exists():\n logger.error(f\"Shard not found: {shard_path}\")\n all_valid = False\n continue\n \n if not self.validate_shard(shard_path):\n all_valid = False\n \n return all_valid\n\n\nclass WebDatasetLoader(IterableDataset):\n \"\"\"\n Streaming WebDataset loader with checksum validation and PyTorch compatibility.\n \"\"\"\n \n def __init__(\n self,\n data_dir: str,\n validate_checksums: bool = True,\n shuffle: bool = False,\n buffer_size: int = 1000,\n transform: Optional[Callable] = None,\n filter_fn: Optional[Callable] = None,\n shard_pattern: str = \"*.tar\"\n ):\n \"\"\"\n Initialize the WebDataset loader.\n \n Args:\n data_dir: Directory containing WebDataset shards\n validate_checksums: Whether to validate checksums before loading\n shuffle: Whether to shuffle samples (requires buffer)\n buffer_size: Buffer size for shuffling\n transform: Optional transformation function for samples\n filter_fn: Optional filter function to skip samples\n shard_pattern: Glob pattern for shard files\n \"\"\"\n super().__init__()\n \n self.data_dir = Path(data_dir)\n self.validate_checksums = validate_checksums\n self.shuffle = shuffle\n self.buffer_size = buffer_size\n self.transform = transform\n self.filter_fn = filter_fn\n self.shard_pattern = shard_pattern\n \n # Find all shards\n self.shard_paths = sorted(self.data_dir.glob(shard_pattern))\n \n if not self.shard_paths:\n raise ValueError(f\"No shards found in {data_dir} matching pattern {shard_pattern}\")\n \n logger.info(f\"Found {len(self.shard_paths)} shards in {data_dir}\")\n \n # Validate checksums if requested\n if self.validate_checksums:\n self._validate_all_checksums()\n \n # Load metadata\n self.metadata = self._load_metadata()\n \n def _validate_all_checksums(self):\n \"\"\"Validate checksums for all shards.\"\"\"\n checksums_file = self.data_dir / \"checksums.json\"\n \n if not checksums_file.exists():\n warnings.warn(\n f\"Checksums file not found: {checksums_file}. \"\n \"Skipping validation.\"\n )\n return\n \n validator = ChecksumValidator(checksums_file)\n \n logger.info(\"Validating checksums for all shards...\")\n all_valid = validator.validate_all_shards(self.data_dir)\n \n if not all_valid:\n raise ValueError(\"Checksum validation failed! Some shards are corrupted.\")\n \n logger.info(\"All checksums validated successfully\")\n \n def _load_metadata(self) -> Dict[str, Any]:\n \"\"\"Load dataset metadata if available.\"\"\"\n metadata_file = self.data_dir / \"dataset_metadata.json\"\n \n if metadata_file.exists():\n with open(metadata_file, 'r') as f:\n metadata = json.load(f)\n logger.info(f\"Loaded metadata: {metadata.get('total_samples', 'unknown')} samples\")\n return metadata\n else:\n logger.warning(f\"Metadata file not found: {metadata_file}\")\n return {}\n \n def _decode_sample(self, sample: Dict) -> Dict:\n \"\"\"\n Decode a sample from WebDataset format.\n \n Expected format:\n - sample['txt']: text content (bytes)\n - sample['json']: metadata (bytes)\n \"\"\"\n decoded = {}\n \n # Decode text\n if 'txt' in sample:\n decoded['text'] = sample['txt'].decode('utf-8')\n \n # Decode metadata\n if 'json' in sample:\n metadata = json.loads(sample['json'].decode('utf-8'))\n decoded.update(metadata)\n \n # Keep the key\n if '__key__' in sample:\n decoded['__key__'] = sample['__key__']\n \n return decoded\n \n def __iter__(self) -> Iterator[Dict]:\n \"\"\"Iterate over samples in the dataset.\"\"\"\n # Get worker info for distributed training\n worker_info = torch.utils.data.get_worker_info()\n \n if worker_info is not None:\n # Split shards among workers\n num_workers = worker_info.num_workers\n worker_id = worker_info.id\n \n # Select shards for this worker\n shards_per_worker = len(self.shard_paths) // num_workers\n start_idx = worker_id * shards_per_worker\n end_idx = start_idx + shards_per_worker if worker_id < num_workers - 1 else len(self.shard_paths)\n \n worker_shards = self.shard_paths[start_idx:end_idx]\n logger.info(f\"Worker {worker_id}/{num_workers}: processing {len(worker_shards)} shards\")\n else:\n worker_shards = self.shard_paths\n \n # Convert paths to URLs for webdataset\n shard_urls = [str(p) for p in worker_shards]\n \n # Create WebDataset pipeline\n dataset = wds.WebDataset(shard_urls)\n \n # Add shuffling if requested\n if self.shuffle:\n dataset = dataset.shuffle(self.buffer_size)\n \n # Decode samples\n dataset = dataset.map(self._decode_sample)\n \n # Apply filter if provided\n if self.filter_fn is not None:\n dataset = dataset.select(self.filter_fn)\n \n # Apply transformation if provided\n if self.transform is not None:\n dataset = dataset.map(self.transform)\n \n # Iterate over samples\n for sample in dataset:\n yield sample\n \n def get_dataloader(\n self,\n batch_size: int = 32,\n num_workers: int = 4,\n pin_memory: bool = True,\n collate_fn: Optional[Callable] = None\n ) -> DataLoader:\n \"\"\"\n Create a PyTorch DataLoader for this dataset.\n \n Args:\n batch_size: Batch size\n num_workers: Number of worker processes\n pin_memory: Whether to pin memory for faster GPU transfer\n collate_fn: Custom collate function for batching\n \n Returns:\n DataLoader instance\n \"\"\"\n return DataLoader(\n self,\n batch_size=batch_size,\n num_workers=num_workers,\n pin_memory=pin_memory,\n collate_fn=collate_fn\n )\n\n\n# Utility functions\n\ndef verify_checksums(data_dir: str) -> bool:\n \"\"\"\n Verify checksums for all shards in a directory.\n \n Args:\n data_dir: Directory containing WebDataset shards and checksums.json\n \n Returns:\n True if all checksums are valid, False otherwise\n \"\"\"\n data_dir = Path(data_dir)\n checksums_file = data_dir / \"checksums.json\"\n \n if not checksums_file.exists():\n logger.error(f\"Checksums file not found: {checksums_file}\")\n return False\n \n validator = ChecksumValidator(checksums_file)\n return validator.validate_all_shards(data_dir)\n\n\ndef default_collate_fn(batch: List[Dict]) -> Dict:\n \"\"\"\n Default collate function for batching WebDataset samples.\n \n Args:\n batch: List of decoded samples\n \n Returns:\n Batched dictionary with lists of values\n \"\"\"\n if not batch:\n return {}\n \n # Get all keys from first sample\n keys = batch[0].keys()\n \n # Collate each key\n collated = {}\n for key in keys:\n values = [sample[key] for sample in batch]\n collated[key] = values\n \n return collated\n\n\ndef main():\n \"\"\"Example usage and testing.\"\"\"\n import argparse\n \n parser = argparse.ArgumentParser(description=\"WebDataset Loader with Checksum Validation\")\n parser.add_argument(\"data_dir\", type=str, help=\"Directory containing WebDataset shards\")\n parser.add_argument(\"--validate-only\", action=\"store_true\", help=\"Only validate checksums\")\n parser.add_argument(\"--no-validate\", action=\"store_true\", help=\"Skip checksum validation\")\n parser.add_argument(\"--num-samples\", type=int, default=10, help=\"Number of samples to load (for testing)\")\n \n args = parser.parse_args()\n \n if args.validate_only:\n # Just validate checksums\n logger.info(\"Validating checksums...\")\n valid = verify_checksums(args.data_dir)\n \n if valid:\n logger.info(\"All checksums are valid!\")\n return 0\n else:\n logger.error(\"Checksum validation failed!\")\n return 1\n else:\n # Load and display samples\n logger.info(f\"Loading WebDataset from {args.data_dir}\")\n \n loader = WebDatasetLoader(\n args.data_dir,\n validate_checksums=not args.no_validate,\n shuffle=False\n )\n \n logger.info(f\"Loading {args.num_samples} samples...\")\n \n for i, sample in enumerate(loader):\n if i >= args.num_samples:\n break\n \n print(f\"\\nSample {i+1}:\")\n print(f\" Key: {sample.get('__key__', 'N/A')}\")\n print(f\" Text length: {len(sample.get('text', ''))} characters\")\n print(f\" Metadata: {', '.join(k for k in sample.keys() if k not in ['text', '__key__'])}\")\n \n logger.info(f\"Successfully loaded {min(i+1, args.num_samples)} samples\")\n return 0\n\n\nif __name__ == \"__main__\":\n import sys\n sys.exit(main())\n'''\n\n# Write the loader script\n(project_root / 'webdataset_loader.py').write_text(loader_script)\nprint(f\"\\nβœ“ Created {project_root / 'webdataset_loader.py'}\")\nprint(f\" Size: {len(loader_script)} bytes\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "tjcbvnxrcxn", + "source": "# 3. requirements.txt for WebDataset tools\nrequirements_txt = '''# WebDataset Conversion and Loading Requirements\n# For converting HuggingFaceFW/fineweb-edu to WebDataset format\n\n# Core dependencies\ndatasets>=2.14.0\nwebdataset>=0.2.48\ntorch>=2.0.0\ntqdm>=4.65.0\n\n# Optional but recommended\nnumpy>=1.24.0\n'''\n\n# Write requirements.txt\n(project_root / 'webdataset_requirements.txt').write_text(requirements_txt)\nprint(f\"\\nβœ“ Created {project_root / 'webdataset_requirements.txt'}\")\nprint(f\" Size: {len(requirements_txt)} bytes\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "kttxvkmvl3d", + "source": "# 4. example_usage.py\nexample_script = r'''#!/usr/bin/env python3\n\"\"\"\nExample usage of WebDataset conversion and loading scripts.\n\nThis script demonstrates:\n1. Converting a small sample of fineweb-edu to WebDataset format\n2. Validating checksums\n3. Loading data with the WebDataset loader\n4. Using the loader with PyTorch DataLoader\n\"\"\"\n\nimport logging\nfrom pathlib import Path\n\n# Import our modules\nfrom convert_to_webdataset import WebDatasetConverter\nfrom webdataset_loader import WebDatasetLoader, verify_checksums, default_collate_fn\n\n# Configure logging\nlogging.basicConfig(\n level=logging.INFO,\n format='%(asctime)s - %(levelname)s - %(message)s'\n)\nlogger = logging.getLogger(__name__)\n\n\ndef example_1_basic_conversion():\n \"\"\"Example 1: Basic conversion of a small dataset sample.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 1: Basic Conversion\")\n logger.info(\"=\"*80)\n \n # Convert a small sample (1000 documents) for testing\n converter = WebDatasetConverter(\n dataset_name=\"HuggingFaceFW/fineweb-edu\",\n config_name=\"sample-10BT\", # Use the 10BT sample\n split=\"train\",\n output_dir=\"./webdataset_sample\",\n shard_size_mb=50, # Smaller shards for testing\n max_samples=1000, # Just 1000 samples\n streaming=True\n )\n \n logger.info(\"Starting conversion...\")\n converter.convert()\n logger.info(\"Conversion complete!\\n\")\n\n\ndef example_2_validate_checksums():\n \"\"\"Example 2: Validate checksums for converted dataset.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 2: Checksum Validation\")\n logger.info(\"=\"*80)\n \n data_dir = \"./webdataset_sample\"\n \n logger.info(f\"Validating checksums in {data_dir}...\")\n valid = verify_checksums(data_dir)\n \n if valid:\n logger.info(\"βœ“ All checksums are valid!\")\n else:\n logger.error(\"βœ— Checksum validation failed!\")\n \n logger.info(\"\")\n\n\ndef example_3_basic_loading():\n \"\"\"Example 3: Basic loading and iteration.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 3: Basic Loading\")\n logger.info(\"=\"*80)\n \n # Create loader\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n shuffle=False\n )\n \n # Load and display a few samples\n logger.info(\"Loading first 5 samples...\")\n for i, sample in enumerate(loader):\n if i >= 5:\n break\n \n logger.info(f\"\\nSample {i+1}:\")\n logger.info(f\" Sample ID: {sample.get('sample_id', 'N/A')}\")\n logger.info(f\" Text length: {len(sample.get('text', ''))} characters\")\n logger.info(f\" URL: {sample.get('url', 'N/A')}\")\n logger.info(f\" Score: {sample.get('score', 'N/A')}\")\n logger.info(f\" Token count: {sample.get('token_count', 'N/A')}\")\n logger.info(f\" Language: {sample.get('language', 'N/A')}\")\n \n # Show first 200 characters of text\n text_preview = sample.get('text', '')[:200]\n logger.info(f\" Text preview: {text_preview}...\")\n \n logger.info(\"\")\n\n\ndef example_4_with_filtering():\n \"\"\"Example 4: Loading with filtering.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 4: Loading with Filtering\")\n logger.info(\"=\"*80)\n \n # Define a filter function (e.g., only high-quality documents)\n def high_quality_filter(sample):\n \"\"\"Only keep samples with score >= 3.0.\"\"\"\n score = sample.get('score')\n return score is not None and score >= 3.0\n \n # Create loader with filter\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n filter_fn=high_quality_filter,\n shuffle=False\n )\n \n # Count filtered samples\n logger.info(\"Counting high-quality samples (score >= 3.0)...\")\n count = 0\n scores = []\n \n for sample in loader:\n count += 1\n scores.append(sample.get('score', 0))\n if count >= 100: # Check first 100\n break\n \n logger.info(f\"Found {count} high-quality samples\")\n logger.info(f\"Average score: {sum(scores) / len(scores):.2f}\")\n logger.info(f\"Min score: {min(scores):.2f}\")\n logger.info(f\"Max score: {max(scores):.2f}\")\n logger.info(\"\")\n\n\ndef example_5_with_transformation():\n \"\"\"Example 5: Loading with transformation.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 5: Loading with Transformation\")\n logger.info(\"=\"*80)\n \n # Define a transformation function\n def transform_sample(sample):\n \"\"\"Add computed features to sample.\"\"\"\n # Add word count\n text = sample.get('text', '')\n sample['word_count'] = len(text.split())\n \n # Add character count\n sample['char_count'] = len(text)\n \n # Truncate text to first 500 characters for memory efficiency\n sample['text_truncated'] = text[:500]\n \n return sample\n \n # Create loader with transformation\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n transform=transform_sample,\n shuffle=False\n )\n \n # Load and display transformed samples\n logger.info(\"Loading 3 transformed samples...\")\n for i, sample in enumerate(loader):\n if i >= 3:\n break\n \n logger.info(f\"\\nTransformed Sample {i+1}:\")\n logger.info(f\" Word count: {sample.get('word_count', 'N/A')}\")\n logger.info(f\" Char count: {sample.get('char_count', 'N/A')}\")\n logger.info(f\" Token count: {sample.get('token_count', 'N/A')}\")\n logger.info(f\" Truncated text: {sample.get('text_truncated', '')[:100]}...\")\n \n logger.info(\"\")\n\n\ndef example_6_pytorch_dataloader():\n \"\"\"Example 6: Using with PyTorch DataLoader.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 6: PyTorch DataLoader Integration\")\n logger.info(\"=\"*80)\n \n # Create loader\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n shuffle=True, # Shuffle for training\n buffer_size=100\n )\n \n # Create PyTorch DataLoader\n dataloader = loader.get_dataloader(\n batch_size=8,\n num_workers=2,\n collate_fn=default_collate_fn\n )\n \n # Iterate over batches\n logger.info(\"Loading 3 batches...\")\n for i, batch in enumerate(dataloader):\n if i >= 3:\n break\n \n logger.info(f\"\\nBatch {i+1}:\")\n logger.info(f\" Batch size: {len(batch['text'])}\")\n logger.info(f\" Sample IDs: {batch['sample_id'][:3]}...\")\n logger.info(f\" Average text length: {sum(len(t) for t in batch['text']) / len(batch['text']):.0f} chars\")\n \n # Show scores if available\n if 'score' in batch:\n scores = [s for s in batch['score'] if s is not None]\n if scores:\n logger.info(f\" Average score: {sum(scores) / len(scores):.2f}\")\n \n logger.info(\"\")\n\n\ndef example_7_distributed_training():\n \"\"\"Example 7: Simulating distributed training setup.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 7: Distributed Training Simulation\")\n logger.info(\"=\"*80)\n \n # Create loader\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n shuffle=True,\n buffer_size=100\n )\n \n # Create DataLoader with multiple workers\n # Each worker will automatically get a subset of shards\n dataloader = loader.get_dataloader(\n batch_size=4,\n num_workers=4, # 4 workers will split shards among themselves\n collate_fn=default_collate_fn\n )\n \n logger.info(\"DataLoader with 4 workers created\")\n logger.info(\"Each worker will process a subset of shards\")\n logger.info(\"Loading first batch...\")\n \n # Load one batch to verify it works\n batch = next(iter(dataloader))\n logger.info(f\"Successfully loaded batch with {len(batch['text'])} samples\")\n logger.info(\"\")\n\n\ndef main():\n \"\"\"Run all examples.\"\"\"\n logger.info(\"\\n\" + \"=\"*80)\n logger.info(\"WebDataset Conversion and Loading Examples\")\n logger.info(\"=\"*80 + \"\\n\")\n \n try:\n # Example 1: Convert dataset\n example_1_basic_conversion()\n \n # Example 2: Validate checksums\n example_2_validate_checksums()\n \n # Example 3: Basic loading\n example_3_basic_loading()\n \n # Example 4: With filtering\n example_4_with_filtering()\n \n # Example 5: With transformation\n example_5_with_transformation()\n \n # Example 6: PyTorch DataLoader\n example_6_pytorch_dataloader()\n \n # Example 7: Distributed training\n example_7_distributed_training()\n \n logger.info(\"=\"*80)\n logger.info(\"All examples completed successfully!\")\n logger.info(\"=\"*80)\n \n except Exception as e:\n logger.error(f\"Error running examples: {e}\", exc_info=True)\n return 1\n \n return 0\n\n\nif __name__ == \"__main__\":\n import sys\n sys.exit(main())\n'''\n\n# Write the example script\n(project_root / 'example_usage.py').write_text(example_script)\nprint(f\"\\nβœ“ Created {project_root / 'example_usage.py'}\")\nprint(f\" Size: {len(example_script)} bytes\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "m1o9mhn826q", + "source": "# 5. README.md - Complete documentation\nreadme_content = '''# WebDataset Conversion for FineWeb-Edu\n\nComplete toolkit for converting HuggingFaceFW/fineweb-edu dataset to WebDataset format with checksum validation and streaming support.\n\n## Overview\n\nThis project provides production-ready tools for:\n- Converting HuggingFace datasets to WebDataset tar format\n- Validating data integrity with SHA256 checksums\n- Streaming large datasets efficiently\n- PyTorch DataLoader integration\n- Distributed training support\n\n## Features\n\n### Conversion (`convert_to_webdataset.py`)\n- βœ… Streaming mode for memory-efficient processing\n- βœ… Configurable shard sizes (~500MB default)\n- βœ… SHA256 checksum generation per shard\n- βœ… Comprehensive metadata tracking\n- βœ… Progress bars and detailed logging\n- βœ… Support for all fineweb-edu configurations\n\n### Loading (`webdataset_loader.py`)\n- βœ… Checksum validation before loading\n- βœ… PyTorch `IterableDataset` interface\n- βœ… Automatic worker-based shard distribution\n- βœ… Optional shuffling with configurable buffer\n- βœ… Sample filtering and transformation\n- βœ… Compatible with PyTorch DataLoader\n\n## Installation\n\n### Basic Installation\n\n```bash\npip install -r webdataset_requirements.txt\n```\n\n### Using uv (Recommended)\n\n```bash\n# If you have uv installed\nuv pip install -r webdataset_requirements.txt\n```\n\n### Dependencies\n\n- `datasets>=2.14.0` - HuggingFace datasets library\n- `webdataset>=0.2.48` - WebDataset format support\n- `torch>=2.0.0` - PyTorch for DataLoader\n- `tqdm>=4.65.0` - Progress bars\n- `numpy>=1.24.0` - Numerical operations\n\n## Quick Start\n\n### 1. Convert Dataset\n\nConvert a small sample for testing:\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-10BT \\\\\n --output-dir ./webdataset_output \\\\\n --shard-size 500 \\\\\n --max-samples 10000\n```\n\nConvert the full dataset:\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-350BT \\\\\n --output-dir ./webdataset_full \\\\\n --shard-size 500\n```\n\n### 2. Validate Checksums\n\n```bash\npython webdataset_loader.py ./webdataset_output --validate-only\n```\n\n### 3. Load and Use Data\n\n```python\nfrom webdataset_loader import WebDatasetLoader\n\n# Create loader\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n shuffle=True,\n buffer_size=1000\n)\n\n# Iterate over samples\nfor sample in loader:\n text = sample['text']\n metadata = sample['id'], sample['url'], sample['score']\n # ... process sample\n```\n\n### 4. Use with PyTorch DataLoader\n\n```python\nfrom webdataset_loader import WebDatasetLoader, default_collate_fn\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n shuffle=True\n)\n\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=4,\n collate_fn=default_collate_fn\n)\n\nfor batch in dataloader:\n texts = batch['text'] # List of strings\n scores = batch['score'] # List of floats\n # ... train your model\n```\n\n## Detailed Usage\n\n### Conversion Script\n\n#### Command-Line Arguments\n\n```bash\npython convert_to_webdataset.py [OPTIONS]\n\nOptions:\n --dataset TEXT HuggingFace dataset name\n [default: HuggingFaceFW/fineweb-edu]\n \n --config TEXT Dataset configuration\n Options: sample-10BT, sample-100BT, sample-350BT\n [default: None]\n \n --split TEXT Dataset split to convert\n [default: train]\n \n --output-dir TEXT Output directory for shards\n [default: ./webdataset_output]\n \n --shard-size INT Target shard size in MB\n [default: 500]\n \n --max-samples INT Maximum samples to convert (for testing)\n [default: None (all samples)]\n \n --no-streaming Disable streaming mode\n [default: streaming enabled]\n```\n\n#### Python API\n\n```python\nfrom convert_to_webdataset import WebDatasetConverter\n\nconverter = WebDatasetConverter(\n dataset_name=\"HuggingFaceFW/fineweb-edu\",\n config_name=\"sample-10BT\",\n split=\"train\",\n output_dir=\"./my_dataset\",\n shard_size_mb=500,\n max_samples=None, # Convert all samples\n streaming=True\n)\n\nconverter.convert()\n```\n\n#### Output Structure\n\n```\nwebdataset_output/\nβ”œβ”€β”€ fineweb_edu_000000.tar # Shard 0 (~500MB)\nβ”œβ”€β”€ fineweb_edu_000001.tar # Shard 1 (~500MB)\nβ”œβ”€β”€ ...\nβ”œβ”€β”€ checksums.json # SHA256 checksums\n└── dataset_metadata.json # Dataset info\n```\n\n#### Sample Format in Tar Files\n\nEach sample consists of two files:\n- `sample_000000000000.txt` - Plain text content\n- `sample_000000000000.json` - Metadata with fields:\n - `id`: Document ID\n - `url`: Source URL\n - `dump`: Dump identifier\n - `score`: Quality score\n - `token_count`: Number of tokens\n - `language`: Language code\n - `language_score`: Language detection confidence\n - `sample_id`: WebDataset sample ID\n - `sample_index`: Index in original dataset\n\n### Loading Script\n\n#### Command-Line Usage\n\n```bash\n# Validate checksums only\npython webdataset_loader.py ./webdataset_output --validate-only\n\n# Load and display samples\npython webdataset_loader.py ./webdataset_output --num-samples 10\n\n# Skip validation (faster, but risky)\npython webdataset_loader.py ./webdataset_output --no-validate\n```\n\n#### Python API - Basic Usage\n\n```python\nfrom webdataset_loader import WebDatasetLoader\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True, # Validate before loading\n shuffle=False, # Don't shuffle\n buffer_size=1000, # Buffer size for shuffling\n transform=None, # No transformation\n filter_fn=None, # No filtering\n shard_pattern=\"*.tar\" # Glob pattern for shards\n)\n\n# Iterate over samples\nfor sample in loader:\n print(sample['text'])\n print(sample['score'])\n```\n\n#### Python API - With Filtering\n\n```python\ndef high_quality_filter(sample):\n \"\"\"Only keep high-quality documents.\"\"\"\n return sample.get('score', 0) >= 3.0\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n filter_fn=high_quality_filter\n)\n\nfor sample in loader:\n # All samples have score >= 3.0\n process(sample)\n```\n\n#### Python API - With Transformation\n\n```python\ndef add_features(sample):\n \"\"\"Add computed features.\"\"\"\n text = sample['text']\n sample['word_count'] = len(text.split())\n sample['char_count'] = len(text)\n return sample\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n transform=add_features\n)\n\nfor sample in loader:\n print(f\"Words: {sample['word_count']}\")\n```\n\n#### Python API - PyTorch DataLoader\n\n```python\nfrom webdataset_loader import WebDatasetLoader, default_collate_fn\nimport torch\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n shuffle=True,\n buffer_size=10000\n)\n\n# Create DataLoader\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=4,\n pin_memory=True,\n collate_fn=default_collate_fn\n)\n\n# Training loop\nfor epoch in range(10):\n for batch in dataloader:\n texts = batch['text'] # List of 32 strings\n scores = batch['score'] # List of 32 floats\n \n # Your training code here\n loss = model(texts, scores)\n loss.backward()\n optimizer.step()\n```\n\n#### Distributed Training\n\nThe loader automatically handles worker-based shard distribution:\n\n```python\n# Each worker gets a subset of shards\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=8, # 8 workers split shards among themselves\n pin_memory=True\n)\n\n# No additional code needed - sharding is automatic!\n```\n\n### Example Usage Script\n\nRun all examples:\n\n```bash\npython example_usage.py\n```\n\nThis demonstrates:\n1. Basic conversion\n2. Checksum validation\n3. Basic loading\n4. Loading with filtering\n5. Loading with transformation\n6. PyTorch DataLoader integration\n7. Distributed training simulation\n\n## Advanced Usage\n\n### Custom Collate Function\n\nCreate a custom collate function for batching:\n\n```python\nimport torch\n\ndef custom_collate_fn(batch):\n \"\"\"Custom batching with tokenization.\"\"\"\n from transformers import AutoTokenizer\n \n tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n \n # Extract texts\n texts = [sample['text'] for sample in batch]\n \n # Tokenize\n encoded = tokenizer(\n texts,\n padding=True,\n truncation=True,\n max_length=512,\n return_tensors='pt'\n )\n \n return {\n 'input_ids': encoded['input_ids'],\n 'attention_mask': encoded['attention_mask'],\n 'scores': torch.tensor([s['score'] for s in batch])\n }\n\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=4,\n collate_fn=custom_collate_fn\n)\n```\n\n### Multi-GPU Training\n\n```python\nimport torch\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\n# Initialize distributed training\ndist.init_process_group(\"nccl\")\nrank = dist.get_rank()\nworld_size = dist.get_world_size()\n\n# Create loader (same on all processes)\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n shuffle=True\n)\n\n# Create DataLoader with appropriate workers\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=4\n)\n\n# Wrap model with DDP\nmodel = DDP(model, device_ids=[rank])\n\n# Training loop (each GPU processes different shards)\nfor batch in dataloader:\n # ... training code\n```\n\n### Checksum Validation Utilities\n\n```python\nfrom webdataset_loader import ChecksumValidator, verify_checksums\n\n# Method 1: Simple validation\nvalid = verify_checksums(\"./webdataset_output\")\nprint(f\"Checksums valid: {valid}\")\n\n# Method 2: Detailed validation\nfrom pathlib import Path\n\nvalidator = ChecksumValidator(\n Path(\"./webdataset_output/checksums.json\")\n)\n\n# Validate specific shard\nshard_path = Path(\"./webdataset_output/fineweb_edu_000000.tar\")\nis_valid = validator.validate_shard(shard_path)\n\n# Validate all shards\nall_valid = validator.validate_all_shards(\n Path(\"./webdataset_output\")\n)\n```\n\n## Configuration Examples\n\n### Small Test Dataset\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-10BT \\\\\n --output-dir ./test_dataset \\\\\n --shard-size 50 \\\\\n --max-samples 1000\n```\n\nOutput: ~1000 samples in small shards for quick testing\n\n### Medium Dataset\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-100BT \\\\\n --output-dir ./medium_dataset \\\\\n --shard-size 500\n```\n\nOutput: ~100B tokens in 500MB shards\n\n### Full Dataset\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-350BT \\\\\n --output-dir ./full_dataset \\\\\n --shard-size 500\n```\n\nOutput: ~350B tokens in 500MB shards\n\n### Custom Dataset\n\n```python\nfrom convert_to_webdataset import WebDatasetConverter\n\n# Convert any HuggingFace dataset\nconverter = WebDatasetConverter(\n dataset_name=\"your-org/your-dataset\",\n config_name=\"your-config\",\n split=\"train\",\n output_dir=\"./custom_dataset\",\n shard_size_mb=500,\n streaming=True\n)\n\nconverter.convert()\n```\n\n## Performance Tips\n\n### Conversion Performance\n\n1. **Use streaming mode** (default) for large datasets\n2. **Adjust shard size** based on your storage:\n - Smaller shards (100MB): More files, faster per-shard processing\n - Larger shards (1GB): Fewer files, better for slow filesystems\n3. **Set max_samples** for testing before full conversion\n\n### Loading Performance\n\n1. **Use multiple workers**: `num_workers=4-8` for DataLoader\n2. **Enable pin_memory**: `pin_memory=True` for GPU training\n3. **Tune buffer_size**: Larger = better shuffling, more memory\n4. **Skip validation** after first check: `validate_checksums=False`\n\n### Memory Usage\n\n- Streaming mode: O(1) memory during conversion\n- Loading: O(buffer_size) for shuffling\n- Workers: Each worker loads one shard at a time\n\n## Troubleshooting\n\n### Issue: Checksum validation fails\n\n**Cause**: Corrupted shard or interrupted download\n\n**Solution**:\n```bash\n# Re-validate to identify corrupt shards\npython webdataset_loader.py ./webdataset_output --validate-only\n\n# Re-convert if needed\npython convert_to_webdataset.py --config sample-10BT --output-dir ./webdataset_output\n```\n\n### Issue: Out of memory during conversion\n\n**Cause**: Not using streaming mode\n\n**Solution**:\n```bash\n# Ensure streaming is enabled (default)\npython convert_to_webdataset.py --config sample-10BT\n```\n\n### Issue: Slow data loading\n\n**Cause**: Not using enough workers\n\n**Solution**:\n```python\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=8, # Increase workers\n pin_memory=True\n)\n```\n\n### Issue: Workers getting same data\n\n**Cause**: Not using `IterableDataset` correctly\n\n**Solution**: The WebDatasetLoader automatically handles worker sharding. Make sure you're using PyTorch >= 2.0.\n\n### Issue: Shards not found\n\n**Cause**: Wrong directory or glob pattern\n\n**Solution**:\n```python\n# Check the directory\nimport os\nprint(os.listdir(\"./webdataset_output\"))\n\n# Adjust shard_pattern if needed\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n shard_pattern=\"fineweb_edu_*.tar\" # More specific pattern\n)\n```\n\n## File Structure\n\n```\n.\nβ”œβ”€β”€ convert_to_webdataset.py # Conversion script\nβ”œβ”€β”€ webdataset_loader.py # Loading script\nβ”œβ”€β”€ example_usage.py # Usage examples\nβ”œβ”€β”€ webdataset_requirements.txt # Dependencies\n└── README.md # This file\n\n# After conversion:\nwebdataset_output/\nβ”œβ”€β”€ fineweb_edu_000000.tar # Shard 0\nβ”œβ”€β”€ fineweb_edu_000001.tar # Shard 1\nβ”œβ”€β”€ ...\nβ”œβ”€β”€ checksums.json # Checksums\n└── dataset_metadata.json # Metadata\n```\n\n## Dataset Information\n\n### HuggingFaceFW/fineweb-edu\n\nFineWeb-Edu is a high-quality educational subset of the FineWeb dataset:\n- **Size**: Up to 1.3T tokens (full version)\n- **Quality**: Filtered for educational content\n- **Language**: Primarily English\n- **Source**: Common Crawl\n- **License**: ODC-By 1.0\n\n### Configurations\n\n- `sample-10BT`: 10B token sample (~10M documents)\n- `sample-100BT`: 100B token sample (~100M documents)\n- `sample-350BT`: 350B token sample (~350M documents)\n- Full dataset: 1.3T tokens\n\n## API Reference\n\n### `WebDatasetConverter`\n\nMain class for converting HuggingFace datasets to WebDataset format.\n\n```python\nclass WebDatasetConverter:\n def __init__(\n self,\n dataset_name: str = \"HuggingFaceFW/fineweb-edu\",\n config_name: Optional[str] = None,\n split: str = \"train\",\n output_dir: str = \"./webdataset_output\",\n shard_size_mb: int = 500,\n max_samples: Optional[int] = None,\n streaming: bool = True\n )\n \n def convert(self) -> None:\n \"\"\"Run the conversion.\"\"\"\n \n def compute_sha256(self, filepath: Path) -> str:\n \"\"\"Compute SHA256 checksum.\"\"\"\n```\n\n### `WebDatasetLoader`\n\nMain class for loading WebDataset with validation.\n\n```python\nclass WebDatasetLoader(IterableDataset):\n def __init__(\n self,\n data_dir: str,\n validate_checksums: bool = True,\n shuffle: bool = False,\n buffer_size: int = 1000,\n transform: Optional[Callable] = None,\n filter_fn: Optional[Callable] = None,\n shard_pattern: str = \"*.tar\"\n )\n \n def __iter__(self) -> Iterator[Dict]:\n \"\"\"Iterate over samples.\"\"\"\n \n def get_dataloader(\n self,\n batch_size: int = 32,\n num_workers: int = 4,\n pin_memory: bool = True,\n collate_fn: Optional[Callable] = None\n ) -> DataLoader:\n \"\"\"Create PyTorch DataLoader.\"\"\"\n```\n\n### Utility Functions\n\n```python\ndef verify_checksums(data_dir: str) -> bool:\n \"\"\"Verify all checksums in directory.\"\"\"\n\ndef default_collate_fn(batch: List[Dict]) -> Dict:\n \"\"\"Default batching function.\"\"\"\n```\n\n## Contributing\n\nContributions are welcome! Please:\n1. Fork the repository\n2. Create a feature branch\n3. Make your changes\n4. Add tests if applicable\n5. Submit a pull request\n\n## License\n\nThis code is provided as-is for working with the FineWeb-Edu dataset.\n\nFineWeb-Edu dataset license: ODC-By 1.0\n\n## Acknowledgments\n\n- HuggingFace for the `datasets` library and FineWeb-Edu dataset\n- WebDataset project for the tar-based format\n- PyTorch team for excellent data loading primitives\n\n## References\n\n- [FineWeb-Edu Dataset](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)\n- [WebDataset Documentation](https://webdataset.github.io/webdataset/)\n- [PyTorch DataLoader](https://pytorch.org/docs/stable/data.html)\n- [HuggingFace Datasets](https://huggingface.co/docs/datasets/)\n\n## Support\n\nFor issues or questions:\n1. Check this README\n2. Run `example_usage.py` for working examples\n3. Validate your setup with small samples first\n4. Check PyTorch and HuggingFace documentation\n\nHappy dataset processing!\n'''\n\n# Write the README\n(project_root / 'WEBDATASET_README.md').write_text(readme_content)\nprint(f\"\\nβœ“ Created {project_root / 'WEBDATASET_README.md'}\")\nprint(f\" Size: {len(readme_content)} bytes\")\n\nprint(\"\\n\" + \"=\"*80)\nprint(\"ALL FILES CREATED SUCCESSFULLY!\")\nprint(\"=\"*80)\nprint(f\"\\nCreated files in {project_root}:\")\nprint(\" 1. convert_to_webdataset.py - Main conversion script\")\nprint(\" 2. webdataset_loader.py - Streaming loader with validation\")\nprint(\" 3. webdataset_requirements.txt - Dependencies\")\nprint(\" 4. example_usage.py - Complete usage examples\")\nprint(\" 5. WEBDATASET_README.md - Comprehensive documentation\")\nprint(\"\\nNext steps:\")\nprint(\" 1. Install dependencies: pip install -r webdataset_requirements.txt\")\nprint(\" 2. Test conversion: python convert_to_webdataset.py --config sample-10BT --max-samples 100\")\nprint(\" 3. Validate: python webdataset_loader.py ./webdataset_output --validate-only\")\nprint(\" 4. Run examples: python example_usage.py\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "oj722n8665e", + "source": "# Execute all the file creation code above\n# This cell runs all the code to actually create the files\n\nprint(\"Starting file creation process...\")\nprint(\"=\"*80)\n\n# The cells above already contain the code to write the files\n# They should have executed when the notebook cells were added\n\n# Let's verify the files exist\nimport os\nfrom pathlib import Path\n\nproject_root = Path('/Users/akseljoonas/Documents/hf-agent')\n\nfiles_to_check = [\n 'convert_to_webdataset.py',\n 'webdataset_loader.py',\n 'webdataset_requirements.txt',\n 'example_usage.py',\n 'WEBDATASET_README.md'\n]\n\nprint(\"\\nChecking created files:\")\nfor filename in files_to_check:\n filepath = project_root / filename\n if filepath.exists():\n size = filepath.stat().st_size\n print(f\" βœ“ {filename} ({size:,} bytes)\")\n else:\n print(f\" βœ— {filename} - NOT FOUND\")\n\nprint(\"\\n\" + \"=\"*80)\nprint(\"File creation verification complete!\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/eval/generate_rubrics.py b/eval/generate_rubrics.py new file mode 100644 index 0000000000000000000000000000000000000000..081d2db1be3a775b3a3d5de96328296094f5de8c --- /dev/null +++ b/eval/generate_rubrics.py @@ -0,0 +1,403 @@ +#!/usr/bin/env env python3 +""" +Rubric Generation Script for HF-Agent Benchmark + +Generates instance-specific evaluation rubrics following the "Rubrics as Rewards" paper. +Uses LiteLLM to call LLM models for rubric synthesis with expert grounding via reference answers. +""" + +import argparse +import json +import os +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List + +import litellm +import pandas as pd +from dotenv import load_dotenv +from pydantic import BaseModel + +from eval.hf_io import df_to_hub + + +class Rubric(BaseModel): + title: str + description: str + weight: int + + +class RubricList(BaseModel): + rubrics: List[Rubric] + + +# Load environment variables +load_dotenv() + +# Rubric generation prompt template based on RaR paper + + +PROMPT_TEMPLATE = """You are an expert rubric writer. Your job is to generate a self-contained set of evaluation criteria ("rubrics") for judging how good, helpful and complete an agent's trajectory is to a given user question/request. + +Rubrics can cover aspects of a response such as, but not limited to, factual correctness, helpfulness, completeness, harmlessness, correctness of using Hugging Face best practices (based on HF documentation), depth of +reasoning, contextual relevance and usefulness. Each item must be self-contained – non expert readers should not need to +infer anything or consult external information. Begin each description with its category: "Essential Criteria: . . . ", "Important +Criteria: . . . ", "Optional Criteria: . . . ", or "Pitfall Criteria: Does not mention . . . ". + + +Inputs: +- question: <<<{question}>>> +- example_solution (NOT ground truth - just an okay attempt): <<<{example_solution}>>> +- example_trace (NOT ground truth - just an okay attempt showing what tool usage might look like): <<<{example_trace}>>> + +IMPORTANT: The example_solution and example_trace provided are NOT ground truth or ideal solutions. They represent +an attempt at solving the task - they give you a general idea of the shape of the problem and what tool usage +might look like, but they contain mistakes and incomplete solutions, suboptimal approaches, or incomplete answers. Your rubrics MUST be designed to fairly grade a PERFECT solution. The perfect solution is complete in all aspects of solving the task and verifing it's correctness before giving the final answer. It tells the user what was done and why, and provides the final answer clearly answering the user's question. + +Total items: +β€’ Choose 7–20 rubric items based on the complexity of the question. + +Each rubric item: +β€’ title (2–4 words). +β€’ description: One sentence starting with its category prefix that explicitly states exactly what to look for. For example: +– Essential Criteria: Writes a up-to-date, correct, complete and working training loop using the latest Hugging Face best practices. Launches the training with hf-jobs. +– Pitfall Criteria: Deprecated launcher usage. Uses python -m torch.distributed.launch instead of torchrun / accelerate. +– Important Criteria: Explains common DDP knobs. Mentions ddp_find_unused_parameters=False for models with conditional branches; optional ddp_timeout; brief note on when they matter and why. +– Optional Criteria: Briefly notes --deepspeed ds_config.json as an alternative scaler when models get big (but stays on DDP for this Q). +β€’ weight: For Essential/Important/Optional, use 1–5 (5 = most important); for Pitfall, use –1 or –2. + +Category guidance: +β€’ Essential: Critical actions to answer/complete the user's question/request; if missing, the response is invalid and useless (weight 5). +β€’ Important: Key reasoning, completeness, or clarity; strongly affects quality and usefulness (weight 3–4). +β€’ Optional: Helpfulness in educating the user or providing extra depth; nice to have but not deal-breaking (weight 1–2). +β€’ Pitfall: Common mistakes or omissions specific to this promptβ€”identify things a respondent often forgets or misstates. +Each Pitfall description must begin with "Pitfall Criteria: Does not mention . . . " or "Pitfall Criteria: Recommends . . . " +and use weight –1 or –2. + +To ensure self-contained guidance: +β€’ When referring to answer choices, explicitly say "Identifies (A)", "Identifies (B)", etc., rather than vague phrasing. +β€’ If the format requires an action like calling a tool or launching a training run, include a rubric item such as: +– Essential Criteria: Includes a clear statement "Launches the training with hf-jobs.". +β€’ If reasoning should precede the answer, include a rubric like: +– Important Criteria: Presents the explanation and reasoning before stating the final answer. +β€’ If brevity is valued, include a rubric like: +– Optional Criteria: Remains concise and avoids unnecessary detail. +β€’ If the question context demands mention of specific findings/best practices, include that explicitly (e.g., "Essential Criteria: Mentions +that training data must be in "messages" column for LLM training"). + +Output: Provide a JSON array of rubric objects. Each object must contain exactly three keysβ€”title, description, and weight. +Do not copy large blocks of the question or example_solution into the text. Each description must begin with its category +prefix, and no extra keys are allowed. + +Remember: The example_solution and example_trace are NOT ideal answers - they are just rough attempts to show the +general approach. Design rubrics that can fairly evaluate any solution, including ones that are better than the example.""" + + +def build_prompt( + question: str, + example_solution: str, + example_trace: List[Dict[str, Any]], +) -> List[Dict[str, str]]: + """ + Build the messages list for LiteLLM completion. + + Args: + question: The question/task to evaluate + difficulty: The difficulty level of the task + example_solution: An example solution attempt (not ground truth) + example_trace: The agent's message trace showing tool usage + + Returns: + List of message dicts for LiteLLM + """ + # Format the trace for readability - only include key parts + formatted_trace = format_trace_for_prompt(example_trace) + + prompt = PROMPT_TEMPLATE.format( + question=question, + example_solution=example_solution, + example_trace=formatted_trace, + ) + + return [{"role": "user", "content": prompt}] + + +def format_trace_for_prompt(messages: List[Dict[str, Any]]) -> str: + """ + Format the agent message trace for inclusion in the prompt. + Extracts key information while keeping it readable. + """ + if not messages: + return "(No trace available)" + + formatted_parts = [] + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + + # Skip system messages + if role == "system": + continue + + # Handle tool calls + if "tool_calls" in msg and msg["tool_calls"]: + tool_info = [] + for tc in msg["tool_calls"]: + if isinstance(tc, dict) and "function" in tc: + func = tc["function"] + tool_name = func.get("name", "unknown_tool") + tool_info.append(f" - Called: {tool_name}") + if tool_info: + formatted_parts.append( + "[Assistant Tool Calls]\n" + "\n".join(tool_info) + ) + + # Handle regular content + if content: + # Truncate very long content + if len(content) > 500: + content = content[:500] + "... (truncated)" + formatted_parts.append(f"[{role.title()}]\n{content}") + + return "\n\n".join(formatted_parts) if formatted_parts else "(Empty trace)" + + +def validate_rubric(rubric_list: List[Dict[str, Any]]) -> bool: + """ + Validate that rubric meets basic requirements. + + Args: + rubric_list: List of rubric items to validate + + Returns: + True if valid, False otherwise + """ + # Check count + if not (7 <= len(rubric_list) <= 20): + return False + + # Check each item + category_prefixes = [ + "Essential Criteria:", + "Important Criteria:", + "Optional Criteria:", + "Pitfall Criteria:", + ] + + for item in rubric_list: + # Check keys + if set(item.keys()) != {"title", "description", "weight"}: + return False + + # Check description starts with category prefix + if not any( + item["description"].startswith(prefix) for prefix in category_prefixes + ): + return False + + return True + + +def generate_rubric(row: pd.Series, model: str, timeout: int = 120) -> Dict[str, Any]: + """ + Generate rubric for a single question using LiteLLM. + + Args: + row: DataFrame row containing question, difficulty, solution, and messages + model: Model name for LiteLLM + timeout: Request timeout in seconds + + Returns: + Dict with rubric_list and rubric_count, or None on failure + """ + + messages = build_prompt( + question=row["question"], + example_solution=row["solution"], + example_trace=row.get("messages", []), + ) + + try: + response = litellm.completion( + model=model, + messages=messages, + timeout=timeout, + response_format=RubricList, + ) + + # Parse structured output + rubric_list: RubricList = RubricList.model_validate_json( + response.choices[0].message.content + ) + + return rubric_list.model_dump_json() + except Exception as e: + print(f"Error generating rubric: {e}", file=sys.stderr) + return None + + +def load_input_data(infile: str) -> pd.DataFrame: + """ + Load input data from CSV or JSONL file. + + Args: + infile: Path to input file + + Returns: + DataFrame with loaded data + """ + path = Path(infile) + + if not path.exists(): + raise FileNotFoundError(f"Input file not found: {infile}") + + if path.suffix == ".csv": + # Try to auto-detect delimiter (comma or semicolon) + df = pd.read_csv(infile, sep=None, engine="python") + elif path.suffix == ".jsonl": + df = pd.read_json(infile, lines=True) + else: + raise ValueError(f"Unsupported file format: {path.suffix}. Use .csv or .jsonl") + + # Validate required columns + required_cols = [ + "question", + "solution", + ] + optional_cols = ["difficulty", "messages", "error"] + missing_cols = [col for col in required_cols if col not in df.columns] + + if missing_cols: + raise ValueError(f"Missing required columns: {missing_cols}") + + # Log available optional columns + available_optional = [col for col in optional_cols if col in df.columns] + print(f"Found optional columns: {available_optional}") + + return df + + +def main(): + parser = argparse.ArgumentParser( + description="Generate rubrics for HF-agent benchmark evaluation" + ) + parser.add_argument( + "--infile", type=str, required=True, help="Input file path (.csv or .jsonl)" + ) + parser.add_argument( + "--outfile", type=str, required=True, help="Output JSONL file path" + ) + parser.add_argument( + "--model", + type=str, + default="anthropic/claude-sonnet-4-5-20250929", + help="LiteLLM model name (default: from LITELLM_MODEL env or gpt-4o-mini)", + ) + parser.add_argument( + "--timeout", + type=int, + default=120, + help="Request timeout in seconds (default: 120)", + ) + parser.add_argument( + "--max-concurrent", + type=int, + default=30, + help="Maximum number of concurrent workers (default: 30)", + ) + parser.add_argument( + "--push-to-hub", + type=str, + default=None, + help="Push to HuggingFace dataset (e.g., username/dataset@rubrics)", + ) + + args = parser.parse_args() + + # Determine model + model = args.model or os.getenv("LITELLM_MODEL", "gpt-4o-mini") + print(f"Using model: {model}") + + # Load input data + print(f"Loading data from {args.infile}...") + df = load_input_data(args.infile) + print(f"Loaded {len(df)} examples") + + # Run rubric generation in parallel using ThreadPoolExecutor + print(f"Running generation with {args.max_concurrent} parallel workers...") + + with ThreadPoolExecutor(max_workers=args.max_concurrent) as executor: + # Submit all tasks + future_to_idx = {} + for idx, row in df.iterrows(): + future = executor.submit( + generate_rubric, + row=row, + model=model, + timeout=args.timeout, + ) + future_to_idx[future] = idx + + # Collect results in order + results = [None] * len(df) + completed = 0 + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + results[idx] = future.result() + completed += 1 + print(f"Completed: {completed}/{len(df)}", end="\r") + + print() # New line after progress + + # Prepare results DataFrame + print("Preparing results...") + output_rows = [] + success_count = 0 + failure_count = 0 + + for idx, (_, row) in enumerate(df.iterrows()): + rubric_result = results[idx] + + if rubric_result is None: + failure_count += 1 + continue + + # Merge with original data + output_row = row.to_dict() + output_row["messages"] = json.dumps(output_row["messages"]) + output_row["rubric"] = rubric_result + output_rows.append(output_row) + success_count += 1 + + # Create DataFrame with results + results_df = pd.DataFrame(output_rows) + + # Upload to HuggingFace if specified (before saving JSONL) + if args.push_to_hub: + print(f"\nUploading to HuggingFace: {args.push_to_hub}") + upload_success = df_to_hub( + df=results_df, + dataset_spec=args.push_to_hub, + split="train", + private=False, + ) + if not upload_success: + print("Warning: HuggingFace push failed, but continuing to save JSONL...") + + # Write results to JSONL file + print(f"\nWriting results to {args.outfile}...") + with open(args.outfile, "w") as outf: + for output_row in output_rows: + outf.write(json.dumps(output_row, default=str) + "\n") + + print("\nComplete!") + print(f"Success: {success_count}/{len(df)}") + print(f"Failures: {failure_count}/{len(df)}") + print(f"Output written to: {args.outfile}") + if args.push_to_hub and upload_success: + print(f"Pushed to: {args.push_to_hub}") + + +if __name__ == "__main__": + main() diff --git a/eval/generated_tasks_with_difficulty.json b/eval/generated_tasks_with_difficulty.json new file mode 100644 index 0000000000000000000000000000000000000000..344347fe1e25398dc00a7f7c2ebb7f50e02660d5 --- /dev/null +++ b/eval/generated_tasks_with_difficulty.json @@ -0,0 +1,255 @@ +{ + "Evaluate models {M_i} on benchmarks {B_i}": "Easy", + "Train models {M_i} on datasets {D_i} with benchmarks {B_i}": "Medium", + "Run an ablation for hyperparameter P for model M on dataset D": "Hard", + "Generate completions with model M on dataset D using engine E": "Medium", + "Merge models {M_i} using linear averaging to find the best result on benchmarks {B_i}": "Hard", + "Given datasets {D_i}, ablate the best SFT mixture for model M across benchmarks {B_i}": "Very hard", + "Decontaminate dataset D against benchmarks {B_i}": "Hard", + "Benchmark RL framework F for best throughput on G GPUs": "Very hard", + "Implement post-training algorithm A from paper P in framework F. Validate it runs end-to-end": "Very hard", + "Implement benchmark B in framework F. Validate it reproduces some published results": "Very hard", + "Format dataset D for compatibility with framework F on task T": "Easy", + "Remove the background from this image: [image path]": "Easy", + "Transcribe all of the audio files in this directory": "Easy", + "Transcribe all of the audio files in this directory, choose the model that'll be cheapest and also relatively accurate": "Medium (judgment call or interaction needed to figure out what accuracy levels are acceptable)", + "Remove the background music from this audio file": "Medium (needs to find Gradio Space and call its API0", + "Change this video track to be from English to Spanish": "Medium (needs to link several models together)", + "Translate this flyer from English to Spanish, keeping the layout and images the same": "Medium (needs to link several models together)", + "What's the best model for X?": "Easy", + "What datasets are available for X? (X={domain x task x modality})": "Easy", + "Is there a space to do Y?": "Easy", + "I have this script and this error - what's the issue?": "Medium", + "This space is broken, how can i fix it?": "Medium", + "I built a space but it is super slow. What can I do?": "Medium", + "How can I run modal X locally?": "Medium", + "I want to build a space with model Y to do X?": "Hard", + "How can I serve a model with multiple LoRAs?": "Hard", + "What's the best model for sentiment analysis on financial text?": "Easy", + "Are there any medical image segmentation datasets on HuggingFace for CT scans?": "Easy", + "Which text classification models support 4-bit quantization?": "Medium", + "Are there inference endpoints available for Whisper large-v3?": "Easy", + "What's the license for the SA-Med2D-20M dataset?": "Easy", + "Which vision models fit in 8GB VRAM for image segmentation?": "Medium", + "What datasets are available for 3D medical image segmentation?": "Medium", + "Is there a space to do text-to-speech with emotion control?": "Medium", + "I'm getting \"CUDA out of memory\" when loading Llama-2-7b even though nvidia-smi shows I have 6GB free - what's the issue?": "Medium", + "My Gradio space shows \"Connection errored out\" after working fine yesterday, no code changes - how can I fix it?": "Medium", + "I built a Gradio space for Stable Diffusion but inference takes 5+ minutes on a 4090 - what can I do?": "Medium", + "My Whisper model outputs different transcriptions after quantization to int8 - why?": "Medium", + "Getting \"RuntimeError: CUDA error: out of memory. Tried to allocate 70.00 MiB\" but only 2.87 GiB is allocated - what's happening?": "Medium", + "My HuggingFace space build fails with \"failed to create containerd task\" - how to fix?": "Medium", + "DistilBERT model gives \"you should probably train your model\" warning even though it's a pretrained model from the Hub": "Easy", + "Space was working fine but now receiving build errors - receiving this error even with a new space": "Medium", + "Inference is correct locally but wrong on deployed space": "Medium", + "Getting CUDA OOM despite having enough memory according to nvidia-smi": "Medium", + "How can I run Mistral-7B-v0.1 locally with multiple LoRA adapters?": "Hard", + "How can I serve Llama-2-7b with vLLM and dynamically load multiple LoRA adapters?": "Hard", + "How do I batch inference requests in my Gradio space for better throughput?": "Medium", + "Can I run Whisper large-v3 with faster-whisper for 4x speedup?": "Medium", + "How to run Llama 2 on CPU after fine-tuning with LoRA?": "Medium", + "Best way to handle 50+ concurrent requests in a Gradio space without OOM?": "Hard", + "How do I add custom stopping criteria for text generation with Transformers?": "Hard", + "Can I merge multiple LoRA adapters before inference to reduce latency?": "Hard", + "How can I optimize my LLM inference with one base LLM and multiple LoRA adapters?": "Hard", + "Compare tokenizers {T_i} for model M on tasks {classification, QA}; report accuracy and average sequence length per task": "Medium", + "Run a LoRA rank sweep (r in {4, 8, 16, 32}) for model M on dataset D; plot validation perplexity vs VRAM usage and select Pareto-optimal settings": "Hard", + "Build a streaming dataloader from Parquet on S3 with deterministic shuffling across N workers; validate epoch reproducibility": "Very hard", + "Find three open-source TTS models with emotion control and list their sample rates and licenses": "Easy", + "Create a retrieval-augmented QA pipeline: index corpus C with FAISS, connect to model M, and benchmark top-1 accuracy and p95 latency": "Hard", + "Diagnose a Space where memory grows per request; add no-grad guards, free caches, and demonstrate stable RSS over 10,000 calls": "Hard", + "Deduplicate dataset D using MinHash LSH at Jaccard >= 0.9 and publish a cleaned HF dataset with provenance columns": "Medium", + "Add special tokens to tokenizer T and resize model M embeddings; resume pretraining for 10k steps without loss spikes": "Hard", + "Create a HuggingFace Dataset from CSV file data.csv and push to repo username/my_dataset": "Easy", + "Build a real-time Whisper transcription Space with VAD and chunked decoding; keep end-to-end latency under 200 ms": "Hard", + "Quantize model M to 4-bit (bnb.int4) with bitsandbytes; compare perplexity and p95 latency to 8-bit on dataset D; select config with <1% perplexity increase": "Medium", + "Fuse LoRA adapter A into base model M and export a single safetensors checkpoint; verify logits parity (<1e-5 MSE) vs on-the-fly LoRA": "Hard", + "Redact PII from dataset D using a transformer NER pipeline; produce a cleaned HuggingFace Dataset with per-entity removal stats and provenance": "Medium", + "Train a SentencePiece tokenizer (vocab=64k, byte fallback) on corpus C; compare tokenization speed, unknown-token rate, and bytes/token vs tokenizer T": "Hard", + "Build a sharded FAISS IVF-PQ index for 100M embeddings stored on S3; integrate with HF datasets streaming and report recall@10 and QPS": "Very hard", + "Fine-tune model M with QLoRA using TRL PPO on dataset D; log KL, reward, and throughput; validate no divergence on a held-out eval": "Hard", + "Resolve HfHubHTTPError 401 when pushing dataset repo R: diagnose token scopes, git-lfs config, and large file thresholds; document the fix": "Medium", + "Implement a custom Transformers LogitsProcessor that bans repeated bigrams; add unit tests and benchmark generation quality (BLEU) on dataset D": "Hard", + "List and download all Hub models tagged 'text-classification' with Apache-2.0 license and size <500MB; save model ids and downloads to CSV": "Easy", + "Enable speculative decoding in vLLM with draft model D for base model M; benchmark tokens/sec speedup at batch sizes {1,4,16} and max_new_tokens {64,256}": "Very hard", + "Profile model M under torch.compile modes {reduce-overhead, max-autotune} on GPU G; report tokens/sec, peak VRAM, and compile overhead": "Medium", + "Detect and remove near-duplicate images in dataset D using CLIP ViT-L/14 embeddings at cosine >= 0.95; publish a cleaned dataset with duplicate_group ids": "Medium", + "Convert a TensorFlow SavedModel of T5-base to Transformers PyTorch format; verify logits parity (MSE < 1e-4) on 1,000 random prompts": "Hard", + "Enable FlashAttention-2 in a Transformers training loop for model M; benchmark step time and confirm loss parity over 2,000 steps vs baseline": "Hard", + "Deploy vLLM for model M with hot-swappable LoRA adapters {A_i}; provide an API to switch adapters and demonstrate <200 ms switch latency under load": "Very hard", + "Implement a custom Trainer callback to log gradient norms, activation histograms, and learning rate; diagnose periodic loss spikes and propose a fix": "Hard", + "Build a bilingual RAG pipeline indexing corpora {en, es} with FAISS HNSW; evaluate exact match@1 on dataset D and report p95 latency": "Hard", + "Run a mixed-precision sweep (fp16 vs bf16) for model M on A100 and RTX 3090; compare convergence, throughput, and numerical stability issues": "Medium", + "Create a Gradio Space that batches Whisper-large-v3 transcription via queue + chunked decoding; maintain real-time factor <= 0.5 on a T4": "Hard", + "List five OCR datasets on the Hub with line-level annotations; include licenses and approximate image counts": "Easy", + "List models on the Hub tagged 'summarization' that offer safetensors weights and 4-bit quantization; output model ids": "Easy", + "Evaluate safety filters of models {M_i} on red-team prompt set R; report jailbreak rate and false positive rate": "Medium", + "Run a prompt template ablation for chat model M on dataset D; compare {alpaca, chatml, llama2} formats and report exact match and average output length": "Hard", + "Implement tensor parallelism for model M in framework F and show linear scaling across 2\u20138 GPUs with <=10% gap from ideal": "Very hard", + "Convert and shard dataset D into WebDataset tar files (~500MB/shard); build a streaming loader with checksum validation": "Medium", + "Deploy a Spaces app serving Stable Diffusion XL with ControlNet; add output caching and keep p95 latency <1s for 20 concurrent users": "Hard", + "Diagnose and fix 'shape mismatch' when loading LoRA into model M after tokenizer resize; provide minimal repro and patch": "Medium", + "Add a detailed model card to repo username/model_M with training data, intended use, limitations, and evaluation results": "Easy", + "Enable KV cache quantization (int8) in Transformers for model M; compare tokens/sec and ROUGE-L on dataset D vs fp16 cache": "Hard", + "Detect and redact license-incompatible samples in dataset D by matching SPDX identifiers and source domains; publish a compliance report": "Medium", + "Profile vLLM serving of model M with paged attention; tune block_size to maximize tokens/sec and report p50/p95 latency and peak VRAM": "Medium", + "Filter dataset D for toxic content using classifier C; log per-label removal rates and recreate stratified train/valid/test splits": "Medium", + "Train a unigram tokenizer (vocab=80k) on corpora {en, fr}; fine-tune T5-small and compare BLEU vs a BPE baseline; report tokenization speed and OOV rate": "Hard", + "Run distributed evaluation of models {M_i} on benchmark B across 4 GPUs with DeepSpeed-Inference; ensure identical metrics across 3 seeds": "Hard", + "Find three open-source ASR models that provide word-level timestamps; record licenses and expected WER on LibriSpeech": "Easy", + "Diagnose intermittent 'Address already in use' crashes in a FastAPI Space; add graceful shutdown and port probing, verifying stability over 1,000 restart cycles": "Medium", + "Export a LoRA-finetuned Llama checkpoint to GGUF for llama.cpp; validate perplexity parity (<=1% drift) on WikiText-2": "Hard", + "Construct a streaming RAG pipeline over S3-stored corpus C with Chroma; index ~1B tokens, implement shard rebalancing, and benchmark recall@5 and QPS": "Very hard", + "List Hub datasets tagged 'speech-emotion-recognition' with CC-BY or CC-BY-SA licenses and >=10k utterances; write dataset ids and sizes to JSON": "Easy", + "Train a summarization reward model via pairwise ranking on dataset D; apply DPO to model M and report ROUGE-L and human win rate": "Hard", + "Find four open-source OCR models that output line- or paragraph-level text and provide ONNX or TensorRT exports; list their licenses and maximum input resolutions": "Easy", + "Verify tokenizer special tokens for model M are preserved after adding new tokens; write a unit test that asserts CLS/SEP/PAD ids are unchanged before and after resize": "Medium", + "Implement a constrained decoder for model M that enforces a JSON schema via a custom Transformers LogitsProcessor; add unit tests and benchmark latency on dataset D": "Hard", + "Build a multilingual RAG index for 50M documents using mDPR with sharded storage on S3; support hot index reloads and report recall@10 and p95 latency at 100 QPS": "Very hard", + "Quantize T5-base to 8-bit with bitsandbytes (LLM.int8) and compare ROUGE-L and tokens/sec to fp16 on CNN/DailyMail; keep ROUGE-L drop <=1%": "Medium", + "Diagnose VRAM growth in a vLLM server at batch size 32; add profiling, fix cache eviction behavior, and demonstrate flat memory over 10,000 requests": "Hard", + "Convert a HuggingFace TokenizerFast to a SentencePiece model; verify >=99.9% token-level agreement on 10,000 sentences and measure tokenization speed delta": "Medium", + "Train a multi-task adapter stack for {summarization, QA, NLI} on model M; implement routing by prompt prefix and report per-task metrics and cross-task interference": "Very hard", + "Assess license compatibility between model M (Apache-2.0) and dataset D (CC-BY-SA); produce a one-paragraph verdict with rationale and reference links": "Easy", + "Enable FSDP with activation checkpointing for a 13B model across 2\u00d7A100 GPUs; achieve <=10% throughput loss vs baseline and verify loss parity over 1,000 steps": "Hard", + "List three datasets for code summarization with permissive licenses; output their dataset ids and license names": "Easy", + "Set up nightly continuous evaluation of model M on benchmarks {B_i}; log metrics to Weights & Biases and alert on >2% regression vs last 7-day rolling mean": "Medium", + "Implement streaming text generation in a Gradio Space for model M using server-sent events; cap median token emission delay at <50 ms": "Hard", + "Scale out training of a 7B model with FSDP + ZeRO across 8 GPUs; demonstrate checkpoint save/restore and achieve throughput within 15% of ideal linear scaling": "Very hard", + "Export a mixture-of-experts PyTorch model to ONNX and run with TensorRT; verify top-1 accuracy within 0.5% of PyTorch on dataset D": "Medium", + "Identify whether model M supports FlashAttention-2 from its config or source; provide supporting repo links and a yes/no compatibility flag": "Easy", + "Build an audio deduplication pipeline for dataset D using embedding model E with cosine similarity >= 0.98; publish grouped duplicate ids and a cleaned manifest": "Hard", + "Diagnose slow tokenization in a Transformers pipeline; profile, switch to a fast tokenizer, and demonstrate 2\u00d7 end-to-end speedup on 1M lines": "Medium", + "Implement a contrastive preference learning loss in TRL; train model M on dataset D and compare KL, reward variance, and human win rate vs a PPO baseline": "Hard", + "Build an elastic RAG service with Ray that autoscales FAISS shards on S3, supports live corpus updates, and maintains p95 latency <500 ms at 200 QPS": "Very hard", + "List five chat-optimized LLMs on the Hub that include a tokenizer chat_template and safetensors weights; output model ids": "Easy", + "Find three biomedical NER datasets with Apache-2.0 or MIT licenses; return dataset ids and license names": "Easy", + "Create a dataset viewer Space that streams Parquet shards from the Hub using datasets streaming; implement server-side filtering and pagination": "Medium", + "Enable gradient checkpointing and optimizer state offloading for model M with Accelerate; report step time and peak VRAM vs baseline on a single A100": "Medium", + "Diagnose and fix 'size mismatch for position_embeddings' after increasing max_position_embeddings; provide a minimal repro and a migration script": "Medium", + "Implement a regex-constrained Transformers LogitsProcessor that enforces ISO-8601 timestamps; add unit tests and report generation latency overhead on dataset D": "Hard", + "Train language-specific LoRA adapters for {en, es, de} on model M; add an automatic language router and report per-language BLEU and cross-language interference": "Hard", + "Build a speaker diarization + ASR Gradio Space using pyannote and Whisper-large-v3; achieve DER <= 12% and real-time factor <= 0.75 on a T4": "Hard", + "Implement multi-draft speculative decoding with dynamic draft-model selection per prompt; integrate with vLLM and benchmark tokens/sec speedup at batch sizes {1,8,32}": "Very hard", + "Convert a TensorFlow DistilBERT SavedModel to ONNX (opset 17) and validate logits parity (MSE < 1e-4) on 1,000 random inputs; measure CPU inference speedup vs TensorFlow": "Medium", + "Evaluate alignment drift after SFT: compare model M vs base M0 on prompt set P; report win rate, refusal rate, and average output length": "Medium", + "Enable KV cache int4 quantization in vLLM for model M; benchmark tokens/sec and exact match on dataset D vs fp16 cache": "Hard", + "Implement variable-length packing in a HF Datasets + Transformers training loop; ensure epoch-level sample coverage matches baseline and no truncation beyond max_length": "Medium", + "Build a multi-tenant LoRA router over vLLM: on-demand load adapters from the Hub with LRU eviction; sustain 100 tenants and <300 ms adapter swap latency under load": "Very hard", + "Audit generations for PII leakage on prompt set P using detector C; compute precision, recall, and false positive rate; redact before logging and publish a compliance summary": "Medium", + "Merge a stack of PEFT adapters {A_i} into base model M to produce a single FP16 checkpoint; validate perplexity drift <=0.5% on dataset D and export safetensors": "Hard", + "Find three Spaces that demonstrate constrained JSON generation; return Space ids and URLs": "Easy", + "Deploy a cross-lingual vector search service with multilingual-e5-large; shard FAISS across 3 nodes and measure mAP@10 and p95 latency at 500 QPS": "Very hard", + "Quantize attention and MLP projections only with bitsandbytes (selective 8-bit); compare peak VRAM, tokens/sec, and ROUGE-L vs full-model 8-bit on dataset D": "Hard", + "Fix \"Token indices sequence length is longer than the specified maximum\" after tokenizer resize; add truncation with stride and update generation config; verify no validation metric regression": "Medium", + "Identify splits for dataset D and output split names with sample counts": "Easy", + "Find five multilingual sentence-embedding models on the Hub with Apache-2.0 license; return model ids": "Easy", + "Set up CI to run evaluation suite E for model M nightly; fail the job if any metric drops >1% vs 7-day rolling mean": "Medium", + "Add length normalization to beam search for model M; compare vs baseline on dataset D and report ROUGE-L and average output length": "Medium", + "Detect per-sample language for dataset D; add a 'lang' column and recreate train/valid/test splits preserving language proportions": "Medium", + "Benchmark vLLM KV-cache eviction strategies (e.g., LRU vs TTL) for model M at batch sizes {1,8,32}; report tokens/sec and peak VRAM": "Medium", + "Implement a custom DataCollator that packs multiple documents for summarization with separator tokens; add unit tests to prevent cross-sample leakage": "Hard", + "Build a PDF-to-dataset pipeline: OCR pages with model Donut, store word-level bboxes, and publish a HuggingFace Dataset with a viewer Space": "Hard", + "Train a ColBERT reranker on corpus C + pairs dataset D; integrate into a RAG search service and report recall@10 and p95 latency delta": "Hard", + "Deploy vLLM for model M with multi-GPU tensor-parallel inference across 2 nodes using NCCL; demonstrate near-linear throughput scaling and deterministic outputs across 3 seeds": "Very hard", + "List four Hub models tagged 'named-entity-recognition' that declare bitsandbytes 8-bit support in their README; output model ids": "Easy", + "Find three Spaces that provide real-time TTS streaming demos; return Space ids and reported sample rates": "Easy", + "Create a Spaces app that visualizes transformer attention maps for a ViT model using Captum; keep heatmap rendering under 200 ms for 224x224 images": "Medium", + "Set up datasets streaming with resumable downloads and exponential backoff for S3-hosted Parquet shards; verify checksum integrity after killing and resuming the job": "Medium", + "Build a tokenizer migration tool to convert a SentencePiece model to a HuggingFace tokenizers JSON with byte-fallback; assert >=99.95% token-level agreement on 20k sentences and report speed delta": "Medium", + "Implement a custom DataCollator for span masking with variable block sizes for byte-level BPE; add unit tests and demonstrate MLM loss parity over 10k steps on WikiText-103": "Hard", + "Add speculative decoding with a small draft model to a Transformers-based text-generation server; expose a per-request flag and benchmark tokens/sec speedup at batch sizes {1,8,32}": "Hard", + "Train an online knowledge-distillation SFT: teacher M0 -> student M on dataset D; log KL divergence, token agreement, and throughput; cap metric drop at <=2% vs teacher": "Hard", + "Deploy a multi-region vLLM service on Kubernetes with adaptive batching and hot LoRA adapter loading; sustain 200 QPS with p95 latency <300 ms and zero-downtime rollouts": "Very hard", + "Build a sharded cross-encoder reranking service with Ray: distribute ColBERT scoring across nodes, integrate with FAISS retrieval, and maintain recall@10 within 1% of single-node baseline at 500 QPS": "Very hard", + "List four Spaces that perform multilingual OCR with layout extraction; return Space ids and supported languages": "Easy", + "Find five Hub datasets for code generation evaluation with permissive licenses; output dataset ids and license names": "Easy", + "Add gradient accumulation and gradient clipping to a Transformers Trainer finetune of model M; report step time, peak VRAM, and validation metric vs baseline": "Medium", + "Implement document chunking with sliding windows and overlap in a Datasets map pipeline; add doc_id and span indices and verify no segment exceeds max_length": "Medium", + "Export a fine-tuned BERT model to TorchScript and ONNX; verify logits parity (MSE < 1e-4) on 1,000 samples and compare CPU throughput": "Medium", + "Diagnose 'pad_token_id is not set' warnings during generation; add a PAD token, resize embeddings, and write a unit test asserting identical logits pre/post fix on 200 prompts": "Medium", + "Implement diverse beam search (group_beam_search) for model M; evaluate on dataset D and report ROUGE-L, distinct-n, and average output length vs standard beam search": "Hard", + "Build a multi-modal RAG demo that indexes image captions with CLIP and uses LLM M to answer visual questions; report top-1 accuracy and p95 latency": "Hard", + "Profile activation and KV-cache memory during generation for model M; log per-layer footprints and reduce peak usage via attention slicing; show tokens/sec and VRAM deltas": "Hard", + "Construct a 200M-document FAISS hybrid (IVF-PQ + HNSW) index with memory-mapped shards on S3; support live add/delete and benchmark recall@10 and QPS at 300 QPS": "Very hard", + "List five Hub datasets tagged 'topic-modeling' with MIT or Apache-2.0 licenses; output dataset ids": "Easy", + "Find three Spaces that offer real-time grammar correction with streaming tokens; return Space ids and URLs": "Easy", + "Convert a spaCy en_core_web_trf NER model to ONNX and wrap it in a Transformers TokenClassification pipeline; verify entity text/label/span parity on 1,000 sentences": "Medium", + "Set up a GitHub Actions workflow that snapshots tokenizer T weekly and fails if vocab or special token ids drift vs the last snapshot; upload a diff artifact": "Medium", + "Profile a Datasets map pipeline on corpus C; refactor to use batched=True, num_proc>1, and caching; achieve >=2\u00d7 speedup while preserving deterministic ordering across runs": "Medium", + "Implement a custom Transformers StoppingCriteria that halts when JSON braces are balanced or max nesting depth is reached; add unit tests and benchmark latency overhead on dataset D": "Hard", + "Build a visual-and-tabular RAG pipeline: index images with CLIP and CSV tables with TAPAS; answer mixed queries using LLM M; report EM@1 and p95 latency at 50 QPS": "Hard", + "Enable KV-cache int4 quantization during generation in Transformers for model M; compare tokens/sec and exact match vs fp16 cache on dataset D; keep metric drop <=1%": "Hard", + "Implement a hot-reloadable sharded FAISS IVF-PQ index for multilingual-e5-base with live add/delete and background re-training; sustain 200 QPS with p95 latency <400 ms across 3 nodes": "Very hard", + "Deploy a geo-distributed vLLM + LoRA adapter gateway across two regions with consistent hashing and zero-downtime adapter updates; ensure identical outputs across 3 seeds and report cross-region p95 latency": "Very hard", + "List five Hub LLM repos that disclose training token counts in their model cards; output model ids and token totals": "Easy", + "Find two ready-to-use Spaces for speaker diarization compatible with Whisper; return Space ids and URLs": "Easy", + "Create a hashing-based dataset splitter using column 'doc_id' to produce reproducible train/valid/test; verify identical splits across two machines and Python versions": "Medium", + "Resolve HTTP 403 when creating an organization dataset via the Hub API; diagnose token scopes and org permissions; provide a minimal repro script and the fix": "Medium", + "Export a PEFT LoRA adapter from a fine-tuned Llama checkpoint as standalone safetensors with a correct adapter_config.json; push to the Hub and verify PEFT.from_pretrained loads it": "Medium", + "Enable multi-query attention in model M within Transformers; benchmark tokens/sec and peak VRAM vs multi-head attention and verify perplexity parity over 2,000 steps": "Hard", + "Audit code dataset D for contamination against {HumanEval, MBPP} using exact substring and 3-gram Jaccard >= 0.9; publish per-source contamination rates and a cleaned dataset": "Hard", + "Implement contrastive search decoding for model M with tunable alpha; compare ROUGE-L, distinct-n, and latency vs nucleus sampling on dataset D": "Hard", + "Implement pipeline parallelism for model M across 4 GPUs with Accelerate; achieve near-linear scaling (<=15% gap), support checkpoint save/restore, and ensure deterministic outputs across 3 seeds": "Very hard", + "Deploy a Spaces app that serves two ASR models with automatic language ID routing; maintain real-time factor <= 0.6 on a single T4 and log per-language latency": "Hard", + "Benchmark JSON-constrained decoding across models {M_i}; report JSON validity rate, exact match on dataset D, and p95 latency under streaming": "Hard", + "Filter a multilingual dataset D to non-English using fastText language ID; recreate stratified splits and report per-language retention and drop rates": "Medium", + "Enable paged attention in a custom Transformers generation loop for model M; verify token-level parity on 500 prompts and measure peak VRAM change": "Hard", + "Shard a 1B-token text corpus into deterministic HF Datasets processing across 16 workers; validate byte-for-byte identical outputs across two runs": "Very hard", + "Compare LoRA vs QLoRA fine-tunes of Mistral-7B on GSM8K; track loss, exact match, and throughput; select the lowest-VRAM config within 2% EM of best": "Hard", + "Deploy a quantized T5 encoder-decoder on Triton Inference Server via a Python backend; add token streaming and achieve >=1.5x throughput vs PyTorch baseline": "Hard", + "Find three Spaces that perform audio source separation (vocals/music); return Space ids and reported sample rates": "Easy", + "Merge a PEFT IA3 adapter stack into Llama-3-8B base weights; verify perplexity drift <=0.3% on WikiText-103 and export safetensors": "Hard", + "Resolve DeepSpeed ZeRO-3 stalls during S3 checkpointing; implement async multipart uploads and show stable 5-minute checkpoint cadence over 2 hours": "Very hard", + "Set up CI to run contamination checks on dataset R against {TruthfulQA, SQuAD} using 4-gram overlap; fail if rate >0.5% and attach offending ids as artifacts": "Medium", + "List four Hub datasets for sarcasm detection in English; return dataset ids and license tags": "Easy", + "Identify whether tokenizer T enables byte_fallback in tokenizer.json; output true/false and the file path": "Easy", + "Find three Spaces that showcase streaming chat with token-by-token updates; return Space ids and whether they use SSE or websockets": "Easy", + "Create a Datasets loader that parses Praat TextGrid files into word-level timestamps aligned with audio; publish a dataset with an 'audio' column and validate 100 sample alignments": "Medium", + "Set up a GitHub Actions workflow that lints model cards for repos {R_i} to require intended use, training data, and limitations; fail PRs and post a summary comment on violations": "Medium", + "Containerize a Gradio Space with optional FlashAttention build: detect GPU capability at startup, compile kernels if supported, and fall back gracefully on unsupported GPUs; test on T4 and A100": "Medium", + "Evaluate long-context retrieval via needle-in-a-haystack for models {M_i} at context lengths {8k, 32k, 64k}; report retrieval accuracy, tokens/sec, and the max stable context length": "Hard", + "Implement a curriculum sampler as a HuggingFace Trainer callback that schedules sample difficulty over epochs; compare convergence and final eval metrics vs random sampling": "Hard", + "Add on-the-fly near-duplicate filtering during training using SimHash over token ids; log per-epoch removal rates and verify no convergence regressions vs a deduplicated baseline": "Hard", + "Deploy a dual-backend inference router using vLLM and TensorRT-LLM that selects backend per prompt length to minimize latency; maintain deterministic outputs across 3 seeds and sustain 300 QPS with p95 latency SLOs": "Very hard", + "Identify max_position_embeddings and whether rope_scaling is enabled for model M from its config; output both values.": "Easy", + "List five Vision Transformer models on the Hub that provide safetensors and have a default image size >= 384; output model ids.": "Easy", + "Find three Spaces that stream machine-translation outputs token-by-token; return Space ids and whether they use SSE or websockets.": "Easy", + "Diagnose bursts of [UNK] after adding special tokens to tokenizer T; enable byte_fallback, retrain embeddings for 2k steps, and show unknown-token rate <= baseline+0.1% on corpus C.": "Medium", + "Create a dataset viewer Space for a dataset with a nested JSON column; convert to Arrow struct arrays, implement server-side filtering on nested keys, and verify row counts match the source.": "Medium", + "Set up a GitHub Action that hits /health and a no-op inference on Space S after each deploy; fail if cold-start median latency >10s and attach server logs as an artifact.": "Medium", + "Implement a SQL grammar-constrained Transformers LogitsProcessor using an LL(1) parser; evaluate on Spider dev and report exact match and p95 latency overhead vs nucleus sampling.": "Hard", + "Add CPU-tier KV-cache offloading with pinned memory for model M in a custom generation loop; compare tokens/sec and peak VRAM vs baseline at context lengths {4k, 16k, 32k}.": "Hard", + "Deploy a batched cross-encoder reranker microservice using bge-reranker-base; keep recall@10 within 1% of single-request baseline and achieve >=2\u00d7 QPS at 100 concurrent users.": "Hard", + "Build a heterogeneous inference gateway that routes requests to vLLM or llama.cpp based on prompt length and GPU load; ensure identical normalized outputs across 3 seeds and sustain 200 QPS with p95 latency <300 ms.": "Very hard", + "Determine whether tokenizer T strips accents (strip_accents); output true/false and the file path where the setting is defined.": "Easy", + "List four Hub datasets for hate-speech detection in English; return dataset ids and license tags.": "Easy", + "Write a Datasets loader for a paginated OAuth2 REST API; cache pages, support streaming, and provide deterministic sharding across 8 workers; verify identical row counts across two runs.": "Medium", + "Add request-level caching (ETag/If-None-Match) to a Gradio summarization Space; achieve >=1.8\u00d7 QPS at 50 concurrent users and report cache hit ratio and p95 latency.": "Medium", + "Enable HuggingFace tokenizers parallelism and batched encoding for corpus C; benchmark throughput and memory on 10M lines and ensure deterministic outputs across 3 runs.": "Medium", + "Set up CI to lint dataset cards in repos {R_i} for required fields {license, citation, dataset_summary}; fail PRs and post a summary comment with missing keys.": "Medium", + "Run a parameter-efficient finetuning sweep comparing LoRA, IA3, and prefix-tuning on RoBERTa-base for MNLI; report accuracy, training time, and peak VRAM; select a Pareto-optimal config.": "Hard", + "Implement a Transformers LogitsProcessor that enforces balanced parentheses and proper quoted-string escaping; add unit tests and benchmark latency overhead on dataset D.": "Hard", + "Export Whisper-medium to ONNX with dynamic axes and int8 weights; verify word-timestamp parity on 500 clips and measure CPU real-time factor improvement >=1.3\u00d7 vs PyTorch.": "Hard", + "Deploy a geo-replicated RAG service: shard FAISS HNSW across three regions with conflict-free index metadata sync; sustain 300 QPS with p95 latency <450 ms and recall@10 within 1% of single-region baseline.": "Very hard", + "Compare cased vs uncased tokenization for BERT on CoNLL-2003 NER; train both, and report F1, average tokens per sentence, and training time.": "Medium", + "Create a HuggingFace Datasets loader for EPUB files: extract chapter text and embedded images into Arrow columns, support streaming and deterministic sharding across 8 workers; verify identical row counts across two runs.": "Medium", + "Configure a Hub webhook to trigger CI when a model card (README.md) changes; fail the job if sections {intended use, limitations} are missing and post a checklist comment on the PR.": "Medium", + "Add a reranking cache to a RAG service keyed by (query, candidate_ids); achieve >=50% cache hit at 100 QPS and keep recall@10 within 0.5% of baseline.": "Hard", + "Fix torch.compile graph breaks in a Transformers training loop; patch non-compilable ops, re-enable compilation, and demonstrate >=1.4\u00d7 step-time speedup with matching loss over 2,000 steps.": "Hard", + "Compute 95% bootstrap confidence intervals for ROUGE-L on dataset D over 3 random seeds; flag regressions when the new CI lies entirely below last week's baseline CI.": "Medium", + "Build a batch image-captioning Space with ViT-GPT2: accept ZIP uploads, use queue-based batching, and keep p95 latency <2s for 32 images.": "Medium", + "Implement hybrid parallelism (tensor + pipeline) for a 13B encoder-decoder using Accelerate; scale across 8 GPUs with <=15% gap from linear, support elastic resize (8->6 GPUs) without losing determinism, and verify checkpoint save/restore.": "Very hard", + "Find five Spaces that stream live vision-language captioning (e.g., LLaVA or BLIP); return Space ids and reported FPS.": "Easy", + "Identify whether tokenizer T applies Unicode normalization (NFKC/NFC/NFD/NFKD) and where it is configured; output the mode and file path.": "Easy", + "Identify whether model repo M stores weights exclusively as safetensors; output true/false and list the .safetensors file paths.": "Easy", + "List three multilingual sentence-embedding models on the Hub that provide ONNX exports; return model ids.": "Easy", + "Determine if tokenizer T lowercases text (do_lower_case or lowercase flag); output true/false and the file path or JSON key where it is set.": "Easy", + "Set up a GitHub Action to run a smoke-test text generation for model M on each push; fail if median time to first token >2s and attach container logs as an artifact.": "Medium", + "Create a Datasets preprocessing pipeline that tokenizes to max_length=512 with stride=64 and retains an 'orig_text' column; verify row counts match input and no NaNs after caching.": "Medium", + "Resolve 'git-lfs: command not found' when pushing model repo R to the Hub; install and configure Git LFS, set an appropriate large file threshold, and provide a minimal repro plus the verified fix.": "Medium", + "Enable KV-cache CPU offloading in a custom Transformers generation loop for model M; benchmark tokens/sec and peak VRAM vs baseline at context lengths {4k, 8k}.": "Hard", + "Implement LoRA rank warmup (r: 4\u219232 over the first 1,000 steps) in a custom Trainer; fine-tune model M on dataset D and report validation perplexity and peak VRAM vs fixed r=32.": "Hard", + "Export Whisper-small to TensorRT via ONNX (opset 18) with dynamic axes; verify word-timestamp parity (median diff \u22640.05s) on 300 clips and measure \u22651.3\u00d7 GPU speedup vs PyTorch.": "Hard", + "Deploy a multi-tenant RAG service that hot-loads per-tenant FAISS indices from S3, shares a reranker, and sustains 200 QPS with p95 latency <350 ms across 1,000 tenants; maintain recall@10 within 1% of a single-tenant baseline.": "Very hard" +} \ No newline at end of file diff --git a/eval/hf_agent_connector.py b/eval/hf_agent_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..acf9ee39bb261e1071bbca81647335ff5c1ab2e9 --- /dev/null +++ b/eval/hf_agent_connector.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import asyncio +import sys +from pathlib import Path +from typing import Any + + +from agent.config import Config, load_config +from agent.core.agent_loop import Handlers +from agent.core.session import Session +from agent.core.tools import ToolRouter + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def _resolve_project_path(path: str | Path) -> Path: + candidate = Path(path) + if candidate.is_absolute(): + return candidate + return (PROJECT_ROOT / candidate).resolve() + + +class AgentResponseGenerator: + """ + Thin async wrapper that executes the existing agent loop once and + returns the assistant's final message. + """ + + def __init__(self, config_path: str | Path, max_iterations: int = 300) -> None: + self.config_path = _resolve_project_path(config_path) + self.config: Config = load_config(str(self.config_path)) + self.max_iterations = max_iterations + + @property + def model_name(self) -> str: + """Expose the agent model name for downstream logging.""" + return self.config.model_name + + async def run(self, prompt: str) -> str: + """ + Execute the agent loop for a single prompt and return the assistant reply. + """ + tool_router = ToolRouter(self.config.mcpServers) + + async with tool_router: + session = Session(asyncio.Queue(), config=self.config) + session.tool_router = tool_router + await Handlers.run_agent( + session, + prompt, + max_iterations=self.max_iterations, + ) + return self._latest_assistant_response(session) + + def _latest_assistant_response(self, session: Session) -> str: + """ + Extract the final assistant response from the session history. + """ + for message in reversed(session.context_manager.items): + if getattr(message, "role", None) == "assistant": + return _content_to_text(getattr(message, "content", "")) + + raise RuntimeError("Agent did not produce an assistant message.") + + +def _content_to_text(content: Any) -> str: + """ + Convert LiteLLM content payloads (str or list[dict]) into plain text. + """ + if isinstance(content, str): + return content + + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, dict): + text = block.get("text") + if text: + parts.append(str(text)) + else: + text = getattr(block, "text", None) + if text: + parts.append(str(text)) + return "\n".join(parts) + + return str(content) diff --git a/eval/hf_io.py b/eval/hf_io.py new file mode 100644 index 0000000000000000000000000000000000000000..0f26899ce70ed5490757fca69a12b802bfa76b35 --- /dev/null +++ b/eval/hf_io.py @@ -0,0 +1,215 @@ +""" +HuggingFace Dataset I/O Utilities + +Reusable functions for uploading and downloading JSONL data to/from HuggingFace Hub. +Supports the dataset_name@config_name notation for managing multiple configurations. +""" + +from typing import List, Optional + +import pandas as pd +from datasets import Dataset, load_dataset + + +def list_dataset_configs(dataset_name: str) -> Optional[List[str]]: + """ + List all available configs for a dataset on HuggingFace Hub. + + Args: + dataset_name: Name of the dataset (e.g., "username/my-dataset") + + Returns: + List of config names, or None if unable to retrieve + + Example: + >>> configs = list_dataset_configs("username/hf-agent-benchmark") + >>> print(configs) + ['default', 'rubrics', 'evaluations'] + """ + try: + from datasets import get_dataset_config_names + + configs = get_dataset_config_names(dataset_name) + return configs + except Exception as e: + print(f"βœ— Failed to list configs: {type(e).__name__}: {str(e)}") + return None + + +def df_to_hub( + df: pd.DataFrame, + dataset_spec: str, + split: str = "train", + private: bool = False, +) -> bool: + """ + Upload a pandas DataFrame directly to HuggingFace Hub as a dataset. + + This function converts a pandas DataFrame to a HuggingFace Dataset and uploads + it to the Hub. This is useful for uploading data directly without creating an + intermediate JSONL file. + + Args: + df: pandas DataFrame to upload. All column types should be serializable. + Example DataFrame: + ``` + | question | solution | rubric | + |----------|----------|--------| + | "How..." | "You..." | {...} | + ``` + + dataset_spec: Dataset specification in the format "dataset_name" or + "dataset_name@config_name". Examples: + - "username/my-dataset" (uses "default" config) + - "username/my-dataset@rubrics" (uses "rubrics" config) + - "username/my-dataset@evaluations" (uses "evaluations" config) + + split: The dataset split name. Defaults to "train". Common values: + - "train": Training or main data + - "validation": Validation data + - "test": Test data + + private: Whether to create a private dataset. Defaults to False (public). + + Returns: + bool: True if upload succeeded, False otherwise + + Raises: + ValueError: If DataFrame is empty + Exception: For HuggingFace Hub upload errors + + Example: + >>> import pandas as pd + >>> df = pd.DataFrame({ + ... "question": ["How to train?", "What is fine-tuning?"], + ... "solution": ["Use trainer...", "Fine-tuning is..."], + ... "rubric": ['[{"title": "...", ...}]', '[{"title": "...", ...}]'] + ... }) + >>> upload_dataframe_to_hf(df, "username/dataset@rubrics") + + Notes: + - Requires authentication via `huggingface-cli login` or HF_TOKEN env var + - DataFrame columns with complex objects should be serialized first (e.g., to JSON strings) + - If the dataset doesn't exist, it will be created automatically + - Empty DataFrames will raise ValueError to prevent uploading invalid data + """ + # Validate DataFrame + if df.empty: + raise ValueError("DataFrame is empty") + + # Parse dataset specification + if "@" in dataset_spec: + dataset_name, config_name = dataset_spec.split("@", 1) + else: + dataset_name = dataset_spec + config_name = "default" + + try: + print("\nUploading DataFrame to HuggingFace Hub...") + print(f" Dataset: {dataset_name}") + print(f" Config: {config_name}") + print(f" Split: {split}") + print(f" Rows: {len(df)}") + print(f" Columns: {list(df.columns)}") + + # Convert DataFrame to HuggingFace Dataset + dataset = Dataset.from_pandas(df) + + # Upload to HuggingFace Hub + dataset.push_to_hub( + dataset_name, + config_name=config_name, + split=split, + private=private, + ) + + print( + f"βœ“ Successfully uploaded to {dataset_name}@{config_name} (split: {split})" + ) + return True + + except Exception as e: + print(f"βœ— Failed to upload to HuggingFace: {type(e).__name__}: {str(e)}") + return False + + +def hub_to_df( + dataset_spec: str, + split: str = "train", +) -> Optional[pd.DataFrame]: + """ + Download a dataset from HuggingFace Hub as a pandas DataFrame. + + This function downloads a dataset from the HuggingFace Hub and returns it as a + pandas DataFrame for immediate use in Python. + + Args: + dataset_spec: Dataset specification in the format "dataset_name" or + "dataset_name@config_name". Examples: + - "username/my-dataset" (uses "default" config) + - "username/my-dataset@rubrics" (uses "rubrics" config) + - "username/my-dataset@evaluations" (uses "evaluations" config) + + split: The dataset split to download. Defaults to "train". Common values: + - "train": Training or main data + - "validation": Validation data + - "test": Test data + + Returns: + pd.DataFrame: Downloaded data as pandas DataFrame, or None if failed + + Raises: + ValueError: If the dataset/config/split doesn't exist + Exception: For HuggingFace Hub download errors + + Example: + >>> # Download rubrics from specific config + >>> df = hub_to_df("username/hf-agent-benchmark@rubrics") + >>> print(df.head()) + >>> print(f"Shape: {df.shape}") + + >>> # Download evaluation results + >>> results_df = download_hf_to_dataframe( + ... "username/hf-agent-benchmark@evaluations", + ... split="test" + ... ) + + Notes: + - Requires authentication for private datasets via `huggingface-cli login` + - Downloaded data will be in the same format as uploaded (preserves structure) + - Large datasets may take time to download and consume significant memory + - For very large datasets, consider using streaming or download_hf_to_jsonl + """ + # Parse dataset specification + if "@" in dataset_spec: + dataset_name, config_name = dataset_spec.split("@", 1) + else: + dataset_name = dataset_spec + config_name = "default" + + try: + print("\nDownloading from HuggingFace Hub...") + print(f" Dataset: {dataset_name}") + print(f" Config: {config_name}") + print(f" Split: {split}") + + # Download dataset from HuggingFace Hub + dataset = load_dataset( + dataset_name, + name=config_name, + split=split, + ) + + print(f" Downloaded {len(dataset)} records") + + # Convert to pandas DataFrame + df = dataset.to_pandas() + + print("βœ“ Successfully loaded as DataFrame") + print(f" Shape: {df.shape}") + print(f" Columns: {list(df.columns)}") + return df + + except Exception as e: + print(f"βœ— Failed to download from HuggingFace: {type(e).__name__}: {str(e)}") + return None diff --git a/eval/leaderboard.py b/eval/leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..00444bc342df6b398365270961d9987e2aad981e --- /dev/null +++ b/eval/leaderboard.py @@ -0,0 +1,172 @@ +""" +Utilities for logging solver scores to a Hugging Face dataset. +""" + +from __future__ import annotations + +import json +import re +import shutil +import subprocess +import tempfile +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from huggingface_hub import HfApi, hf_hub_download + +AVERAGE_RE = re.compile(r"Average normalized score:\s*([0-9.]+)") +DEFAULT_FILENAME = "records.jsonl" + + +def _hydra_join(*parts: str | None) -> str: + tokens = [str(part).strip().replace(" ", "_") for part in parts if part] + return "/".join(tokens) if tokens else "default" + + +def detect_agent_version(config_path: str = "agent/config_mcp_example.json") -> str: + """ + Returns a short string identifying the current agent version: + -. + """ + + try: + commit = ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode() + .strip() + ) + except Exception: + commit = "unknown" + + config_file = Path(config_path) + config_stem = config_file.stem or "config" + parent_name = config_file.parent.name if config_file.parent.name else None + return _hydra_join(parent_name, config_stem, commit) + + +def parse_average_score(text: str) -> float | None: + """Extracts the 'Average normalized score' value from Inspect logs.""" + + match = AVERAGE_RE.search(text) + if match: + try: + return float(match.group(1)) + except ValueError: + return None + return None + + +def latest_log_file( + log_dir: Path, extensions: tuple[str, ...] = (".eval", ".json") +) -> Path | None: + """Returns the most recent log file in log_dir matching the provided extensions.""" + + if not log_dir.exists(): + return None + + files: list[Path] = [] + for ext in extensions: + files.extend(log_dir.glob(f"*{ext}")) + + if not files: + return None + + files.sort(key=lambda path: path.stat().st_mtime) + return files[-1] + + +@dataclass +class LeaderboardClient: + """Simple helper to append JSONL rows to a HF dataset.""" + + repo_id: str + token: str + filename: str = DEFAULT_FILENAME + + def append_record(self, record: dict[str, Any]) -> None: + tmp_dir = Path(tempfile.mkdtemp(prefix="leaderboard_")) + local_file = tmp_dir / self.filename + + self._download_existing(local_file) + if not local_file.exists(): + local_file.write_text("", encoding="utf-8") + + with local_file.open("a", encoding="utf-8") as fh: + fh.write(json.dumps(record) + "\n") + + HfApi(token=self.token).upload_file( + path_or_fileobj=str(local_file), + path_in_repo=self.filename, + repo_id=self.repo_id, + repo_type="dataset", + ) + + try: + local_file.unlink() + tmp_dir.rmdir() + except OSError: + pass + + def _download_existing(self, destination: Path) -> None: + destination.parent.mkdir(parents=True, exist_ok=True) + + try: + downloaded = hf_hub_download( + repo_id=self.repo_id, + filename=self.filename, + repo_type="dataset", + token=self.token, + ) + shutil.copy(Path(downloaded), destination) + except Exception: + destination.write_text("", encoding="utf-8") + + +def build_record( + solver_name: str, + solver_kwargs: dict[str, Any], + dataset_name: str, + dataset_split: str, + limit: int | None, + score: float, + command: list[str], + log_path: Path | None, + criterion_checks: list[dict[str, Any]] | None = None, +) -> dict[str, Any]: + """Assembles a JSON-serialisable record for the leaderboard dataset.""" + + record = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "solver": solver_name, + "solver_kwargs": solver_kwargs, + "dataset_name": dataset_name, + "dataset_split": dataset_split, + "limit": limit, + "score": score, + "command": command, + } + + if solver_name == "hf_agent": + record["solver_version"] = detect_agent_version( + solver_kwargs.get("config_path", "agent/config_mcp_example.json") + ) + else: + version_spec = solver_kwargs.get("version") + if isinstance(version_spec, (list, tuple)): + record["solver_version"] = _hydra_join(*version_spec) + elif isinstance(version_spec, dict): + record["solver_version"] = _hydra_join( + *[f"{k}={v}" for k, v in version_spec.items()] + ) + elif isinstance(version_spec, str): + record["solver_version"] = version_spec + else: + record["solver_version"] = _hydra_join(solver_name, "default") + + if log_path: + record["log_artifact"] = str(log_path) + record["criterion_checks"] = criterion_checks or [] + + return record diff --git a/eval/models.py b/eval/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d58f7f2478ebba8cb6cdbfb48cd2700f2322f77e --- /dev/null +++ b/eval/models.py @@ -0,0 +1,63 @@ +"""Shared data models for the HF agent project""" + +from datetime import datetime +from enum import Enum + +from pydantic import BaseModel, Field + + +class Discussion(BaseModel): + """Model for a discussion thread""" + + title: str + url: str + topic_id: int + category: int + created_at: datetime + + +class QuestionAndSolution(BaseModel): + """Model for a QA pair from a discussion""" + + discussion_title: str + discussion_url: str + discussion_topic_id: int + discussion_category: int + discussion_created_at: datetime + thread: list[dict] + question: str + solution: str + + +class Correctness(str, Enum): + yes = "yes" + no = "no" + + +class JudgementResult(BaseModel): + """Structured output for LLM judge evaluation""" + + extracted_final_answer: str = Field( + description="The final exact/snippet answer extracted from the response" + ) + reasoning: str = Field( + description="Explanation of why the answer is correct or incorrect" + ) + correct: Correctness = Field(description="'yes' if correct, 'no' if incorrect") + confidence: int = Field( + description="Confidence score between 0 and 100", ge=0, le=100 + ) + + +class EvaluationResult(BaseModel): + """Model for evaluation results including metadata""" + + success: bool + judgement: JudgementResult | None = None + error: str | None = None + + +class EvaluatedQuestionAndSolution(QuestionAndSolution): + """Model for a QA pair with its evaluation result""" + + evaluation: JudgementResult diff --git a/eval/rubric_eval.py b/eval/rubric_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ce706de91c531f862d50af8f39cfd9a4f32e01da --- /dev/null +++ b/eval/rubric_eval.py @@ -0,0 +1,142 @@ +""" +Rubric-based evaluation following the "Rubrics as Rewards" paper. + +Implements RaR-Explicit: Weighted sum of individual criterion scores (Equation 1) +""" + +from typing import List, Optional + +import litellm +from pydantic import BaseModel + + +class CriterionCheck(BaseModel): + """Result of checking a single rubric criterion.""" + + title: str + description: str + weight: int + satisfied: bool + reasoning: Optional[str] = None + + +class RubricEvaluation(BaseModel): + """Complete rubric-based evaluation result.""" + + criterion_checks: List[CriterionCheck] + raw_score: float # Unnormalized score + normalized_score: float # Score normalized to [0, 1] + + +CRITERION_PROMPT = """You are evaluating whether a response satisfies a specific evaluation criterion. + +Question: {question} + +Response to evaluate: {response} + +Evaluation Criterion: +{criterion_description} + +Your task: Determine if the response satisfies this criterion. + +Output a JSON object with: +- "satisfied": true or false +- "reasoning": Brief explanation (1-2 sentences) of why it does or doesn't satisfy the criterion + +Be strict but fair. The criterion must be clearly satisfied for you to answer true.""" + + +class RubricData(BaseModel): + """Rubric data loaded from file.""" + + title: str + description: str + weight: int + + +def check_criterion( + question: str, response: str, criterion: RubricData, model: str = "gpt-4o-mini" +) -> CriterionCheck: + """ + Check if response satisfies a single criterion. + + Args: + question: The question being answered + response: The response to evaluate + criterion: The rubric criterion to check + model: LLM model for judging + + Returns: + CriterionCheck with satisfaction result + """ + prompt = CRITERION_PROMPT.format( + question=question, + response=response, + criterion_description=criterion.description, + ) + + llm_response = litellm.completion( + model=model, + messages=[ + { + "role": "system", + "content": "You are an expert evaluator for rubric-based assessment.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.0, + response_format=CriterionCheck, + ) + + result = CriterionCheck.model_validate_json(llm_response.choices[0].message.content) + + return result + + +def evaluate_with_rubrics( + question: str, + response: str, + rubrics: List[RubricData], + model: str = "gpt-5-nano", +) -> RubricEvaluation: + """ + Evaluate response using RaR-Explicit method (weighted sum). + + Implements Equation 1 from paper: + r(x, Ε·) = Ξ£(w_j * c_j(x, Ε·)) / Ξ£(w_j) + + Args: + question: The question + response: Response to evaluate + reference_answer: Reference answer (not directly used, but available) + rubrics: List of rubric criteria + model: LLM model for judging + + Returns: + RubricEvaluation with normalized score + """ + # Check each criterion independently + checks = [] + for rubric in rubrics: + check = check_criterion(question, response, rubric, model) + checks.append(check) + + # Calculate weighted score (Equation 1) + # Only positive weights contribute to denominator + positive_weights = sum(abs(r.weight) for r in rubrics if r.weight > 0) + + raw_score = 0.0 + for check in checks: + if check.satisfied: + raw_score += check.weight + + # Normalize to [0, 1] + normalized_score = raw_score / positive_weights if positive_weights > 0 else 0.0 + # Clip to [0, 1] in case pitfalls make it negative + normalized_score = max(0.0, min(1.0, normalized_score)) + + return RubricEvaluation( + raw_score=raw_score, + normalized_score=normalized_score, + criterion_checks=checks, + ) diff --git a/eval/run_eval_with_leaderboard.py b/eval/run_eval_with_leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..031246179bf2f2eb8ae1b84b64985e82fd01bebe --- /dev/null +++ b/eval/run_eval_with_leaderboard.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import sys +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv +from leaderboard import LeaderboardClient, build_record, latest_log_file + +load_dotenv() + + +def run_command(cmd: list[str]) -> subprocess.CompletedProcess[str]: + print(f"[leaderboard] running: {' '.join(cmd)}") + return subprocess.run(cmd, capture_output=True, text=True) + + +def build_inspect_command(args: argparse.Namespace) -> list[str]: + cmd = [] + cmd.extend(args.inspect_launch) + cmd.append(args.inspect_task) + + def add_task_arg(key: str, value: Any) -> None: + if value is None: + return + cmd.extend(["-T", f"{key}={value}"]) + + add_task_arg("solver_name", args.solver_name) + add_task_arg("solver_kwargs", json.dumps(args.solver_kwargs)) + add_task_arg("dataset_name", args.dataset) + if args.limit is not None: + add_task_arg("limit", args.limit) + + cmd.extend(["--log-dir", args.log_dir]) + if args.log_format: + cmd.extend(["--log-format", args.log_format]) + + if args.extra_inspect_args: + cmd.extend(args.extra_inspect_args) + + return cmd + + +def parse_score_from_outputs(log_dir: Path) -> tuple[float, Path, list[dict[str, Any]]]: + log_path = latest_log_file(log_dir) + if not log_path: + raise RuntimeError("Inspect log file not found.") + + # Sanitization + content = log_path.read_text(encoding="utf-8") + # Regex to match hf_ followed by 34 alphanumeric chars + sanitized_content = re.sub(r"hf_[a-zA-Z0-9]{34}", "", content) + + if content != sanitized_content: + log_path.write_text(sanitized_content, encoding="utf-8") + print(f"[leaderboard] Redacted HF tokens in {log_path}") + content = sanitized_content + + data = json.loads(content) + results = data.get("results", {}) + scores = results.get("scores", []) + score_value = None + criterion_checks: list[dict[str, Any]] = [] + + for score_entry in scores: + metrics = score_entry.get("metrics", {}) + for metric in metrics.values(): + value = metric.get("value") + if isinstance(value, (int, float)): + score_value = float(value) + break + if score_value is not None: + break + + if score_value is None: + raise RuntimeError("Could not find a numeric metric value in the Inspect log.") + + for sample in data.get("samples", []): + # Grab the question from metadata (fallback to input) + question = "Unknown Question" + if "metadata" in sample and "question" in sample["metadata"]: + question = sample["metadata"]["question"] + elif "input" in sample: + question = sample["input"] + + # Check if any scorer produced criterion_checks + for scorer in sample.get("scores", {}).values(): + metadata = scorer.get("metadata") or {} + checks = metadata.get("criterion_checks") + + if isinstance(checks, list) and checks: + # Create a grouped entry for this question/sample + grouped_entry = {"question": question, "checks": []} + for check in checks: + if isinstance(check, dict): + grouped_entry["checks"].append(check) + + if grouped_entry["checks"]: + criterion_checks.append(grouped_entry) + + return score_value, log_path, criterion_checks + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run Inspect eval and append the resulting score to a HF dataset." + ) + parser.add_argument( + "--hf-dataset", + default="akseljoonas/hf-agent-leaderboard", + help="HF dataset repo id for the leaderboard (e.g. user/leaderboard).", + ) + + parser.add_argument( + "--solver-name", + required=True, + help="Solver name used in the Inspect task (e.g. hf_agent).", + ) + parser.add_argument( + "--solver-kwargs", + type=json.loads, + default="{}", + help="JSON string with solver kwargs passed to the Inspect task.", + ) + parser.add_argument( + "--dataset", + default="akseljoonas/hf-agent-rubrics@train", + help="Dataset spec in the form author/dataset@split.", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Optional sample limit passed to Inspect.", + ) + parser.add_argument( + "--inspect-task", + default="eval/task.py@hf-benchmark-with-rubrics", + help="Inspect task reference.", + ) + parser.add_argument( + "--inspect-launch", + nargs="+", + default=["uv", "run", "inspect", "eval"], + help="Command used to invoke Inspect (default: uv run inspect eval).", + ) + parser.add_argument( + "--log-dir", + default="logs/leaderboard", + help="Directory where Inspect outputs .eval logs.", + ) + parser.add_argument( + "--extra-inspect-args", + nargs="*", + help="Additional args forwarded to Inspect after the standard task arguments.", + ) + parser.add_argument( + "--log-format", + default="json", + help="Log format passed to Inspect (default: json).", + ) + + args = parser.parse_args() + + if isinstance(args.solver_kwargs, str): + args.solver_kwargs = json.loads(args.solver_kwargs or "{}") + + hf_token = os.getenv("HF_TOKEN") + if not hf_token: + print("ERROR: set HF_TOKEN in your environment.", file=sys.stderr) + sys.exit(1) + + if "@" not in args.dataset: + raise ValueError("Dataset must be in the format 'author/dataset@split'.") + dataset_name, dataset_split = args.dataset.split("@", 1) + + log_dir = Path(args.log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + + inspect_cmd = build_inspect_command(args) + result = run_command(inspect_cmd) + + if result.returncode != 0: + print(result.stdout) + print(result.stderr, file=sys.stderr) + raise SystemExit(result.returncode) + + score, log_path, criterion_checks = parse_score_from_outputs(log_dir) + + client = LeaderboardClient(repo_id=args.hf_dataset, token=hf_token) + record = build_record( + solver_name=args.solver_name, + solver_kwargs=args.solver_kwargs, + dataset_name=dataset_name, + dataset_split=dataset_split, + limit=args.limit, + score=score, + command=inspect_cmd, + log_path=log_path, + criterion_checks=criterion_checks, + ) + client.append_record(record) + + print( + f"[leaderboard] recorded score {score:.3f} for solver '{args.solver_name}' to {args.hf_dataset}" + ) + + +if __name__ == "__main__": + main() diff --git a/eval/scrape_discussions/discussions_scraper.py b/eval/scrape_discussions/discussions_scraper.py new file mode 100644 index 0000000000000000000000000000000000000000..506cf97302e033e0e63fa3a125c0d6ba0b438f6b --- /dev/null +++ b/eval/scrape_discussions/discussions_scraper.py @@ -0,0 +1,98 @@ +import sys +import time +from pathlib import Path + +import requests +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +# Add parent directory to path to import models +sys.path.insert(0, str(Path(__file__).parent.parent)) +from models import Discussion, QuestionAndSolution + +BASE_URL = "https://discuss.huggingface.co" + + +# configure retry decorator for your requests +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=1, max=60), + retry=retry_if_exception_type(requests.HTTPError), +) +def safe_get(url, **kwargs): + resp = requests.get(url, **kwargs) + if resp.status_code == 422: + # read retry‐after header if present + retry_after = resp.headers.get("Retry-After") + if retry_after: + delay = float(retry_after) + else: + # fallback to guess + delay = 30 + print(f"429 hit β€” waiting {delay} seconds...") + time.sleep(delay) + resp.raise_for_status() + else: + resp.raise_for_status() + return resp + + +def get_solved_discussions(n_posts: int = 50): + page = 1 + discussions = [] + while len(discussions) < n_posts: + url = f"{BASE_URL}/search.json?q=status:solved+order:latest&page={page}" + resp = safe_get(url) + topics = resp.json()["topics"] + if not topics: + break + for post in topics: + discussions.append( + Discussion( + title=post["fancy_title"], + url=f"{BASE_URL}/t/{post['slug']}/{post['id']}", + topic_id=post["id"], + category=post["category_id"], + created_at=post["created_at"], + ) + ) + if len(discussions) >= n_posts: + break + page += 1 + time.sleep(0.5) # simple pacing to avoid bursts + return discussions + + +def get_qa_pair(discussions, start_idx: int = 0): + for discussion in discussions[start_idx:]: + resp = safe_get(discussion.url + ".json") + data = resp.json() + posts = data["post_stream"]["posts"] + accepted_nr = min( + max(data["accepted_answer"]["post_number"] - 1, 0), len(posts) - 1 + ) + question = posts[0]["cooked"] + solution = posts[accepted_nr]["cooked"] + yield QuestionAndSolution( + discussion_title=discussion.title, + discussion_url=discussion.url, + discussion_topic_id=discussion.topic_id, + discussion_category=discussion.category, + discussion_created_at=discussion.created_at, + question=question, + solution=solution, + thread=posts, + ) + time.sleep(0.5) + + +if __name__ == "__main__": + discussions = get_solved_discussions(n_posts=300) + print(f"Fetched {len(discussions)} discussions") + with open("qa_pairs.jsonl", "a") as f: + for qa_pair in get_qa_pair(discussions): + f.write(qa_pair.model_dump_json() + "\n") diff --git a/eval/solvers.py b/eval/solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..43cb44d8b2fbe87785c7597700578bda4ae5ed59 --- /dev/null +++ b/eval/solvers.py @@ -0,0 +1,165 @@ +""" +Collection of Inspect AI solvers used by the rubric task. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import tempfile +from typing import Callable, Dict, List, Sequence + +import litellm +from inspect_ai.model import ChatMessageAssistant, ModelOutput +from inspect_ai.solver import Solver, solver +from inspect_ai.solver._task_state import TaskState + +from eval.hf_agent_connector import AgentResponseGenerator + + +async def _run_subprocess(command: Sequence[str]) -> str: + process = await asyncio.create_subprocess_exec( + *command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + if process.returncode != 0: + raise RuntimeError( + f"Command {' '.join(command)} failed with code {process.returncode}:\n" + f"{stderr.decode().strip()}" + ) + return stdout.decode().strip() + + +@solver(name="hf_agent") +def hf_agent( + config_path: str = "agent/config_mcp_example.json", + max_iterations: int = 10, +) -> Solver: + + runner = AgentResponseGenerator( + config_path=config_path, + max_iterations=max_iterations, + ) + + async def solve(state: TaskState, generate) -> TaskState: + response = await runner.run(state.input_text) + assistant_message = ChatMessageAssistant( + content=response, + model=runner.model_name, + source="generate", + ) + state.messages.append(assistant_message) + state.output = ModelOutput.from_message(assistant_message) + state.completed = True + return state + + return solve + + +@solver(name="claude_code") +def claude_code( + output_format: str = "json", + mcp_config: str | None = None, +) -> Solver: + if output_format not in {"text", "json", "stream-json"}: + raise ValueError("output_format must be one of: text, json, stream-json") + + async def solve(state: TaskState, generate) -> TaskState: + prompt = state.input_text + + cmd: List[str] = ["claude", "-p", prompt, "--output-format", output_format] + if mcp_config: + cmd += ["--mcp-config", mcp_config] + + stdout = await _run_subprocess(cmd) + response_text = stdout + session_id = None + + if output_format in {"json", "stream-json"}: + # stream-json may emit multiple JSON objects; take the last complete line + candidate_line = stdout.strip().splitlines()[-1] + try: + payload = json.loads(candidate_line) + response_text = ( + payload.get("result") or payload.get("message", "") or stdout + ) + session_id = payload.get("session_id") + except (json.JSONDecodeError, AttributeError): + response_text = stdout + + assistant_message = ChatMessageAssistant( + content=response_text, + model="claude-code", + source="generate", + metadata={"session_id": session_id} if session_id else None, + ) + state.messages.append(assistant_message) + state.output = ModelOutput.from_message(assistant_message) + state.completed = True + return state + + return solve + + +@solver(name="claude_code+hf_mcp") +def claude_code_hf_mcp( + output_format: str = "json", + hf_token: str | None = None, +) -> Solver: + """ + A solver that uses Claude Code with the Hugging Face MCP server. + Requires HF_TOKEN in environment variables or passed as argument. + """ + token = hf_token or os.environ.get("HF_TOKEN") + if not token: + raise ValueError( + "HF_TOKEN not found. Please set HF_TOKEN env var or pass it to the solver." + ) + + # Construct the MCP configuration for Hugging Face + mcp_config = { + "mcpServers": { + "huggingface": { + "type": "http", + "url": "https://huggingface.co/mcp", + "headers": {"Authorization": f"Bearer {token}"}, + } + } + } + + async def solve(state: TaskState, generate) -> TaskState: + # Write config to a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: + json.dump(mcp_config, tmp, indent=2) + tmp_path = tmp.name + + try: + # Delegate to the base claude_code solver + delegate = claude_code(output_format=output_format, mcp_config=tmp_path) + return await delegate(state, generate) + finally: + # Clean up the temporary file + if os.path.exists(tmp_path): + os.remove(tmp_path) + + return solve + + +SOLVER_REGISTRY: Dict[str, Callable[..., Solver]] = { + "hf_agent": hf_agent, + "claude_code": claude_code, + "claude_code+hf_mcp": claude_code_hf_mcp, +} + + +def get_solver(name: str, **kwargs) -> Solver: + try: + factory = SOLVER_REGISTRY[name] + except KeyError as exc: + available = ", ".join(sorted(SOLVER_REGISTRY)) + raise ValueError(f"Unknown solver '{name}'. Available: {available}") from exc + + return factory(**kwargs) diff --git a/eval/task.py b/eval/task.py new file mode 100644 index 0000000000000000000000000000000000000000..134e2119d597ceb9c97666bae161549f766cd4cd --- /dev/null +++ b/eval/task.py @@ -0,0 +1,121 @@ +""" +Inspect AI task definition that runs the existing agent and reuses the rubric scorer. +""" + +from __future__ import annotations + +import asyncio +import json +import sys +from pathlib import Path +from typing import Any, Sequence + +from inspect_ai import Task, task +from inspect_ai.dataset import Sample, hf_dataset +from inspect_ai.scorer import Score, Target, mean, scorer +from inspect_ai.solver._task_state import TaskState +import litellm + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from eval.rubric_eval import RubricData, evaluate_with_rubrics # noqa: E402 +from eval.solvers import get_solver # noqa: E402 + + +def _record_to_sample(record: dict[str, Any]) -> Sample: + rubric_payload = json.loads(record["rubric"]) + rubrics = rubric_payload.get("rubrics", []) + + metadata = { + "question": record["question"], + "discussion_title": record.get("discussion_title"), + "discussion_url": record.get("discussion_url"), + "rubric_title": rubric_payload.get("title"), + "rubric_description": rubric_payload.get("description"), + "rubrics": rubrics, + } + + return Sample( + input=record["question"], + target=record["solution"], + id=record.get("discussion_topic_id"), + metadata=metadata, + ) + + +def _load_dataset(dataset_name: str, split: str, limit: int | None) -> Sequence[Sample]: + return hf_dataset( + dataset_name, sample_fields=_record_to_sample, split=split, limit=limit + ) + + +def _metadata_to_rubrics(metadata: dict[str, Any]) -> list[RubricData]: + raw_rubrics = metadata.get("rubrics", []) + return [RubricData(**rubric) for rubric in raw_rubrics] + + +@scorer(metrics=[mean()], name="rubric_scorer") +def rubric_scorer(judge_model: str = "gpt-5-mini"): + async def score(state: TaskState, target: Target) -> Score: + response_text = state.output.completion or state.output.message.text + question = state.metadata.get("question", state.input_text) + rubrics = _metadata_to_rubrics(state.metadata) + + evaluation = await asyncio.to_thread( + evaluate_with_rubrics, + question, + response_text, + rubrics, + judge_model, + ) + + score_metadata = { + "raw_score": evaluation.raw_score, + "criterion_checks": [ + check.model_dump() for check in evaluation.criterion_checks + ], + "discussion_title": state.metadata.get("discussion_title"), + "discussion_url": state.metadata.get("discussion_url"), + "reference_answer": target.text, + } + + return Score( + value=evaluation.normalized_score, + answer=response_text, + explanation=f"Normalized score {evaluation.normalized_score:.3f}", + metadata=score_metadata, + ) + + return score + + +@task(name="hf-benchmark-with-rubrics") +def hf_benchmark_with_rubrics( + solver_name: str = "hf_agent", + solver_kwargs: dict[str, Any] = { + "max_iterations": 10, + "config_path": "agent/config_mcp_example.json", + }, + dataset_name: str = "akseljoonas/hf-agent-rubrics@train", + limit: int | None = None, + judge_model: str = "gpt-5-mini", +) -> Task: + litellm.drop_params = True + if "@" not in dataset_name: + raise ValueError("Dataset name must be in the format 'author/dataset@split'") + dataset_name, dataset_split = dataset_name.split("@") + dataset = _load_dataset(dataset_name, dataset_split, limit=limit) + + return Task( + dataset=dataset, + solver=get_solver(solver_name, **solver_kwargs), + scorer=rubric_scorer(judge_model=judge_model), + metadata={ + "dataset_name": dataset_name, + "dataset_split": dataset_split, + "solver_name": solver_name, + "judge_model": judge_model, + }, + ) diff --git a/frontend/index.html b/frontend/index.html index e5acd2f134c91f0715d7b259ba2eb842a655757b..6d7b512a1076c222ddf20efbf45893c43f5b33d5 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -2,9 +2,9 @@ - + - ML Intern + HF Agent diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 0d7602b5b273bc63a4864ce026a37e98c28d10e4..714022bf0e85692e2e8621dfbe1eec6396c13887 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1437,9 +1437,9 @@ "license": "MIT" }, "node_modules/@rollup/rollup-android-arm-eabi": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.60.1.tgz", - "integrity": "sha512-d6FinEBLdIiK+1uACUttJKfgZREXrF0Qc2SmLII7W2AD8FfiZ9Wjd+rD/iRuf5s5dWrr1GgwXCvPqOuDquOowA==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.55.1.tgz", + "integrity": "sha512-9R0DM/ykwfGIlNu6+2U09ga0WXeZ9MRC2Ter8jnz8415VbuIykVuc6bhdrbORFZANDmTDvq26mJrEVTl8TdnDg==", "cpu": [ "arm" ], @@ -1451,9 +1451,9 @@ ] }, "node_modules/@rollup/rollup-android-arm64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.60.1.tgz", - "integrity": "sha512-YjG/EwIDvvYI1YvYbHvDz/BYHtkY4ygUIXHnTdLhG+hKIQFBiosfWiACWortsKPKU/+dUwQQCKQM3qrDe8c9BA==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.55.1.tgz", + "integrity": "sha512-eFZCb1YUqhTysgW3sj/55du5cG57S7UTNtdMjCW7LwVcj3dTTcowCsC8p7uBdzKsZYa8J7IDE8lhMI+HX1vQvg==", "cpu": [ "arm64" ], @@ -1465,9 +1465,9 @@ ] }, "node_modules/@rollup/rollup-darwin-arm64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.60.1.tgz", - "integrity": "sha512-mjCpF7GmkRtSJwon+Rq1N8+pI+8l7w5g9Z3vWj4T7abguC4Czwi3Yu/pFaLvA3TTeMVjnu3ctigusqWUfjZzvw==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.55.1.tgz", + "integrity": "sha512-p3grE2PHcQm2e8PSGZdzIhCKbMCw/xi9XvMPErPhwO17vxtvCN5FEA2mSLgmKlCjHGMQTP6phuQTYWUnKewwGg==", "cpu": [ "arm64" ], @@ -1479,9 +1479,9 @@ ] }, "node_modules/@rollup/rollup-darwin-x64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.60.1.tgz", - "integrity": "sha512-haZ7hJ1JT4e9hqkoT9R/19XW2QKqjfJVv+i5AGg57S+nLk9lQnJ1F/eZloRO3o9Scy9CM3wQ9l+dkXtcBgN5Ew==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.55.1.tgz", + "integrity": "sha512-rDUjG25C9qoTm+e02Esi+aqTKSBYwVTaoS1wxcN47/Luqef57Vgp96xNANwt5npq9GDxsH7kXxNkJVEsWEOEaQ==", "cpu": [ "x64" ], @@ -1493,9 +1493,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-arm64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.60.1.tgz", - "integrity": "sha512-czw90wpQq3ZsAVBlinZjAYTKduOjTywlG7fEeWKUA7oCmpA8xdTkxZZlwNJKWqILlq0wehoZcJYfBvOyhPTQ6w==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.55.1.tgz", + "integrity": "sha512-+JiU7Jbp5cdxekIgdte0jfcu5oqw4GCKr6i3PJTlXTCU5H5Fvtkpbs4XJHRmWNXF+hKmn4v7ogI5OQPaupJgOg==", "cpu": [ "arm64" ], @@ -1507,9 +1507,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-x64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.60.1.tgz", - "integrity": "sha512-KVB2rqsxTHuBtfOeySEyzEOB7ltlB/ux38iu2rBQzkjbwRVlkhAGIEDiiYnO2kFOkJp+Z7pUXKyrRRFuFUKt+g==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.55.1.tgz", + "integrity": "sha512-V5xC1tOVWtLLmr3YUk2f6EJK4qksksOYiz/TCsFHu/R+woubcLWdC9nZQmwjOAbmExBIVKsm1/wKmEy4z4u4Bw==", "cpu": [ "x64" ], @@ -1521,16 +1521,13 @@ ] }, "node_modules/@rollup/rollup-linux-arm-gnueabihf": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.60.1.tgz", - "integrity": "sha512-L+34Qqil+v5uC0zEubW7uByo78WOCIrBvci69E7sFASRl0X7b/MB6Cqd1lky/CtcSVTydWa2WZwFuWexjS5o6g==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.55.1.tgz", + "integrity": "sha512-Rn3n+FUk2J5VWx+ywrG/HGPTD9jXNbicRtTM11e/uorplArnXZYsVifnPPqNNP5BsO3roI4n8332ukpY/zN7rQ==", "cpu": [ "arm" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1538,16 +1535,13 @@ ] }, "node_modules/@rollup/rollup-linux-arm-musleabihf": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.60.1.tgz", - "integrity": "sha512-n83O8rt4v34hgFzlkb1ycniJh7IR5RCIqt6mz1VRJD6pmhRi0CXdmfnLu9dIUS6buzh60IvACM842Ffb3xd6Gg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.55.1.tgz", + "integrity": "sha512-grPNWydeKtc1aEdrJDWk4opD7nFtQbMmV7769hiAaYyUKCT1faPRm2av8CX1YJsZ4TLAZcg9gTR1KvEzoLjXkg==", "cpu": [ "arm" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1555,16 +1549,13 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.60.1.tgz", - "integrity": "sha512-Nql7sTeAzhTAja3QXeAI48+/+GjBJ+QmAH13snn0AJSNL50JsDqotyudHyMbO2RbJkskbMbFJfIJKWA6R1LCJQ==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.55.1.tgz", + "integrity": "sha512-a59mwd1k6x8tXKcUxSyISiquLwB5pX+fJW9TkWU46lCqD/GRDe9uDN31jrMmVP3feI3mhAdvcCClhV8V5MhJFQ==", "cpu": [ "arm64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1572,16 +1563,13 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.60.1.tgz", - "integrity": "sha512-+pUymDhd0ys9GcKZPPWlFiZ67sTWV5UU6zOJat02M1+PiuSGDziyRuI/pPue3hoUwm2uGfxdL+trT6Z9rxnlMA==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.55.1.tgz", + "integrity": "sha512-puS1MEgWX5GsHSoiAsF0TYrpomdvkaXm0CofIMG5uVkP6IBV+ZO9xhC5YEN49nsgYo1DuuMquF9+7EDBVYu4uA==", "cpu": [ "arm64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1589,16 +1577,13 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.60.1.tgz", - "integrity": "sha512-VSvgvQeIcsEvY4bKDHEDWcpW4Yw7BtlKG1GUT4FzBUlEKQK0rWHYBqQt6Fm2taXS+1bXvJT6kICu5ZwqKCnvlQ==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.55.1.tgz", + "integrity": "sha512-r3Wv40in+lTsULSb6nnoudVbARdOwb2u5fpeoOAZjFLznp6tDU8kd+GTHmJoqZ9lt6/Sys33KdIHUaQihFcu7g==", "cpu": [ "loong64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1606,16 +1591,13 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.60.1.tgz", - "integrity": "sha512-4LqhUomJqwe641gsPp6xLfhqWMbQV04KtPp7/dIp0nzPxAkNY1AbwL5W0MQpcalLYk07vaW9Kp1PBhdpZYYcEw==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.55.1.tgz", + "integrity": "sha512-MR8c0+UxAlB22Fq4R+aQSPBayvYa3+9DrwG/i1TKQXFYEaoW3B5b/rkSRIypcZDdWjWnpcvxbNaAJDcSbJU3Lw==", "cpu": [ "loong64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1623,16 +1605,13 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.60.1.tgz", - "integrity": "sha512-tLQQ9aPvkBxOc/EUT6j3pyeMD6Hb8QF2BTBnCQWP/uu1lhc9AIrIjKnLYMEroIz/JvtGYgI9dF3AxHZNaEH0rw==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.55.1.tgz", + "integrity": "sha512-3KhoECe1BRlSYpMTeVrD4sh2Pw2xgt4jzNSZIIPLFEsnQn9gAnZagW9+VqDqAHgm1Xc77LzJOo2LdigS5qZ+gw==", "cpu": [ "ppc64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1640,16 +1619,13 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.60.1.tgz", - "integrity": "sha512-RMxFhJwc9fSXP6PqmAz4cbv3kAyvD1etJFjTx4ONqFP9DkTkXsAMU4v3Vyc5BgzC+anz7nS/9tp4obsKfqkDHg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.55.1.tgz", + "integrity": "sha512-ziR1OuZx0vdYZZ30vueNZTg73alF59DicYrPViG0NEgDVN8/Jl87zkAPu4u6VjZST2llgEUjaiNl9JM6HH1Vdw==", "cpu": [ "ppc64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1657,16 +1633,13 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.60.1.tgz", - "integrity": "sha512-QKgFl+Yc1eEk6MmOBfRHYF6lTxiiiV3/z/BRrbSiW2I7AFTXoBFvdMEyglohPj//2mZS4hDOqeB0H1ACh3sBbg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.55.1.tgz", + "integrity": "sha512-uW0Y12ih2XJRERZ4jAfKamTyIHVMPQnTZcQjme2HMVDAHY4amf5u414OqNYC+x+LzRdRcnIG1YodLrrtA8xsxw==", "cpu": [ "riscv64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1674,16 +1647,13 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.60.1.tgz", - "integrity": "sha512-RAjXjP/8c6ZtzatZcA1RaQr6O1TRhzC+adn8YZDnChliZHviqIjmvFwHcxi4JKPSDAt6Uhf/7vqcBzQJy0PDJg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.55.1.tgz", + "integrity": "sha512-u9yZ0jUkOED1BFrqu3BwMQoixvGHGZ+JhJNkNKY/hyoEgOwlqKb62qu+7UjbPSHYjiVy8kKJHvXKv5coH4wDeg==", "cpu": [ "riscv64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1691,16 +1661,13 @@ ] }, "node_modules/@rollup/rollup-linux-s390x-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.60.1.tgz", - "integrity": "sha512-wcuocpaOlaL1COBYiA89O6yfjlp3RwKDeTIA0hM7OpmhR1Bjo9j31G1uQVpDlTvwxGn2nQs65fBFL5UFd76FcQ==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.55.1.tgz", + "integrity": "sha512-/0PenBCmqM4ZUd0190j7J0UsQ/1nsi735iPRakO8iPciE7BQ495Y6msPzaOmvx0/pn+eJVVlZrNrSh4WSYLxNg==", "cpu": [ "s390x" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1708,16 +1675,13 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.60.1.tgz", - "integrity": "sha512-77PpsFQUCOiZR9+LQEFg9GClyfkNXj1MP6wRnzYs0EeWbPcHs02AXu4xuUbM1zhwn3wqaizle3AEYg5aeoohhg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.55.1.tgz", + "integrity": "sha512-a8G4wiQxQG2BAvo+gU6XrReRRqj+pLS2NGXKm8io19goR+K8lw269eTrPkSdDTALwMmJp4th2Uh0D8J9bEV1vg==", "cpu": [ "x64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1725,16 +1689,13 @@ ] }, "node_modules/@rollup/rollup-linux-x64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.60.1.tgz", - "integrity": "sha512-5cIATbk5vynAjqqmyBjlciMJl1+R/CwX9oLk/EyiFXDWd95KpHdrOJT//rnUl4cUcskrd0jCCw3wpZnhIHdD9w==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.55.1.tgz", + "integrity": "sha512-bD+zjpFrMpP/hqkfEcnjXWHMw5BIghGisOKPj+2NaNDuVT+8Ds4mPf3XcPHuat1tz89WRL+1wbcxKY3WSbiT7w==", "cpu": [ "x64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1742,9 +1703,9 @@ ] }, "node_modules/@rollup/rollup-openbsd-x64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.60.1.tgz", - "integrity": "sha512-cl0w09WsCi17mcmWqqglez9Gk8isgeWvoUZ3WiJFYSR3zjBQc2J5/ihSjpl+VLjPqjQ/1hJRcqBfLjssREQILw==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.55.1.tgz", + "integrity": "sha512-eLXw0dOiqE4QmvikfQ6yjgkg/xDM+MdU9YJuP4ySTibXU0oAvnEWXt7UDJmD4UkYialMfOGFPJnIHSe/kdzPxg==", "cpu": [ "x64" ], @@ -1756,9 +1717,9 @@ ] }, "node_modules/@rollup/rollup-openharmony-arm64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.60.1.tgz", - "integrity": "sha512-4Cv23ZrONRbNtbZa37mLSueXUCtN7MXccChtKpUnQNgF010rjrjfHx3QxkS2PI7LqGT5xXyYs1a7LbzAwT0iCA==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.55.1.tgz", + "integrity": "sha512-xzm44KgEP11te3S2HCSyYf5zIzWmx3n8HDCc7EE59+lTcswEWNpvMLfd9uJvVX8LCg9QWG67Xt75AuHn4vgsXw==", "cpu": [ "arm64" ], @@ -1770,9 +1731,9 @@ ] }, "node_modules/@rollup/rollup-win32-arm64-msvc": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.60.1.tgz", - "integrity": "sha512-i1okWYkA4FJICtr7KpYzFpRTHgy5jdDbZiWfvny21iIKky5YExiDXP+zbXzm3dUcFpkEeYNHgQ5fuG236JPq0g==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.55.1.tgz", + "integrity": "sha512-yR6Bl3tMC/gBok5cz/Qi0xYnVbIxGx5Fcf/ca0eB6/6JwOY+SRUcJfI0OpeTpPls7f194as62thCt/2BjxYN8g==", "cpu": [ "arm64" ], @@ -1784,9 +1745,9 @@ ] }, "node_modules/@rollup/rollup-win32-ia32-msvc": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.60.1.tgz", - "integrity": "sha512-u09m3CuwLzShA0EYKMNiFgcjjzwqtUMLmuCJLeZWjjOYA3IT2Di09KaxGBTP9xVztWyIWjVdsB2E9goMjZvTQg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.55.1.tgz", + "integrity": "sha512-3fZBidchE0eY0oFZBnekYCfg+5wAB0mbpCBuofh5mZuzIU/4jIVkbESmd2dOsFNS78b53CYv3OAtwqkZZmU5nA==", "cpu": [ "ia32" ], @@ -1798,9 +1759,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.60.1.tgz", - "integrity": "sha512-k+600V9Zl1CM7eZxJgMyTUzmrmhB/0XZnF4pRypKAlAgxmedUA+1v9R+XOFv56W4SlHEzfeMtzujLJD22Uz5zg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.55.1.tgz", + "integrity": "sha512-xGGY5pXj69IxKb4yv/POoocPy/qmEGhimy/FoTpTSVju3FYXUQQMFCaZZXJVidsmGxRioZAwpThl/4zX41gRKg==", "cpu": [ "x64" ], @@ -1812,9 +1773,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-msvc": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.60.1.tgz", - "integrity": "sha512-lWMnixq/QzxyhTV6NjQJ4SFo1J6PvOX8vUx5Wb4bBPsEb+8xZ89Bz6kOXpfXj9ak9AHTQVQzlgzBEc1SyM27xQ==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.55.1.tgz", + "integrity": "sha512-SPEpaL6DX4rmcXtnhdrQYgzQ5W2uW3SCJch88lB2zImhJRhIIK44fkUrgIV/Q8yUNfw5oyZ5vkeQsZLhCb06lw==", "cpu": [ "x64" ], @@ -2183,9 +2144,9 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz", - "integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -2193,13 +2154,13 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": { - "version": "9.0.9", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", - "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", "dev": true, "license": "ISC", "dependencies": { - "brace-expansion": "^2.0.2" + "brace-expansion": "^2.0.1" }, "engines": { "node": ">=16 || 14 >=14.17" @@ -2341,9 +2302,9 @@ } }, "node_modules/ajv": { - "version": "6.14.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz", - "integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==", + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", "dev": true, "license": "MIT", "dependencies": { @@ -2423,9 +2384,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.13", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", - "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", "dev": true, "license": "MIT", "dependencies": { @@ -3104,9 +3065,9 @@ } }, "node_modules/flatted": { - "version": "3.4.2", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", - "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", + "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", "dev": true, "license": "ISC" }, @@ -4491,9 +4452,9 @@ "license": "MIT" }, "node_modules/minimatch": { - "version": "3.1.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", - "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", "dev": true, "license": "ISC", "dependencies": { @@ -4698,9 +4659,9 @@ "license": "ISC" }, "node_modules/picomatch": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", - "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", "engines": { @@ -5011,9 +4972,9 @@ } }, "node_modules/rollup": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.1.tgz", - "integrity": "sha512-VmtB2rFU/GroZ4oL8+ZqXgSA38O6GR8KSIvWmEFv63pQ0G6KaBH9s07PO8XTXP4vI+3UJUEypOfjkGfmSBBR0w==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.55.1.tgz", + "integrity": "sha512-wDv/Ht1BNHB4upNbK74s9usvl7hObDnvVzknxqY/E/O3X6rW1U1rV1aENEfJ54eFZDTNo7zv1f5N4edCluH7+A==", "dev": true, "license": "MIT", "dependencies": { @@ -5027,31 +4988,31 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.60.1", - "@rollup/rollup-android-arm64": "4.60.1", - "@rollup/rollup-darwin-arm64": "4.60.1", - "@rollup/rollup-darwin-x64": "4.60.1", - "@rollup/rollup-freebsd-arm64": "4.60.1", - "@rollup/rollup-freebsd-x64": "4.60.1", - "@rollup/rollup-linux-arm-gnueabihf": "4.60.1", - "@rollup/rollup-linux-arm-musleabihf": "4.60.1", - "@rollup/rollup-linux-arm64-gnu": "4.60.1", - "@rollup/rollup-linux-arm64-musl": "4.60.1", - "@rollup/rollup-linux-loong64-gnu": "4.60.1", - "@rollup/rollup-linux-loong64-musl": "4.60.1", - "@rollup/rollup-linux-ppc64-gnu": "4.60.1", - "@rollup/rollup-linux-ppc64-musl": "4.60.1", - "@rollup/rollup-linux-riscv64-gnu": "4.60.1", - "@rollup/rollup-linux-riscv64-musl": "4.60.1", - "@rollup/rollup-linux-s390x-gnu": "4.60.1", - "@rollup/rollup-linux-x64-gnu": "4.60.1", - "@rollup/rollup-linux-x64-musl": "4.60.1", - "@rollup/rollup-openbsd-x64": "4.60.1", - "@rollup/rollup-openharmony-arm64": "4.60.1", - "@rollup/rollup-win32-arm64-msvc": "4.60.1", - "@rollup/rollup-win32-ia32-msvc": "4.60.1", - "@rollup/rollup-win32-x64-gnu": "4.60.1", - "@rollup/rollup-win32-x64-msvc": "4.60.1", + "@rollup/rollup-android-arm-eabi": "4.55.1", + "@rollup/rollup-android-arm64": "4.55.1", + "@rollup/rollup-darwin-arm64": "4.55.1", + "@rollup/rollup-darwin-x64": "4.55.1", + "@rollup/rollup-freebsd-arm64": "4.55.1", + "@rollup/rollup-freebsd-x64": "4.55.1", + "@rollup/rollup-linux-arm-gnueabihf": "4.55.1", + "@rollup/rollup-linux-arm-musleabihf": "4.55.1", + "@rollup/rollup-linux-arm64-gnu": "4.55.1", + "@rollup/rollup-linux-arm64-musl": "4.55.1", + "@rollup/rollup-linux-loong64-gnu": "4.55.1", + "@rollup/rollup-linux-loong64-musl": "4.55.1", + "@rollup/rollup-linux-ppc64-gnu": "4.55.1", + "@rollup/rollup-linux-ppc64-musl": "4.55.1", + "@rollup/rollup-linux-riscv64-gnu": "4.55.1", + "@rollup/rollup-linux-riscv64-musl": "4.55.1", + "@rollup/rollup-linux-s390x-gnu": "4.55.1", + "@rollup/rollup-linux-x64-gnu": "4.55.1", + "@rollup/rollup-linux-x64-musl": "4.55.1", + "@rollup/rollup-openbsd-x64": "4.55.1", + "@rollup/rollup-openharmony-arm64": "4.55.1", + "@rollup/rollup-win32-arm64-msvc": "4.55.1", + "@rollup/rollup-win32-ia32-msvc": "4.55.1", + "@rollup/rollup-win32-x64-gnu": "4.55.1", + "@rollup/rollup-win32-x64-msvc": "4.55.1", "fsevents": "~2.3.2" } }, @@ -5587,9 +5548,9 @@ "license": "ISC" }, "node_modules/yaml": { - "version": "1.10.3", - "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.3.tgz", - "integrity": "sha512-vIYeF1u3CjlhAFekPPAk2h/Kv4T3mAkMox5OymRiJQB0spDP10LHvt+K7G9Ny6NuuMAb25/6n1qyUjAcGNf/AA==", + "version": "1.10.2", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.2.tgz", + "integrity": "sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==", "license": "ISC", "engines": { "node": ">= 6" diff --git a/frontend/public/smolagents.webp b/frontend/public/smolagents.webp deleted file mode 100644 index 4be2c482082e0d2f08c88336d89a71f2b1f2f55e..0000000000000000000000000000000000000000 Binary files a/frontend/public/smolagents.webp and /dev/null differ diff --git a/frontend/src/components/Chat/ActivityStatusBar.tsx b/frontend/src/components/Chat/ActivityStatusBar.tsx index 3dd0af534ec1b7fa5861116b865b388f76a42eaa..3a266e1777cac7b53d5a3a92646e7220cf48d9ad 100644 --- a/frontend/src/components/Chat/ActivityStatusBar.tsx +++ b/frontend/src/components/Chat/ActivityStatusBar.tsx @@ -34,15 +34,9 @@ function formatResearchStatus(raw: string): string { if (typeof v === 'string') args[k] = v; } } catch { - // JSON is likely truncated β€” extract complete "key": "value" pairs for (const m of jsonStr.matchAll(/"(\w+)":\s*"([^"]*)"/g)) { args[m[1]] = m[2]; } - // Also try to extract a truncated value for known keys if not found yet - if (!args.query && !args.arxiv_id) { - const partial = jsonStr.match(/"(query|arxiv_id)":\s*"([^"]*)/); - if (partial) args[partial[1]] = partial[2]; - } } } @@ -68,15 +62,12 @@ function formatResearchStatus(raw: string): string { } if (toolName === 'hf_papers') { const op = args.operation as string; - const detail = (args.query) || (args.arxiv_id) || (args.positive_ids); + const detail = (args.query) || (args.arxiv_id); const opLabels: Record = { trending: 'Browsing trending papers', search: 'Searching papers', paper_details: 'Reading paper details', read_paper: 'Reading paper', - citation_graph: 'Tracing citations', - snippet_search: 'Searching paper passages', - recommend: 'Finding similar papers', find_datasets: 'Finding paper datasets', find_models: 'Finding paper models', find_collections: 'Finding paper collections', diff --git a/frontend/src/components/Chat/AssistantMessage.tsx b/frontend/src/components/Chat/AssistantMessage.tsx index 91c7b8c1012bf1513ca141999d1acc7cfa23284f..83bd8cae505808781908a2292eaa8acc1242536b 100644 --- a/frontend/src/components/Chat/AssistantMessage.tsx +++ b/frontend/src/components/Chat/AssistantMessage.tsx @@ -1,19 +1,13 @@ -import { useMemo, useState } from 'react'; -import { Box, IconButton, Stack, Tooltip, Typography } from '@mui/material'; -import ThumbUpOutlined from '@mui/icons-material/ThumbUpOutlined'; -import ThumbUp from '@mui/icons-material/ThumbUp'; -import ThumbDownOutlined from '@mui/icons-material/ThumbDownOutlined'; -import ThumbDown from '@mui/icons-material/ThumbDown'; +import { useMemo } from 'react'; +import { Box, Stack, Typography } from '@mui/material'; import MarkdownContent from './MarkdownContent'; import ToolCallGroup from './ToolCallGroup'; -import { apiFetch } from '@/utils/api'; import type { UIMessage } from 'ai'; import type { MessageMeta } from '@/types/agent'; interface AssistantMessageProps { message: UIMessage; isStreaming?: boolean; - sessionId?: string | null; approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; } @@ -49,27 +43,8 @@ function groupParts(parts: UIMessage['parts']) { return groups; } -export default function AssistantMessage({ message, isStreaming = false, sessionId, approveTools }: AssistantMessageProps) { +export default function AssistantMessage({ message, isStreaming = false, approveTools }: AssistantMessageProps) { const groups = useMemo(() => groupParts(message.parts), [message.parts]); - const [feedback, setFeedback] = useState<'up' | 'down' | null>(null); - const [feedbackBusy, setFeedbackBusy] = useState(false); - - const sendFeedback = async (rating: 'up' | 'down') => { - if (!sessionId || feedbackBusy) return; - setFeedbackBusy(true); - // Optimistic toggle β€” feedback is observability, not a hard requirement. - setFeedback(rating); - try { - await apiFetch(`/api/feedback/${sessionId}`, { - method: 'POST', - body: JSON.stringify({ rating, message_id: message.id }), - }); - } catch { - // Silently swallow β€” don't block chat UX on a telemetry write. - } finally { - setFeedbackBusy(false); - } - }; // Find the last text group index for streaming cursor let lastTextIdx = -1; @@ -139,24 +114,6 @@ export default function AssistantMessage({ message, isStreaming = false, session return null; })} - {!isStreaming && sessionId && ( - - - sendFeedback('up')}> - {feedback === 'up' ? : } - - - - sendFeedback('down')}> - {feedback === 'down' ? : } - - - - )} ); } diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index 8a8810eac905be0639b450d11d35a8ca9c675252..df8bacb899a43dfd45bd653442366c67e8ef822a 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,34 +1,9 @@ import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; -import { - Alert, - Box, - TextField, - IconButton, - CircularProgress, - Typography, - Menu, - MenuItem, - ListItemIcon, - ListItemText, - Chip, - Snackbar, -} from '@mui/material'; +import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; -import StopIcon from '@mui/icons-material/Stop'; +import CloseIcon from '@mui/icons-material/Close'; import { apiFetch } from '@/utils/api'; -import { useUserQuota } from '@/hooks/useUserQuota'; -import ClaudeCapDialog from '@/components/ClaudeCapDialog'; -import JobsUpgradeDialog from '@/components/JobsUpgradeDialog'; -import { useAgentStore } from '@/store/agentStore'; -import { useSessionStore } from '@/store/sessionStore'; -import { - CLAUDE_MODEL_PATH, - FIRST_FREE_MODEL_PATH, - GPT_55_MODEL_PATH, - isClaudePath, - isPremiumPath, -} from '@/utils/model'; // Model configuration interface ModelOption { @@ -45,77 +20,44 @@ const getHfAvatarUrl = (modelId: string) => { return `https://huggingface.co/api/avatars/${org}`; }; -const DEFAULT_MODEL_OPTIONS: ModelOption[] = [ - { - id: 'kimi-k2.6', - name: 'Kimi K2.6', - description: 'Novita', - modelPath: 'moonshotai/Kimi-K2.6', - avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.6'), - recommended: true, - }, +const MODEL_OPTIONS: ModelOption[] = [ { id: 'claude-opus', name: 'Claude Opus 4.6', description: 'Anthropic', - modelPath: CLAUDE_MODEL_PATH, + modelPath: 'anthropic/claude-opus-4-6', avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', recommended: true, }, { - id: 'gpt-5.5', - name: 'GPT-5.5', - description: 'OpenAI', - modelPath: GPT_55_MODEL_PATH, - avatarUrl: 'https://huggingface.co/api/avatars/openai', - }, - { - id: 'minimax-m2.7', - name: 'MiniMax M2.7', - description: 'Novita', - modelPath: 'MiniMaxAI/MiniMax-M2.7', - avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.7'), + id: 'minimax-m2.5', + name: 'MiniMax M2.5', + description: 'Via Fireworks', + modelPath: 'huggingface/fireworks-ai/MiniMaxAI/MiniMax-M2.5', + avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.5'), + recommended: true, }, { - id: 'glm-5.1', - name: 'GLM 5.1', - description: 'Together', - modelPath: 'zai-org/GLM-5.1', - avatarUrl: getHfAvatarUrl('zai-org/GLM-5.1'), + id: 'kimi-k2.5', + name: 'Kimi K2.5', + description: 'Via Novita', + modelPath: 'huggingface/novita/moonshotai/kimi-k2.5', + avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.5'), }, { - id: 'deepseek-v4-pro', - name: 'DeepSeek V4 Pro', - description: 'DeepInfra', - modelPath: 'deepseek-ai/DeepSeek-V4-Pro:deepinfra', - avatarUrl: getHfAvatarUrl('deepseek-ai/DeepSeek-V4-Pro'), + id: 'glm-5', + name: 'GLM 5', + description: 'Via Novita', + modelPath: 'huggingface/novita/zai-org/glm-5', + avatarUrl: getHfAvatarUrl('zai-org/GLM-5'), }, ]; -const findModelByPath = (path: string, options: ModelOption[]): ModelOption | undefined => { - if (isClaudePath(path)) { - const claude = options.find(isClaudeModel); - if (claude) return claude; - } - return options.find(m => m.modelPath === path || path?.includes(m.id)); -}; - -const readApiErrorMessage = async (res: Response, fallback: string): Promise => { - try { - const data = await res.json(); - const detail = data?.detail; - if (typeof detail === 'string') return detail; - if (detail && typeof detail.message === 'string') return detail.message; - if (detail && typeof detail.error === 'string') return detail.error; - } catch { - /* ignore malformed error bodies */ - } - return fallback; +const findModelByPath = (path: string): ModelOption | undefined => { + return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); }; interface ChatInputProps { - sessionId?: string; - initialModelPath?: string | null; onSend: (text: string) => void; onStop?: () => void; isProcessing?: boolean; @@ -123,90 +65,36 @@ interface ChatInputProps { placeholder?: string; } -const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath); -const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath); -const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0]; - -export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { +export default function ChatInput({ onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); + const [stopHovered, setStopHovered] = useState(false); const inputRef = useRef(null); - const [modelOptions, setModelOptions] = useState(DEFAULT_MODEL_OPTIONS); - const modelOptionsRef = useRef(DEFAULT_MODEL_OPTIONS); - const sessionIdRef = useRef(sessionId); - const [selectedModelId, setSelectedModelId] = useState( - () => findModelByPath(initialModelPath ?? '', DEFAULT_MODEL_OPTIONS)?.id ?? DEFAULT_MODEL_OPTIONS[0].id, - ); + const [selectedModelId, setSelectedModelId] = useState(() => { + try { + const stored = localStorage.getItem('hf-agent-model'); + if (stored && MODEL_OPTIONS.some(m => m.id === stored)) return stored; + } catch { /* localStorage unavailable */ } + return MODEL_OPTIONS[0].id; + }); const [modelAnchorEl, setModelAnchorEl] = useState(null); - const { quota, refresh: refreshQuota } = useUserQuota(); - // The daily-cap dialog is triggered from two places: (a) a 429 returned - // from the chat transport when the user tries to send on a premium model over cap β€” - // surfaced via the agent-store flag β€” and (b) nothing else right now - // (switching models is free). Keeping the open state in the store means - // the hook layer can flip it without threading props through. - const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted); - const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted); - const jobsUpgradeRequired = useAgentStore((s) => s.jobsUpgradeRequired); - const setJobsUpgradeRequired = useAgentStore((s) => s.setJobsUpgradeRequired); - const updateSessionModel = useSessionStore((s) => s.updateSessionModel); - const [awaitingTopUp, setAwaitingTopUp] = useState(false); - const [modelSwitchError, setModelSwitchError] = useState(null); - const lastSentRef = useRef(''); + // Sync with backend on mount (backend is source of truth, localStorage is just a cache) useEffect(() => { - modelOptionsRef.current = modelOptions; - }, [modelOptions]); - - useEffect(() => { - sessionIdRef.current = sessionId; - }, [sessionId]); - - useEffect(() => { - let cancelled = false; - apiFetch('/api/config/model') + fetch('/api/config/model') .then((res) => (res.ok ? res.json() : null)) .then((data) => { - if (cancelled || !data?.available) return; - const claude = data.available.find((m: { provider?: string; id?: string }) => ( - m.provider === 'anthropic' && m.id - )); - if (!claude?.id) return; - - const next = DEFAULT_MODEL_OPTIONS.map((option) => ( - isClaudeModel(option) - ? { ...option, modelPath: claude.id, name: claude.label ?? option.name } - : option - )); - modelOptionsRef.current = next; - setModelOptions(next); - if (!sessionIdRef.current) { - const current = data.current ? findModelByPath(data.current, next) : null; - if (current) setSelectedModelId(current.id); + if (data?.current) { + const model = findModelByPath(data.current); + if (model) { + setSelectedModelId(model.id); + try { localStorage.setItem('hf-agent-model', model.id); } catch { /* ignore */ } + } } }) .catch(() => { /* ignore */ }); - return () => { cancelled = true; }; }, []); - // Model is per-session: fetch this tab's current model every time the - // session changes. Other tabs keep their own selections independently. - useEffect(() => { - if (!sessionId) return; - let cancelled = false; - apiFetch(`/api/session/${sessionId}`) - .then((res) => (res.ok ? res.json() : null)) - .then((data) => { - if (cancelled) return; - if (data?.model) { - const model = findModelByPath(data.model, modelOptionsRef.current); - if (model) setSelectedModelId(model.id); - updateSessionModel(sessionId, data.model); - } - }) - .catch(() => { /* ignore */ }); - return () => { cancelled = true; }; - }, [sessionId, updateSessionModel]); - - const selectedModel = modelOptions.find(m => m.id === selectedModelId) || modelOptions[0]; + const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; // Auto-focus the textarea when the session becomes ready useEffect(() => { @@ -217,27 +105,11 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, const handleSend = useCallback(() => { if (input.trim() && !disabled) { - lastSentRef.current = input; onSend(input); setInput(''); } }, [input, disabled, onSend]); - // When the chat transport reports a premium-model quota 429, restore the typed - // text so the user doesn't lose their message. - useEffect(() => { - if (claudeQuotaExhausted && lastSentRef.current) { - setInput(lastSentRef.current); - } - }, [claudeQuotaExhausted]); - - // Refresh the quota display whenever the session changes (user might - // have started another tab that spent quota). - useEffect(() => { - if (sessionId) refreshQuota(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [sessionId]); - const handleKeyDown = useCallback( (e: KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { @@ -258,116 +130,17 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, const handleSelectModel = async (model: ModelOption) => { handleModelClose(); - if (!sessionId) return; try { - const res = await apiFetch(`/api/session/${sessionId}/model`, { + const res = await apiFetch('/api/config/model', { method: 'POST', body: JSON.stringify({ model: model.modelPath }), }); if (res.ok) { setSelectedModelId(model.id); - updateSessionModel(sessionId, model.modelPath); - setModelSwitchError(null); - return; - } - setModelSwitchError(await readApiErrorMessage(res, 'Could not switch model.')); - } catch (error) { - setModelSwitchError(error instanceof Error ? error.message : 'Could not switch model.'); - } - }; - - // Dialog close: just clear the flag. The typed text is already restored. - const handleCapDialogClose = useCallback(() => { - setClaudeQuotaExhausted(false); - }, [setClaudeQuotaExhausted]); - - // "Use a free model" β€” switch the current session to Kimi (or the first - // non-premium option) and auto-retry the send that tripped the cap. - const handleUseFreeModel = useCallback(async () => { - setClaudeQuotaExhausted(false); - if (!sessionId) return; - const free = modelOptions.find(m => m.modelPath === FIRST_FREE_MODEL_PATH) - ?? firstFreeModel(modelOptions); - try { - const res = await apiFetch(`/api/session/${sessionId}/model`, { - method: 'POST', - body: JSON.stringify({ model: free.modelPath }), - }); - if (res.ok) { - setSelectedModelId(free.id); - updateSessionModel(sessionId, free.modelPath); - const retryText = lastSentRef.current; - if (retryText) { - onSend(retryText); - setInput(''); - lastSentRef.current = ''; - } + try { localStorage.setItem('hf-agent-model', model.id); } catch { /* ignore */ } } } catch { /* ignore */ } - }, [sessionId, onSend, setClaudeQuotaExhausted, modelOptions, updateSessionModel]); - - const handlePremiumUpgradeClick = useCallback(async () => { - if (!sessionId) return; - try { - await apiFetch(`/api/pro-click/${sessionId}`, { - method: 'POST', - body: JSON.stringify({ source: 'premium_cap_dialog', target: 'pro_pricing' }), - }); - } catch { - /* tracking is best-effort */ - } - }, [sessionId]); - - const handleJobsUpgradeClose = useCallback(() => { - setJobsUpgradeRequired(null); - setAwaitingTopUp(false); - }, [setJobsUpgradeRequired]); - - const handleJobsUpgradeClick = useCallback(async () => { - setAwaitingTopUp(true); - if (!sessionId || !jobsUpgradeRequired) return; - try { - await apiFetch(`/api/pro-click/${sessionId}`, { - method: 'POST', - body: JSON.stringify({ source: 'hf_jobs_billing_dialog', target: 'hf_billing' }), - }); - } catch { - /* tracking is best-effort */ - } - }, [sessionId, jobsUpgradeRequired]); - - const handleJobsRetry = useCallback(() => { - const namespace = jobsUpgradeRequired?.namespace; - setJobsUpgradeRequired(null); - setAwaitingTopUp(false); - const msg = namespace - ? `I just added credits to the \`${namespace}\` namespace. Please retry the previous job.` - : "I just added credits. Please retry the previous job."; - onSend(msg); - }, [jobsUpgradeRequired, setJobsUpgradeRequired, onSend]); - - // Auto-retry when the user comes back to this tab after clicking "Add credits". - // Browsers fire visibilitychange when the tab regains focus from a sibling tab. - useEffect(() => { - if (!awaitingTopUp || !jobsUpgradeRequired) return; - const onVisible = () => { - if (document.visibilityState === 'visible') { - handleJobsRetry(); - } - }; - document.addEventListener('visibilitychange', onVisible); - return () => document.removeEventListener('visibilitychange', onVisible); - }, [awaitingTopUp, jobsUpgradeRequired, handleJobsRetry]); - - // Hide the chip until the user has actually burned quota; opening a - // premium-model session without sending should not populate a counter. - const premiumChip = (() => { - if (!quota || quota.premiumUsedToday === 0) return null; - if (quota.plan === 'free') { - return quota.premiumRemaining > 0 ? 'Free today' : 'Pro only'; - } - return `${quota.premiumUsedToday}/${quota.premiumDailyCap} today`; - })(); + }; return ( setStopHovered(true)} + onMouseLeave={() => setStopHovered(false)} sx={{ mt: 1, - p: 1.5, + p: 1, borderRadius: '10px', - color: 'var(--muted-text)', + color: stopHovered ? 'var(--accent-yellow)' : 'var(--muted-text)', transition: 'all 0.2s', - position: 'relative', '&:hover': { bgcolor: 'var(--hover-bg)', - color: 'var(--accent-red)', }, }} > - - - - + {stopHovered ? : } ) : ( - {modelOptions.map((model) => ( + {MODEL_OPTIONS.map((model) => ( handleSelectModel(model)} @@ -567,19 +337,6 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, }} /> )} - {isPremiumModel(model) && premiumChip && ( - - )} } secondary={model.description} @@ -590,38 +347,6 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, ))} - - - - setModelSwitchError(null)} - autoHideDuration={6000} - > - setModelSwitchError(null)} - sx={{ fontSize: '0.8rem', maxWidth: 480 }} - > - {modelSwitchError} - - ); diff --git a/frontend/src/components/Chat/ExpiredBanner.tsx b/frontend/src/components/Chat/ExpiredBanner.tsx deleted file mode 100644 index 32f638c245089fc1d26561c8f2706609a5a48345..0000000000000000000000000000000000000000 --- a/frontend/src/components/Chat/ExpiredBanner.tsx +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Shown inline in a chat when the backend no longer recognizes the - * session id (typically: Space was restarted). Lets the user catch the - * agent up with a summary of the prior conversation, or start over. - */ -import { useState, useCallback } from 'react'; -import { Box, Button, CircularProgress, Typography } from '@mui/material'; -import { apiFetch } from '@/utils/api'; -import { useSessionStore } from '@/store/sessionStore'; -import { useAgentStore } from '@/store/agentStore'; -import { loadBackendMessages } from '@/lib/backend-message-store'; -import { loadMessages } from '@/lib/chat-message-store'; -import { uiMessagesToLLMMessages } from '@/lib/convert-llm-messages'; -import { logger } from '@/utils/logger'; - -interface Props { - sessionId: string; -} - -export default function ExpiredBanner({ sessionId }: Props) { - const { renameSession, deleteSession, updateSessionModel } = useSessionStore(); - const [busy, setBusy] = useState<'catch-up' | 'start-over' | null>(null); - const [error, setError] = useState(null); - - const handleCatchUp = useCallback(async () => { - setBusy('catch-up'); - setError(null); - try { - // Prefer the raw backend-message cache; fall back to reconstructing - // from UIMessages (for sessions that predate the backend cache). - let messages = loadBackendMessages(sessionId); - if (!messages || messages.length === 0) { - const uiMsgs = loadMessages(sessionId); - if (uiMsgs.length > 0) messages = uiMessagesToLLMMessages(uiMsgs); - } - if (!messages || messages.length === 0) { - setError('Nothing to summarize from this chat.'); - setBusy(null); - return; - } - - const res = await apiFetch('/api/session/restore-summary', { - method: 'POST', - body: JSON.stringify({ messages }), - }); - if (!res.ok) throw new Error(`restore-summary failed: ${res.status}`); - const data = await res.json(); - const newId = data.session_id as string | undefined; - if (!newId) throw new Error('no session_id in response'); - - useAgentStore.getState().clearSessionState(sessionId); - renameSession(sessionId, newId); - if (data.model) updateSessionModel(newId, data.model); - } catch (e) { - logger.warn('Catch-up failed:', e); - setError("Couldn't catch up β€” try starting over."); - setBusy(null); - } - }, [sessionId, renameSession, updateSessionModel]); - - const handleStartOver = useCallback(() => { - setBusy('start-over'); - useAgentStore.getState().clearSessionState(sessionId); - deleteSession(sessionId); - }, [sessionId, deleteSession]); - - return ( - - - Where were we? - - - Let me skim the conversation so far and pick up right where we left - off β€” or we can start something new. - - - - - - {error && ( - - {error} - - )} - - ); -} diff --git a/frontend/src/components/Chat/MarkdownContent.tsx b/frontend/src/components/Chat/MarkdownContent.tsx index 0d1e69171d3955e998d78807006862bd95422c34..beb682720bf2b4d846b67a86d45607bc4544044b 100644 --- a/frontend/src/components/Chat/MarkdownContent.tsx +++ b/frontend/src/components/Chat/MarkdownContent.tsx @@ -1,4 +1,4 @@ -import { useMemo, useRef, useState, useEffect, type ComponentPropsWithoutRef } from 'react'; +import { useMemo, useRef, useState, useEffect } from 'react'; import { Box } from '@mui/material'; import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; @@ -70,30 +70,16 @@ const markdownSx: SxProps = { width: '100%', my: 2, fontSize: '0.85rem', - display: 'block', - overflowX: 'auto', - WebkitOverflowScrolling: 'touch', - }, - '& thead': { - position: 'sticky', - top: 0, }, '& th': { borderBottom: '2px solid var(--border-hover)', - bgcolor: 'var(--hover-bg)', textAlign: 'left', - px: 1.5, - py: 0.75, + p: 1, fontWeight: 600, - whiteSpace: 'nowrap', }, '& td': { borderBottom: '1px solid var(--tool-border)', - px: 1.5, - py: 0.75, - }, - '& tr:nth-of-type(even) td': { - bgcolor: 'color-mix(in srgb, var(--hover-bg) 50%, transparent)', + p: 1, }, '& hr': { @@ -166,17 +152,9 @@ export default function MarkdownContent({ content, sx, isStreaming = false }: Ma const remarkPlugins = useMemo(() => [remarkGfm], []); - const components = useMemo(() => ({ - a: ({ href, children, ...props }: ComponentPropsWithoutRef<'a'>) => ( - - {children} - - ), - }), []); - return ( - {displayContent} + {displayContent} ); } diff --git a/frontend/src/components/Chat/MessageBubble.tsx b/frontend/src/components/Chat/MessageBubble.tsx index ab971205c18a9c60bb23b398e83cf1090dcd5116..ede0f8a7ffff1c18e0ab95487236d17614a499a8 100644 --- a/frontend/src/components/Chat/MessageBubble.tsx +++ b/frontend/src/components/Chat/MessageBubble.tsx @@ -9,7 +9,6 @@ interface MessageBubbleProps { onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise; isProcessing?: boolean; isStreaming?: boolean; - sessionId?: string | null; approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; } @@ -20,7 +19,6 @@ export default function MessageBubble({ onEditAndRegenerate, isProcessing = false, isStreaming = false, - sessionId, approveTools, }: MessageBubbleProps) { if (message.role === 'user') { @@ -40,7 +38,6 @@ export default function MessageBubble({ ); diff --git a/frontend/src/components/Chat/MessageList.tsx b/frontend/src/components/Chat/MessageList.tsx index 5e3efcaea901bf97970f7644fae162046e3382b2..b50a66626e3e0bd9b50d3936655cd6edc330cfb1 100644 --- a/frontend/src/components/Chat/MessageList.tsx +++ b/frontend/src/components/Chat/MessageList.tsx @@ -8,7 +8,6 @@ import type { UIMessage } from 'ai'; interface MessageListProps { messages: UIMessage[]; isProcessing: boolean; - sessionId?: string | null; approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; onUndoLastTurn: () => void | Promise; onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise; @@ -58,7 +57,7 @@ function WelcomeGreeting() { ); } -export default function MessageList({ messages, isProcessing, sessionId, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) { +export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) { const scrollContainerRef = useRef(null); const stickToBottom = useRef(true); @@ -140,7 +139,6 @@ export default function MessageList({ messages, isProcessing, sessionId, approve onEditAndRegenerate={onEditAndRegenerate} isProcessing={isProcessing} isStreaming={isProcessing && msg.id === lastAssistantId} - sessionId={sessionId} approveTools={approveTools} /> )) diff --git a/frontend/src/components/Chat/ToolCallGroup.tsx b/frontend/src/components/Chat/ToolCallGroup.tsx index b85de8b26c867e7b8f117333b367f79549ad9073..47396a0cb13e9cd39222b24d0c4b3e080ad29700 100644 --- a/frontend/src/components/Chat/ToolCallGroup.tsx +++ b/frontend/src/components/Chat/ToolCallGroup.tsx @@ -1,5 +1,5 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { Alert, Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material'; +import { Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material'; import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline'; import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline'; import OpenInNewIcon from '@mui/icons-material/OpenInNew'; @@ -7,10 +7,9 @@ import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty'; import LaunchIcon from '@mui/icons-material/Launch'; import SendIcon from '@mui/icons-material/Send'; import BlockIcon from '@mui/icons-material/Block'; -import { useAgentStore, type ResearchAgentState } from '@/store/agentStore'; +import { useAgentStore } from '@/store/agentStore'; import { useLayoutStore } from '@/store/layoutStore'; import { logger } from '@/utils/logger'; -import { RESEARCH_MAX_STEPS } from '@/lib/research-store'; import type { UIMessage } from 'ai'; // --------------------------------------------------------------------------- @@ -36,22 +35,16 @@ interface ToolCallGroupProps { // Research sub-steps (inline under the research tool row) // --------------------------------------------------------------------------- -/** Hook that forces a re-render every second while enabled β€” used so each - * research card can compute its own elapsed seconds synchronously from - * Date.now() without needing its own timer. */ -function useSecondTick(enabled: boolean): void { - const [, setTick] = useState(0); +/** Hook that ticks every second while startedAt is set, returning elapsed seconds. */ +function useElapsed(startedAt: number | null): number | null { + const [elapsed, setElapsed] = useState(null); useEffect(() => { - if (!enabled) return; - const id = setInterval(() => setTick(t => t + 1), 1000); + if (startedAt === null) { setElapsed(null); return; } + setElapsed(Math.round((Date.now() - startedAt) / 1000)); + const id = setInterval(() => setElapsed(Math.round((Date.now() - startedAt) / 1000)), 1000); return () => clearInterval(id); - }, [enabled]); -} - -/** Compute elapsed seconds from startedAt (or null). Call under useSecondTick. */ -function computeElapsed(startedAt: number | null): number | null { - if (startedAt === null) return null; - return Math.round((Date.now() - startedAt) / 1000); + }, [startedAt]); + return elapsed; } /** Format token count like the CLI: "12.4k" or "800". */ @@ -95,17 +88,9 @@ function parseStepArgs(step: string): Record { } catch { // JSON likely truncated β€” extract key-value pairs via regex const result: Record = {}; - // Match complete "key": "value" pairs for (const m of jsonStr.matchAll(/"(\w+)":\s*"([^"]*)"/g)) { result[m[1]] = m[2]; } - // Match truncated trailing value: "key": "value... (no closing quote) - if (Object.keys(result).length === 0 || !result.query) { - const trunc = jsonStr.match(/"(\w+)":\s*"([^"]+)$/); - if (trunc && !result[trunc[1]]) { - result[trunc[1]] = trunc[2]; - } - } return result; } } @@ -146,9 +131,6 @@ function formatResearchStep(raw: string): { label: string } { search: 'Searching papers', paper_details: 'Reading paper details', read_paper: 'Reading paper', - citation_graph: 'Tracing citations', - snippet_search: 'Searching paper snippets', - recommend: 'Finding related papers', find_datasets: 'Finding paper datasets', find_models: 'Finding paper models', find_collections: 'Finding paper collections', @@ -178,9 +160,10 @@ function formatResearchStep(raw: string): { label: string } { return { label: step.replace(/^β–Έ\s*/, '') }; } -/** Rolling display of research sub-tool calls for a single agent. */ -function ResearchSteps({ steps }: { steps: string[] }) { - const visible = steps.slice(-RESEARCH_MAX_STEPS); +/** Rolling 2-line display of research sub-tool calls β€” hidden when complete. */ +function ResearchSteps({ steps, isRunning }: { steps: string[]; isRunning: boolean }) { + if (!isRunning) return null; + const visible = steps.slice(-2); if (visible.length === 0) return null; return ( @@ -220,193 +203,8 @@ function ResearchSteps({ steps }: { steps: string[] }) { ); } -// --------------------------------------------------------------------------- -// Trackio dashboard embed -// --------------------------------------------------------------------------- - -// HF repo IDs are `/` where each segment is alphanumerics plus -// `_`, `.`, `-`. Anything else (slashes, spaces, query params, missing owner) -// would let an attacker-controlled string redirect the embed to a different -// Space, so we refuse to render rather than build a malformed URL. -const SPACE_ID_PATTERN = /^[a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+$/; - -function isValidSpaceId(spaceId: string): boolean { - return SPACE_ID_PATTERN.test(spaceId); -} - -/** HF Space embed subdomain: 'user/space_name' β†’ 'user-space-name'. */ -function spaceIdToSubdomain(spaceId: string): string { - return spaceId - .toLowerCase() - .replace(/[/_.]/g, '-') - .replace(/-+/g, '-') - .replace(/^-|-$/g, ''); -} - -function buildTrackioEmbedUrl(spaceId: string, project?: string): string { - // __theme=dark is gradio's standard query param to force the embedded - // dashboard into dark mode so it blends with the surrounding chat instead - // of flashing a bright white panel inside the dark UI. - const params = new URLSearchParams({ - sidebar: 'hidden', - footer: 'false', - __theme: 'dark', - }); - if (project) params.set('project', project); - return `https://${spaceIdToSubdomain(spaceId)}.hf.space/?${params.toString()}`; -} - -function buildTrackioPageUrl(spaceId: string, project?: string): string { - const qs = project ? `?${new URLSearchParams({ project }).toString()}` : ''; - return `https://huggingface.co/spaces/${spaceId}${qs}`; -} - -function TrackioEmbed({ spaceId, project }: { spaceId: string; project?: string }) { - const [expanded, setExpanded] = useState(true); - const [iframeLoaded, setIframeLoaded] = useState(false); - const embedUrl = useMemo(() => buildTrackioEmbedUrl(spaceId, project), [spaceId, project]); - const pageUrl = useMemo(() => buildTrackioPageUrl(spaceId, project), [spaceId, project]); - const label = project ? `${spaceId} Β· ${project}` : spaceId; - - if (!isValidSpaceId(spaceId)) return null; - - return ( - - - e.stopPropagation()} - sx={{ - px: 1.25, - py: 0.5, - borderBottom: expanded ? '1px solid var(--tool-border)' : 'none', - }} - > - - trackio - - - {label} - - e.stopPropagation()} - sx={{ - display: 'inline-flex', - alignItems: 'center', - gap: 0.4, - color: 'var(--accent-yellow)', - fontSize: '0.65rem', - textDecoration: 'none', - '&:hover': { textDecoration: 'underline' }, - }} - > - - Open - - - - {expanded && ( - -