diff --git a/.gitattributes b/.gitattributes index 5c1fa543a2dcf0e292a5151a6d696f7f59a1556b..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +0,0 @@ -*.png filter=lfs diff=lfs merge=lfs -text -README.md merge=ours diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 5d79742fe97daa25e23740b7904a69439fd38368..0000000000000000000000000000000000000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: CI - -on: - pull_request: - push: - branches: [main] - -permissions: - contents: read - -concurrency: - group: ci-${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - ruff: - name: Ruff - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v5 - with: - enable-cache: true - cache-dependency-glob: uv.lock - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - - name: Install dependencies - run: uv sync --locked --extra dev - - - name: Run Ruff - run: uv run ruff check . - - - name: Check formatting - run: uv run ruff format --check . - - tests: - name: Tests - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v5 - with: - enable-cache: true - cache-dependency-glob: uv.lock - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - - name: Install dependencies - run: uv sync --locked --extra dev - - - name: Run tests - run: uv run pytest diff --git a/.github/workflows/claude-review.yml b/.github/workflows/claude-review.yml deleted file mode 100644 index 1304cfb9cf5efb059ae02ed071ef2030390802bf..0000000000000000000000000000000000000000 --- a/.github/workflows/claude-review.yml +++ /dev/null @@ -1,78 +0,0 @@ -name: Claude PR Review - -on: - pull_request_target: - types: [opened, synchronize, ready_for_review, reopened] - -permissions: - contents: read - pull-requests: write - issues: read - id-token: write - -concurrency: - group: claude-review-${{ github.event.pull_request.number }} - cancel-in-progress: true - -jobs: - review: - if: github.event.pull_request.draft == false - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - # On pull_request_target, keep checkout on the trusted base-repo ref. - # The Claude action can review the PR via GitHub context/API without - # executing untrusted fork code with repository secrets. - persist-credentials: false - - - name: Compose review prompt - id: compose - run: | - { - printf 'prompt<> "$GITHUB_OUTPUT" - - - name: Prepare Claude Code bin directory - run: mkdir -p "$HOME/.local/bin" - - - uses: anthropics/claude-code-action@v1 - with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - # Bypass the OIDC -> Claude GitHub App token exchange. That exchange - # rejects OIDC tokens minted for pull_request_target events with - # "401 Invalid OIDC token", which broke every review after the switch - # away from pull_request. Using the workflow's GITHUB_TOKEN works for - # both same-repo and fork PRs; comments post as github-actions[bot] - # instead of claude[bot], which is the documented trade-off. - github_token: ${{ secrets.GITHUB_TOKEN }} - track_progress: true - prompt: ${{ steps.compose.outputs.prompt }} diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml deleted file mode 100644 index d3036a23259e41c48a7efe2aa72a2d7c3c77bebf..0000000000000000000000000000000000000000 --- a/.github/workflows/claude.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Claude on Mention - -on: - issue_comment: - types: [created] - pull_request_review_comment: - types: [created] - pull_request_review: - types: [submitted] - issues: - types: [opened, assigned] - -permissions: - contents: write - pull-requests: write - issues: write - id-token: write - -jobs: - claude: - if: | - (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || - (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || - (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || - (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - uses: anthropics/claude-code-action@v1 - with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - track_progress: true diff --git a/.gitignore b/.gitignore index c10ab3552f66fe78ba9248272546e98a050789be..71fc3082173c89ba24599c3429094701bd0987c2 100644 --- a/.gitignore +++ b/.gitignore @@ -52,11 +52,7 @@ frontend/yarn-error.log* # Docker .docker/ -# Eval (stale) -eval/ - # Project-specific -scratch/ session_logs/ /logs hf-agent-leaderboard/ diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index 03f3bd9d98fee110961befde5d4ac148421f4b28..0000000000000000000000000000000000000000 --- a/AGENTS.md +++ /dev/null @@ -1,47 +0,0 @@ -# Agent Notes - -## Local Dev Servers - -- Frontend: from `frontend/`, run `npm ci` if dependencies are missing, then `npm run dev`. -- Backend: from `backend/`, run `uv run uvicorn main:app --host ::1 --port 7860`. -- Frontend URL: http://localhost:5173/ -- Backend health check: `curl -g http://[::1]:7860/api` -- Frontend proxy health check: `curl http://localhost:5173/api` - -Notes: - -- Vite proxies `/api` and `/auth` to `http://localhost:7860`. -- If `127.0.0.1:7860` is already owned by another local process, binding the backend to `::1` lets the Vite proxy resolve `localhost` cleanly. -- Prefer `npm ci` over `npm install` for setup, since `npm install` may rewrite `frontend/package-lock.json` metadata depending on npm version. -- Production defaults to the Bedrock Claude model. For local development with a personal Anthropic key, set `ANTHROPIC_API_KEY` and `ML_INTERN_CLAUDE_MODEL_ID=anthropic/claude-opus-4-6` before starting the backend. Other models are selected through the app's model switcher. - -## Development Checks - -- Before every commit, run `uv run ruff check .` and `uv run ruff format --check .`. -- If formatting fails, run `uv run ruff format .`, then re-run the Ruff checks before committing. - -## GitHub CLI - -- For multiline PR descriptions, prefer `gh pr edit --body-file ` over inline `--body` so shell quoting, `$` env-var names, backticks, and newlines are preserved correctly. - -## GitHub PRs - -- Open code changes as GitHub PRs first. Do not push code changes directly to the Hugging Face Space deployment branch or Space remote before the PR has been opened, reviewed, and merged, unless the user explicitly asks to bypass the PR flow. - -## Hugging Face Space Deploys - -- The Space remote is `space` and points to `https://huggingface.co/spaces/smolagents/ml-intern`. -- Deploy GitHub `main` to the Space from the local `space-main` branch by merging `origin/main` into `space-main` with a single merge commit, then pushing `space-main:main` to the `space` remote. -- Keep the Space-only README frontmatter on `space-main`; `.gitattributes` should contain `README.md merge=ours` and the local repo config should include `merge.ours.driver=true`. -- Local dev commonly uses a personal `HF_TOKEN`, but the deployed Space uses HF OAuth tokens. When adding Hub features, make sure the Space README `hf_oauth_scopes` frontmatter and the backend OAuth request in `backend/routes/auth.py` include the scopes required by the Hub APIs being called. A feature can work locally with a broad PAT and still fail in production with 403s if OAuth scopes are missing; after changing scopes, users may need to log out and log in again to receive a fresh token. -- Recommended deploy flow: - -```bash -git pull --ff-only origin main -git switch space-main -git config merge.ours.driver true -git merge --no-ff origin/main -m "Deploy $(date +%Y-%m-%d)" \ - -m "Co-authored-by: OpenAI Codex " -git push space space-main:main -git switch main -``` diff --git a/Dockerfile b/Dockerfile index 264dd3d9f97d6d2353e96611a6af7c90680f17b1..c4a876d8366ad4495f7da922f827bdaf1b9594fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,7 @@ COPY pyproject.toml uv.lock ./ # Install dependencies into /app/.venv # Use --frozen to ensure exact versions from uv.lock -RUN uv sync --no-dev --frozen +RUN uv sync --extra agent --no-dev --frozen # Copy application code COPY agent/ ./agent/ @@ -56,4 +56,4 @@ EXPOSE 7860 # Run the application from backend directory WORKDIR /app/backend -CMD ["bash", "start.sh"] +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64..0000000000000000000000000000000000000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md index 26b65da33edb1b3d882cdf9a6896a03030213195..fed2a689ba76144abbf89ce49d0b74a973325062 100644 --- a/README.md +++ b/README.md @@ -1,164 +1,57 @@ --- -title: ML Intern +title: HF Agent emoji: πŸ€– -colorFrom: yellow -colorTo: blue +colorFrom: blue +colorTo: purple sdk: docker app_port: 7860 hf_oauth: true -hf_oauth_expiration_minutes: 43200 hf_oauth_scopes: - read-repos - write-repos - contribute-repos - manage-repos - - write-collections - inference-api - jobs - write-discussions --- -

- smolagents logo -

+# HF Agent -# ML Intern +An MLE agent CLI with MCP (Model Context Protocol) integration and built-in tool support. -An ML intern that autonomously researches, writes, and ships good quality ML related code using the Hugging Face ecosystem β€” with deep access to docs, papers, datasets, and cloud compute. ## Quick Start ### Installation ```bash -git clone git@github.com:huggingface/ml-intern.git -cd ml-intern -uv sync -uv tool install -e . +# Clone the repository +git clone git@github.com:huggingface/hf_agent.git +cd hf_agent ``` -#### That's it. Now `ml-intern` works from any directory: - -```bash -ml-intern -``` - -Create a `.env` file in the project root (or export these in your shell): - -```bash -ANTHROPIC_API_KEY= # if using anthropic models -OPENAI_API_KEY= # if using openai models -HF_TOKEN= -GITHUB_TOKEN= -``` -If no `HF_TOKEN` is set, the CLI will prompt you to paste one on first launch. To get a GITHUB_TOKEN follow the tutorial [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token). - -### Usage - -**Interactive mode** (start a chat session): - +#### Install recommended dependencies ```bash -ml-intern +uv sync --extra agent # or uv sync --extra all ``` -**Headless mode** (single prompt, auto-approve): +### Interactive CLI ```bash -ml-intern "fine-tune llama on my dataset" -``` - -**Options:** - -```bash -ml-intern --model anthropic/claude-opus-4-6 "your prompt" -ml-intern --model openai/gpt-5.5 "your prompt" -ml-intern --max-iterations 100 "your prompt" -ml-intern --no-stream "your prompt" -``` - -## Sharing Traces - -Every session is auto-uploaded to your **own private Hugging Face dataset** -in [Claude Code JSONL format](https://huggingface.co/changelog/agent-trace-viewer), -which the HF Agent Trace Viewer auto-detects so you can browse turns, tool -calls, and model responses directly on the Hub. - -By default the dataset is named `{your-hf-username}/ml-intern-sessions` and is -**created private**. You can flip it to public from inside the CLI: - -```bash -/share-traces # show current visibility + dataset URL -/share-traces public # publish (anyone can view) -/share-traces private # lock it back down -``` - -You can also flip visibility from the dataset page on huggingface.co β€” the -agent honours whatever you set there for subsequent uploads. - -To opt out entirely, set in your CLI config (e.g. `configs/cli_agent_config.json` -or `~/.config/ml-intern/cli_agent_config.json`): - -```json -{ "share_traces": false } -``` - -To override the destination repo, set: - -```json -{ "personal_trace_repo_template": "{hf_user}/my-custom-traces" } +uv run python -m agent.main ``` +This starts an interactive chat session with the agent. Type your messages and the agent will respond, using tools as needed. -The shared `smolagents/ml-intern-sessions` dataset is unrelated and only -receives anonymized telemetry rows used by the backend KPI scheduler. +The agent will automatically discover and register all tools from configured MCP servers. -## Supported Gateways - -ML Intern currently supports one-way notification gateways from CLI sessions. -These gateways send out-of-band status updates; they do not accept inbound chat -messages. - -### Slack - -Slack notifications use the Slack Web API to post messages when the agent needs -approval, hits an error, or completes a turn. Create a Slack app with a bot token -that has `chat:write`, invite the bot to the target channel, then set: +### Env Setup ```bash -SLACK_BOT_TOKEN=xoxb-... -SLACK_CHANNEL_ID=C... -``` - -The CLI automatically creates a `slack.default` destination when both variables -are present. Optional environment variables for the env-only default: - -```bash -ML_INTERN_SLACK_NOTIFICATIONS=false -ML_INTERN_SLACK_DESTINATION=slack.ops -ML_INTERN_SLACK_AUTO_EVENTS=approval_required,error,turn_complete -ML_INTERN_SLACK_ALLOW_AGENT_TOOL=true -ML_INTERN_SLACK_ALLOW_AUTO_EVENTS=true -``` - -For a persistent user-level config, put overrides in -`~/.config/ml-intern/cli_agent_config.json` or point `ML_INTERN_CLI_CONFIG` at a -JSON file: - -```json -{ - "messaging": { - "enabled": true, - "auto_event_types": ["approval_required", "error", "turn_complete"], - "destinations": { - "slack.ops": { - "provider": "slack", - "token": "${SLACK_BOT_TOKEN}", - "channel": "${SLACK_CHANNEL_ID}", - "allow_agent_tool": true, - "allow_auto_events": true - } - } - } -} +ANTHROPIC_API_KEY= +HF_TOKEN= +GITHUB_TOKEN= +HF_NAMESPACE= ``` ## Architecture @@ -167,70 +60,62 @@ JSON file: ``` β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ User/CLI β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - β”‚ Operations β”‚ Events - ↓ (user_input, exec_approval, ↑ - submission_queue interrupt, compact, ...) event_queue - β”‚ β”‚ - ↓ β”‚ -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ submission_loop (agent_loop.py) β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ 1. Receive Operation from queue β”‚ β”‚ β”‚ -β”‚ β”‚ 2. Route to handler (run_agent/compact/...) β”‚ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ ↓ β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ Handlers.run_agent() β”‚ β”œβ”€β”€β”€ -β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ Agentic Loop (max 300 iterations) β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ Session β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ ContextManager β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Message history β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ (litellm.Message[]) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Auto-compaction (170k) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Session upload to HF β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ ToolRouter β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ HF docs & research β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ HF repos, datasets, β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ jobs, papers β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ GitHub code search β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ Sandbox & local tools β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ Planning β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ └─ MCP server tools β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ Doom Loop Detector β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β€’ Detects repeated tool patterns β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β€’ Injects corrective prompts β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ Loop: β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 1. LLM call (litellm.acompletion) β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 2. Parse tool_calls[] β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 3. Approval check β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ (jobs, sandbox, destructive ops) β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 4. Execute via ToolRouter β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 5. Add results to ContextManager β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ 6. Repeat if tool_calls exist β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”˜ +β”‚ User/CLI β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ User request β”‚ Events + ↓ ↑ + submission_queue event_queue + β”‚ β”‚ + ↓ β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ submission_loop (agent_loop.py) β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ +β”‚ β”‚ 1. Receive Operation from queue β”‚ β”‚ β”‚ +β”‚ β”‚ 2. Route to Handler (run_agent/compact/...) β”‚ β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ +β”‚ ↓ β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ +β”‚ β”‚ Handlers.run_agent() β”‚ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ β”‚ β”‚ β”‚ Emit β”‚ +β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ Events β”‚ +β”‚ β”‚ β”‚ Agentic Loop (max 10 iterations) β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ Session β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ ContextManager β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Message history β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ (litellm.Message[]) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β€’ Auto-compaction (180k) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ ToolRouter β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ explore_hf_docs β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ fetch_hf_docs β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ find_hf_api β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ plan_tool β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ hf_jobs* β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ hf_private_repos* β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ github_* (3 tools) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ └─ MCP tools (e.g., β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ model_search, etc.) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ Loop: β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 1. LLM call (litellm.acompletion) β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 2. Parse tool_calls[] β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 3. Execute via ToolRouter β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 4. Add results to ContextManager β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ ↓ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ 5. Repeat if tool_calls exist β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ ``` ### Agentic Loop Flow @@ -240,49 +125,61 @@ User Message ↓ [Add to ContextManager] ↓ - ╔═══════════════════════════════════════════╗ - β•‘ Iteration Loop (max 300) β•‘ - β•‘ β•‘ - β•‘ Get messages + tool specs β•‘ - β•‘ ↓ β•‘ - β•‘ litellm.acompletion() β•‘ - β•‘ ↓ β•‘ - β•‘ Has tool_calls? ──No──> Done β•‘ - β•‘ β”‚ β•‘ - β•‘ Yes β•‘ - β•‘ ↓ β•‘ - β•‘ Add assistant msg (with tool_calls) β•‘ - β•‘ ↓ β•‘ - β•‘ Doom loop check β•‘ - β•‘ ↓ β•‘ - β•‘ For each tool_call: β•‘ - β•‘ β€’ Needs approval? ──Yes──> Wait for β•‘ - β•‘ β”‚ user confirm β•‘ - β•‘ No β•‘ - β•‘ ↓ β•‘ - β•‘ β€’ ToolRouter.execute_tool() β•‘ - β•‘ β€’ Add result to ContextManager β•‘ - β•‘ ↓ β•‘ - β•‘ Continue loop ─────────────────┐ β•‘ - β•‘ ↑ β”‚ β•‘ - β•‘ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β•‘ - β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β• + ╔═══════════════════════════════════════╗ + β•‘ Iteration Loop (max 10) β•‘ + β•‘ β•‘ + β•‘ Get messages + tool specs β•‘ + β•‘ ↓ β•‘ + β•‘ litellm.acompletion() β•‘ + β•‘ ↓ β•‘ + β•‘ Has tool_calls? ──No──> Done β•‘ + β•‘ β”‚ β•‘ + β•‘ Yes β•‘ + β•‘ ↓ β•‘ + β•‘ Add assistant msg (with tool_calls) β•‘ + β•‘ ↓ β•‘ + β•‘ For each tool_call: β•‘ + β•‘ β€’ ToolRouter.execute_tool() β•‘ + β•‘ β€’ Add result to ContextManager β•‘ + β•‘ ↓ β•‘ + β•‘ Continue loop ─────────────────┐ β•‘ + β•‘ ↑ β”‚ β•‘ + β•šβ•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β• +``` + +## Project Structure + +``` +agent/ +β”œβ”€β”€ config.py # Configuration models +β”œβ”€β”€ main.py # Interactive CLI entry point +β”œβ”€β”€ prompts/ +β”‚ └── system_prompt.yaml # Agent behavior and personality +β”œβ”€β”€ context_manager/ +β”‚ └── manager.py # Message history & auto-compaction +└── core/ + β”œβ”€β”€ agent_loop.py # Main agent loop and handlers + β”œβ”€β”€ session.py # Session management + β”œβ”€β”€ mcp_client.py # MCP SDK integration + └── tools.py # ToolRouter and built-in tools + +configs/ +└── main_agent_config.json # Model and MCP server configuration + +tests/ # Integration and unit tests +eval/ # Evaluation suite (see eval/README.md) ``` + ## Events The agent emits the following events via `event_queue`: - `processing` - Starting to process user input -- `ready` - Agent is ready for input -- `assistant_chunk` - Streaming token chunk -- `assistant_message` - Complete LLM response text -- `assistant_stream_end` - Token stream finished +- `assistant_message` - LLM response text - `tool_call` - Tool being called with arguments - `tool_output` - Tool execution result -- `tool_log` - Informational tool log message -- `tool_state_change` - Tool execution state transition -- `approval_required` - Requesting user approval for sensitive operations +- `approval_request` - Requesting user approval for sensitive operations - `turn_complete` - Agent finished processing - `error` - Error occurred during processing - `interrupted` - Agent was interrupted @@ -317,8 +214,7 @@ def create_builtin_tools() -> list[ToolSpec]: ### Adding MCP Servers -Edit `configs/cli_agent_config.json` for CLI defaults, or -`configs/frontend_agent_config.json` for web-session defaults: +Edit `configs/main_agent_config.json`: ```json { diff --git a/REVIEW.md b/REVIEW.md deleted file mode 100644 index 3f08c60a8a43022497a58e3d547433bd5ecb5c24..0000000000000000000000000000000000000000 --- a/REVIEW.md +++ /dev/null @@ -1,135 +0,0 @@ -# Review instructions - -These rules override the default review guidance. Treat them as the highest-priority -instruction block for any review of this repo. If something here contradicts a more -generic review habit, follow these. - -## Severity levels - -Every finding carries one of three priority labels: - -- **P0** β€” blocks merge. -- **P1** β€” worth fixing, not blocking. -- **P2** β€” informational. - -Write labels as plain text (`P0`, `P1`, `P2`) in finding headers. Do not use -emoji or colored markers. Use judgment on what belongs at which level β€” this -repo does not enumerate P0 cases; read the code and decide. - -## Default bias: rigor - -Reviews gate merges. This is an open-source repo that takes PRs from anyone; the -maintainer team is small and relies on the review to catch what they don't have -time to verify themselves. **Default bias is rigor, not speed.** When in doubt -on a P0-class concern, investigate further before deciding whether to flag β€” a -false negative ships a bug to production, a false positive costs the contributor -one round trip. - -Rigor is not nitpicking. The P1 cap, "do not report" skip list, and verification -bar all still apply. Rigor means going deep on a small number of real concerns, -not surfacing a large number of shallow ones. Prefer one well-investigated P0 -over three speculative P1s. - -**Hold the line on P0.** If the author pushes back on a P0 finding without a fix -that actually addresses the root cause, re-state the concern with added -citations. Only accept the pushback if the author points to code or behavior you -missed. Do not soften a P0 because the contributor is polite or new to the repo. - -For P1 and P2: if the author defers or pushes back without fixing, accept it -silently β€” do not re-flag on subsequent commits. P1/P2 are informational; the -author may defer to a follow-up issue at their discretion. - -If Claude and the author repeatedly disagree on the same class of finding, the -signal is that REVIEW.md is missing a rule; note it once in the PR summary as -`suggest-rule: ` and stop. - -## Investigate before posting - -The depth of your analysis determines the strength of your finding. For any -P0-class concern, before writing it up: - -- Read the relevant callers and callees, not just the diff. Use Read and Grep - to open files the diff doesn't touch but the changed code interacts with. -- Trace the full chain end-to-end for routing, auth, and agent-loop findings. - Cite each hop by `file:line`, not just the suspicious line. -- Check whether the codebase already has an established pattern for this kind - of change (`grep` for similar call sites, similar tool definitions, similar - route guards). If the PR introduces a new approach where an established - pattern exists, flag that β€” divergence from the existing pattern is usually a - regression vector even when the new code "works." -- Confirm the specific behavior you're claiming. "This breaks X" must be - grounded in either the code handling X or a test exercising X, not in - inference from naming or structure. - -A finding you "spotted" by scanning the diff is more likely to be a false -positive than a finding you verified by reading the code around it. - -## P1 cap - -Report at most **3** P1 findings per review. If you found more, say "plus N -similar items" in the summary. If everything you found is P1 or below, open the -summary with "No blocking issues." - -## Re-review convergence - -If this PR has already received a Claude review (there is a prior review comment -by the `claude` bot), suppress new P1 findings and post only P0 ones. Do not -re-post P1s that were already flagged on earlier commits. If the author pushed a -fix for a previously flagged issue, acknowledge it in one line rather than -re-flagging. - -## Do not report - -Anything in these paths β€” skip entirely: - -- `frontend/node_modules/**`, `**/*.lock`, `uv.lock`, `package-lock.json` -- `hf_agent.egg-info/**`, `.ruff_cache/**`, `.pytest_cache/**`, `.venv/**` -- `session_logs/**`, `reports/**` -- Anything under a `gen/` or `generated/` path - -Anything speculative β€” do not post: - -- "This might be slow" without a concrete complexity claim tied to a specific - input size -- Hypothetical race conditions without a concrete interleaving - -## Dependency PRs - -For PRs whose diff is only a lockfile bump, a `pyproject.toml` change, or a -new dependency, the code rules above don't apply β€” risks shift to provenance -and framing. Every claim in the title or body (CVE IDs, version numbers, -behavior fixes) must match what the diff actually does, and any new -transitive dep needs justification. A PR that lies in its framing is P0 -regardless of whether the code change is safe in isolation. - -## Verification bar - -Every behavior claim in a finding must cite `file:line`. "This breaks X" is not -actionable without a line reference. If you cannot cite a line, do not post -the finding. - -## Summary shape - -Open the review body with a single-line tally and an explicit merge verdict, on -two lines: - -``` -2 P0, 3 P1 -Verdict: changes requested -``` - -Valid verdicts: - -- **Verdict: ready to merge** β€” no P0 findings, contributor can merge as-is - once any CI passes -- **Verdict: changes requested** β€” at least one P0 that must be addressed - before merging -- **Verdict: needs discussion** β€” a design-level concern the maintainer should - weigh in on before the contributor iterates (use sparingly) - -If it's a clean review, write `LGTM` followed by `Verdict: ready to merge`. - -Then a **What I checked** bullet list β€” one line per major area you examined, -regardless of whether you found anything. This gives the maintainer visible -coverage at a glance and lets them decide whether to spot-check areas you -didn't touch. diff --git a/agent/__init__.py b/agent/__init__.py index 2e301c8d7b97df90efb932a3685a5c401326232e..3528882f8728ddce586748e8256755fe5b2ea6ad 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -2,20 +2,6 @@ HF Agent - Main agent module """ -import litellm - -# Global LiteLLM behavior β€” set once at package import so both CLI and -# backend entries share the same config. -# drop_params: quietly drop unsupported params rather than raising -# suppress_debug_info: hide the noisy "Give Feedback" banner on errors -# modify_params: let LiteLLM patch Anthropic's tool-call requirements -# (synthesize a dummy tool spec when we call completion on a history -# that contains tool_calls but aren't passing `tools=` β€” happens -# during summarization / session seeding). -litellm.drop_params = True -litellm.suppress_debug_info = True -litellm.modify_params = True - -from agent.core.agent_loop import submission_loop # noqa: E402 +from agent.core.agent_loop import submission_loop __all__ = ["submission_loop"] diff --git a/agent/config.py b/agent/config.py index 35b095c328fe64b53eb51ef5126ebec7e6f546e4..f2582b3f760e61ae97d7ec250dcbc465ece40d98 100644 --- a/agent/config.py +++ b/agent/config.py @@ -1,7 +1,6 @@ import json import os import re -from pathlib import Path from typing import Any, Union from dotenv import load_dotenv @@ -11,14 +10,9 @@ from fastmcp.mcp_config import ( ) from pydantic import BaseModel -from agent.messaging.models import MessagingConfig - # These two are the canonical server config types for MCP servers. MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer] -# Project root: two levels up from this file (agent/config.py -> project root) -_PROJECT_ROOT = Path(__file__).resolve().parent.parent - class Config(BaseModel): """Configuration manager""" @@ -26,139 +20,14 @@ class Config(BaseModel): model_name: str mcpServers: dict[str, MCPServerConfig] = {} save_sessions: bool = True - session_dataset_repo: str = "smolagents/ml-intern-sessions" - # Per-user private dataset that mirrors each session in Claude Code JSONL - # format so the HF Agent Trace Viewer auto-renders it - # (https://huggingface.co/changelog/agent-trace-viewer). Created private - # on first use; user flips it public via /share-traces. ``{hf_user}`` is - # substituted at upload time from the authenticated HF username. - share_traces: bool = True - personal_trace_repo_template: str = "{hf_user}/ml-intern-sessions" - auto_save_interval: int = 1 # Save every N user turns (0 = disabled) - # Mid-turn heartbeat: save + upload every N seconds while events are being - # emitted. Guards against losing trace data on long-running turns that - # crash before turn_complete (e.g. a multi-hour hf_jobs wait that OOMs). - # 0 = disabled. Consumed by agent.core.telemetry.HeartbeatSaver. - heartbeat_interval_s: int = 60 + session_dataset_repo: str = "akseljoonas/hf-agent-sessions" + auto_save_interval: int = 3 # Save every N user turns (0 = disabled) yolo_mode: bool = False # Auto-approve all tool calls without confirmation - max_iterations: int = 300 # Max LLM calls per agent turn (-1 = unlimited) # Permission control parameters confirm_cpu_jobs: bool = True auto_file_upload: bool = False - # Reasoning effort *preference* β€” the ceiling the user wants. The probe - # on `/model` walks a cascade down from here (``max`` β†’ ``xhigh`` β†’ ``high`` - # β†’ …) and caches per-model what the provider actually accepted in - # ``Session.model_effective_effort``. Default ``max`` because we'd rather - # burn tokens thinking than ship a wrong ML recipe; the cascade lands on - # whichever level the model supports (``high`` for GPT-5 / HF router, - # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off. - # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max" - reasoning_effort: str | None = "max" - messaging: MessagingConfig = MessagingConfig() - - -USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG" -DEFAULT_USER_CONFIG_PATH = ( - Path.home() / ".config" / "ml-intern" / "cli_agent_config.json" -) -SLACK_DEFAULT_DESTINATION = "slack.default" -SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"] - - -def _deep_merge_config( - base: dict[str, Any], override: dict[str, Any] -) -> dict[str, Any]: - merged = dict(base) - for key, value in override.items(): - current = merged.get(key) - if isinstance(current, dict) and isinstance(value, dict): - merged[key] = _deep_merge_config(current, value) - else: - merged[key] = value - return merged - - -def _load_json_config(path: Path) -> dict[str, Any]: - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - if not isinstance(data, dict): - raise ValueError(f"Config file {path} must contain a JSON object") - return data - - -def _load_user_config() -> dict[str, Any]: - raw_path = os.environ.get(USER_CONFIG_ENV_VAR) - if raw_path: - path = Path(raw_path).expanduser() - if not path.exists(): - raise FileNotFoundError( - f"{USER_CONFIG_ENV_VAR} points to missing config file: {path}" - ) - return _load_json_config(path) - - if DEFAULT_USER_CONFIG_PATH.exists(): - return _load_json_config(DEFAULT_USER_CONFIG_PATH) - return {} - - -def _env_bool(name: str, default: bool) -> bool: - value = os.environ.get(name) - if value is None: - return default - normalized = value.strip().lower() - if normalized in {"1", "true", "yes", "on"}: - return True - if normalized in {"0", "false", "no", "off"}: - return False - return default - - -def _env_list(name: str) -> list[str] | None: - value = os.environ.get(name) - if value is None: - return None - return [item.strip() for item in value.split(",") if item.strip()] - - -def apply_slack_user_defaults(raw_config: dict[str, Any]) -> dict[str, Any]: - """Enable a default Slack destination from user env vars, when present.""" - if not _env_bool("ML_INTERN_SLACK_NOTIFICATIONS", True): - return raw_config - - token = os.environ.get("SLACK_BOT_TOKEN") - channel = os.environ.get("SLACK_CHANNEL_ID") or os.environ.get("SLACK_CHANNEL") - if not token or not channel: - return raw_config - - config = dict(raw_config) - messaging = dict(config.get("messaging") or {}) - destinations = dict(messaging.get("destinations") or {}) - destination_name = ( - os.environ.get("ML_INTERN_SLACK_DESTINATION") or SLACK_DEFAULT_DESTINATION - ).strip() - - if destination_name not in destinations: - destinations[destination_name] = { - "provider": "slack", - "token": token, - "channel": channel, - "allow_agent_tool": _env_bool("ML_INTERN_SLACK_ALLOW_AGENT_TOOL", True), - "allow_auto_events": _env_bool("ML_INTERN_SLACK_ALLOW_AUTO_EVENTS", True), - } - - auto_events = _env_list("ML_INTERN_SLACK_AUTO_EVENTS") - if auto_events is not None: - messaging["auto_event_types"] = auto_events - elif "auto_event_types" not in messaging: - messaging["auto_event_types"] = SLACK_DEFAULT_AUTO_EVENT_TYPES - - messaging["enabled"] = True - messaging["destinations"] = destinations - config["messaging"] = messaging - return config - def substitute_env_vars(obj: Any) -> Any: """ @@ -197,25 +66,18 @@ def substitute_env_vars(obj: Any) -> Any: return obj -def load_config( - config_path: str = "config.json", - include_user_defaults: bool = False, -) -> Config: +def load_config(config_path: str = "config.json") -> Config: """ Load configuration with environment variable substitution. Use ${VAR_NAME} in your JSON for any secret. Automatically loads from .env file. """ - # Load .env from project root first (so it works from any directory), - # then CWD .env can override if present - load_dotenv(_PROJECT_ROOT / ".env") - load_dotenv(override=False) - - raw_config = _load_json_config(Path(config_path)) - if include_user_defaults: - raw_config = _deep_merge_config(raw_config, _load_user_config()) - raw_config = apply_slack_user_defaults(raw_config) + # Load environment variables from .env file + load_dotenv() + + with open(config_path, "r") as f: + raw_config = json.load(f) config_with_env = substitute_env_vars(raw_config) return Config.model_validate(config_with_env) diff --git a/agent/context_manager/manager.py b/agent/context_manager/manager.py index 85e96af0f6f3fa6d0426acddcd308281d502558b..1f74edc025aa864de75534efba6b863f44678842 100644 --- a/agent/context_manager/manager.py +++ b/agent/context_manager/manager.py @@ -3,7 +3,7 @@ Context management for conversation history """ import logging -import time +import os import zoneinfo from datetime import datetime from pathlib import Path @@ -13,16 +13,17 @@ import yaml from jinja2 import Template from litellm import Message, acompletion -from agent.core.prompt_caching import with_prompt_caching - logger = logging.getLogger(__name__) +# Module-level cache for HF username β€” avoids repeating the slow whoami() call +_hf_username_cache: str | None = None + _HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2" _HF_WHOAMI_TIMEOUT = 5 # seconds -def _get_hf_username(hf_token: str | None = None) -> str: - """Return the HF username for the given token. +def _get_hf_username() -> str: + """Return the HF username, cached after the first call. Uses subprocess + curl to avoid Python HTTP client IPv6 issues that cause 40+ second hangs (httpx/urllib try IPv6 first which times out @@ -32,9 +33,15 @@ def _get_hf_username(hf_token: str | None = None) -> str: import subprocess import time as _t + global _hf_username_cache + if _hf_username_cache is not None: + return _hf_username_cache + + hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") if not hf_token: - logger.warning("No hf_token provided, using 'unknown' as username") - return "unknown" + logger.warning("No HF_TOKEN set, using 'unknown' as username") + _hf_username_cache = "unknown" + return _hf_username_cache t0 = _t.monotonic() try: @@ -56,119 +63,21 @@ def _get_hf_username(hf_token: str | None = None) -> str: t1 = _t.monotonic() if result.returncode == 0 and result.stdout: data = json.loads(result.stdout) - username = data.get("name", "unknown") - logger.info(f"HF username resolved to '{username}' in {t1 - t0:.2f}s") - return username + _hf_username_cache = data.get("name", "unknown") + logger.info( + f"HF username resolved to '{_hf_username_cache}' in {t1 - t0:.2f}s" + ) else: logger.warning( f"curl whoami failed (rc={result.returncode}) in {t1 - t0:.2f}s" ) - return "unknown" + _hf_username_cache = "unknown" except Exception as e: t1 = _t.monotonic() logger.warning(f"HF whoami failed in {t1 - t0:.2f}s: {e}") - return "unknown" - - -_COMPACT_PROMPT = ( - "Please provide a concise summary of the conversation above, focusing on " - "key decisions, the 'why' behind the decisions, problems solved, and " - "important context needed for developing further. Your summary will be " - "given to someone who has never worked on this project before and they " - "will be have to be filled in." -) - -# Per-message ceiling. If a single message in the "untouched" tail is larger -# than this, compaction can't recover even after summarizing the middle β€” -# producing the infinite compaction loop seen 2026-05-03 in pod logs (200k -# context shrinks to 200k+ because one tool output is 80k tokens). We replace -# such messages with a placeholder before compaction runs. -_MAX_TOKENS_PER_MESSAGE = 50_000 - - -class CompactionFailedError(Exception): - """Raised when compaction can't reduce context below the threshold. - - Typically means an individual preserved message (system, first user, or - untouched tail) exceeds what truncation can fix in one pass. The caller - must terminate the session β€” retrying produces an infinite loop that - burns Bedrock budget for free (~$3 per re-attempt on Opus). - """ - - -# Used when seeding a brand-new session from prior browser-cached messages. -# Here we're writing a note to *ourselves* β€” so preserve the tool-call trail, -# files produced, and planned next steps in first person. Optimized for -# continuity, not brevity. -_RESTORE_PROMPT = ( - "You're about to be restored into a fresh session with no memory of the " - "conversation above. Write a first-person note to your future self so " - "you can continue right where you left off. Include:\n" - " β€’ What the user originally asked for and what progress you've made.\n" - " β€’ Every tool you called, with arguments and a one-line result summary.\n" - " β€’ Any code, files, scripts, or artifacts you produced (with paths).\n" - " β€’ Key decisions and the reasoning behind them.\n" - " β€’ What you were planning to do next.\n\n" - "Don't be cute. Be specific. This is the only context you'll have." -) - - -async def summarize_messages( - messages: list[Message], - model_name: str, - hf_token: str | None = None, - max_tokens: int = 2000, - tool_specs: list[dict] | None = None, - prompt: str = _COMPACT_PROMPT, - session: Any = None, - kind: str = "compaction", -) -> tuple[str, int]: - """Run a summarization prompt against a list of messages. - - ``prompt`` defaults to the compaction prompt (terse, decision-focused). - Callers seeding a new session after a restart should pass ``_RESTORE_PROMPT`` - instead β€” it preserves the tool-call trail so the agent can answer - follow-up questions about what it did. - - ``session`` is optional; when provided, the call is recorded via - ``telemetry.record_llm_call`` so its cost lands in the session's - ``total_cost_usd``. Without it, the call still happens but is - invisible in telemetry β€” which used to be the case for every - compaction call until 2026-04-29 (~30-50% of Bedrock spend was - attributed to this single source of dark cost). - - Returns ``(summary_text, completion_tokens)``. - """ - from agent.core.llm_params import _resolve_llm_params - - prompt_messages = list(messages) + [Message(role="user", content=prompt)] - llm_params = _resolve_llm_params(model_name, hf_token, reasoning_effort="high") - prompt_messages, tool_specs = with_prompt_caching( - prompt_messages, tool_specs, llm_params.get("model") - ) - _t0 = time.monotonic() - response = await acompletion( - messages=prompt_messages, - max_completion_tokens=max_tokens, - tools=tool_specs, - **llm_params, - ) - if session is not None: - from agent.core import telemetry + _hf_username_cache = "unknown" - await telemetry.record_llm_call( - session, - model=model_name, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind=kind, - ) - summary = response.choices[0].message.content or "" - completion_tokens = response.usage.completion_tokens if response.usage else 0 - return summary, completion_tokens + return _hf_username_cache class ContextManager: @@ -176,39 +85,26 @@ class ContextManager: def __init__( self, - model_max_tokens: int = 180_000, + max_context: int = 180_000, compact_size: float = 0.1, untouched_messages: int = 5, tool_specs: list[dict[str, Any]] | None = None, - prompt_file_suffix: str = "system_prompt_v3.yaml", - hf_token: str | None = None, - local_mode: bool = False, + prompt_file_suffix: str = "system_prompt_v2.yaml", ): self.system_prompt = self._load_system_prompt( tool_specs or [], - prompt_file_suffix="system_prompt_v3.yaml", - hf_token=hf_token, - local_mode=local_mode, + prompt_file_suffix="system_prompt_v2.yaml", ) - # The model's real input-token ceiling (from litellm.get_model_info). - # Compaction triggers at _COMPACT_THRESHOLD_RATIO below it β€” see - # the compaction_threshold property. - self.model_max_tokens = model_max_tokens - self.compact_size = int(model_max_tokens * compact_size) - # Running count of tokens the last LLM call reported. Drives the - # compaction gate; updated in add_message() with each response's - # usage.total_tokens. - self.running_context_usage = 0 + self.max_context = max_context + self.compact_size = int(max_context * compact_size) + self.context_length = len(self.system_prompt) // 4 self.untouched_messages = untouched_messages self.items: list[Message] = [Message(role="system", content=self.system_prompt)] - self.on_message_added = None def _load_system_prompt( self, tool_specs: list[dict[str, Any]], prompt_file_suffix: str = "system_prompt.yaml", - hf_token: str | None = None, - local_mode: bool = False, ): """Load and render the system prompt from YAML file with Jinja2""" prompt_file = Path(__file__).parent.parent / "prompts" / f"{prompt_file_suffix}" @@ -224,374 +120,78 @@ class ContextManager: current_time = now.strftime("%H:%M:%S.%f")[:-3] current_timezone = f"{now.strftime('%Z')} (UTC{now.strftime('%z')[:3]}:{now.strftime('%z')[3:]})" - # Get HF user info from OAuth token - hf_user_info = _get_hf_username(hf_token) + # Get HF user info (cached after the first call) + hf_user_info = _get_hf_username() template = Template(template_str) - static_prompt = template.render( + return template.render( tools=tool_specs, num_tools=len(tool_specs), - ) - - # CLI-specific context for local mode - if local_mode: - import os - - cwd = os.getcwd() - local_context = ( - f"\n\n# CLI / Local mode\n\n" - f"You are running as a local CLI tool on the user's machine. " - f"There is NO sandbox β€” bash, read, write, and edit operate directly " - f"on the local filesystem.\n\n" - f"Working directory: {cwd}\n" - f"Use absolute paths or paths relative to the working directory. " - f"Do NOT use /app/ paths β€” that is a sandbox convention that does not apply here.\n" - f"The sandbox_create tool is NOT available. Run code directly with bash." - ) - static_prompt += local_context - - return ( - f"{static_prompt}\n\n" - f"[Session context: Date={current_date}, Time={current_time}, " - f"Timezone={current_timezone}, User={hf_user_info}, " - f"Tools={len(tool_specs)}]" + current_date=current_date, + current_time=current_time, + current_timezone=current_timezone, + hf_user_info=hf_user_info, ) def add_message(self, message: Message, token_count: int = None) -> None: """Add a message to the history""" if token_count: - self.running_context_usage = token_count + self.context_length = token_count self.items.append(message) - if self.on_message_added: - self.on_message_added(message) def get_messages(self) -> list[Message]: - """Get all messages for sending to LLM. - - Patches any dangling tool_calls (assistant messages with tool_calls - that have no matching tool-result message) so the LLM API doesn't - reject the request. - """ - self._patch_dangling_tool_calls() + """Get all messages for sending to LLM""" return self.items - @staticmethod - def _normalize_tool_calls(msg: Message) -> None: - """Ensure msg.tool_calls contains proper ToolCall objects, not dicts. - - litellm's Message has validate_assignment=False (Pydantic v2 default), - so direct attribute assignment (e.g. inside litellm's streaming handler) - can leave raw dicts. Re-assigning via the constructor fixes this. - """ - from litellm import ChatCompletionMessageToolCall as ToolCall - - tool_calls = getattr(msg, "tool_calls", None) - if not tool_calls: - return - needs_fix = any(isinstance(tc, dict) for tc in tool_calls) - if not needs_fix: - return - msg.tool_calls = [ - tc if not isinstance(tc, dict) else ToolCall(**tc) for tc in tool_calls - ] - - def _patch_dangling_tool_calls(self) -> None: - """Add stub tool results for any tool_calls that lack a matching result. - - Ensures each assistant message's tool_calls are followed immediately - by matching tool-result messages. This has to work across the whole - history, not just the most recent turn, because a cancelled tool use - in an earlier turn can still poison the next provider request. - """ - if not self.items: - return - - i = 0 - while i < len(self.items): - msg = self.items[i] - if getattr(msg, "role", None) != "assistant" or not getattr( - msg, "tool_calls", None - ): - i += 1 - continue - - self._normalize_tool_calls(msg) - - # Consume the contiguous tool-result block that immediately follows - # this assistant message. Any missing tool ids must be inserted - # before the next non-tool message to satisfy provider ordering. - j = i + 1 - immediate_ids: set[str | None] = set() - while ( - j < len(self.items) and getattr(self.items[j], "role", None) == "tool" - ): - immediate_ids.add(getattr(self.items[j], "tool_call_id", None)) - j += 1 - - missing: list[Message] = [] - for tc in msg.tool_calls: - if tc.id not in immediate_ids: - missing.append( - Message( - role="tool", - content="Tool was not executed (interrupted or error).", - tool_call_id=tc.id, - name=tc.function.name, - ) - ) - - if missing: - self.items[j:j] = missing - j += len(missing) - - i = j - - def undo_last_turn(self) -> bool: - """Remove the last complete turn (user msg + all assistant/tool msgs that follow). - - Pops from the end until the last user message is removed, keeping the - tool_use/tool_result pairing valid. Never removes the system message. - - Returns True if a user message was found and removed. - """ - if len(self.items) <= 1: - return False - - while len(self.items) > 1: - msg = self.items.pop() - if getattr(msg, "role", None) == "user": - return True - - return False - - def truncate_to_user_message(self, user_message_index: int) -> bool: - """Truncate history to just before the Nth user message (0-indexed). - - Removes that user message and everything after it. - System message (index 0) is never removed. - - Returns True if the target user message was found and removed. - """ - count = 0 - for i, msg in enumerate(self.items): - if i == 0: - continue # skip system message - if getattr(msg, "role", None) == "user": - if count == user_message_index: - self.items = self.items[:i] - return True - count += 1 - return False - - # Compaction fires at 90% of model_max_tokens so there's headroom for - # the next turn's prompt + response before we actually hit the ceiling. - _COMPACT_THRESHOLD_RATIO = 0.9 - - @property - def compaction_threshold(self) -> int: - """Token count at which `compact()` kicks in.""" - return int(self.model_max_tokens * self._COMPACT_THRESHOLD_RATIO) - - @property - def needs_compaction(self) -> bool: - return self.running_context_usage > self.compaction_threshold and bool( - self.items - ) - - def _truncate_oversized( - self, messages: list[Message], model_name: str - ) -> list[Message]: - """Replace any message > _MAX_TOKENS_PER_MESSAGE with a placeholder. - - These are typically tool outputs (CSV dumps, file contents) sitting in - the untouched tail or first-user position that compaction can't shrink - β€” they pass through verbatim, keeping context above threshold and - triggering an infinite compaction retry loop. - """ - from litellm import token_counter - - out: list[Message] = [] - for msg in messages: - # System messages are sacred β€” they're the agent's instructions. - # In edge cases (items < untouched_messages), the slice math in - # compact() can let items[0] (the system message) leak into the - # recent_messages list. Defense-in-depth: never truncate it. - if msg.role == "system": - out.append(msg) - continue - try: - n = token_counter(model=model_name, messages=[msg.model_dump()]) - except Exception: - # token_counter occasionally fails on edge-case content; - # don't drop the message, just keep it as-is. - out.append(msg) - continue - if n <= _MAX_TOKENS_PER_MESSAGE: - out.append(msg) - continue - placeholder = ( - f"[truncated for compaction β€” original was {n} tokens, " - f"removed to keep context under {self.compaction_threshold} tokens]" - ) - logger.warning( - "Truncating %s message: %d -> %d tokens for compaction", - msg.role, - n, - len(placeholder) // 4, - ) - # Preserve all known assistant-side fields (tool_calls, thinking_blocks, - # reasoning_content, provider_specific_fields) even when content is - # replaced. Anthropic extended-thinking models reject the next request - # with "Invalid signature in thinking block" if thinking_blocks is - # dropped from a prior assistant message. - kept = { - k: getattr(msg, k, None) - for k in ( - "tool_call_id", - "tool_calls", - "name", - "thinking_blocks", - "reasoning_content", - "provider_specific_fields", - ) - if getattr(msg, k, None) is not None - } - out.append(Message(role=msg.role, content=placeholder, **kept)) - return out - - def _recompute_usage(self, model_name: str) -> None: - """Refresh ``running_context_usage`` from current items via real tokenizer.""" - from litellm import token_counter - - try: - self.running_context_usage = token_counter( - model=model_name, - messages=[m.model_dump() for m in self.items], - ) - except Exception as e: - logger.warning("token_counter failed (%s); rough estimate", e) - # Rough fallback: 4 chars per token. - self.running_context_usage = ( - sum(len(getattr(m, "content", "") or "") for m in self.items) // 4 - ) - - async def compact( - self, - model_name: str, - tool_specs: list[dict] | None = None, - hf_token: str | None = None, - session: Any = None, - ) -> None: - """Remove old messages to keep history under target size. - - ``session`` is optional β€” if passed, the underlying summarization - LLM call is recorded via ``telemetry.record_llm_call(kind= - "compaction")`` so its cost shows up in ``total_cost_usd``. - - Raises ``CompactionFailedError`` if the post-compact context is still - over the threshold. This happens when a preserved message (typically - a giant tool output stuck in the untouched tail) is too large for - truncation to fix. The caller must terminate the session β€” retrying - is what caused the 2026-05-03 infinite-compaction-loop pattern that - burned Bedrock budget invisibly. - """ - if not self.needs_compaction: + async def compact(self, model_name: str) -> None: + """Remove old messages to keep history under target size""" + if (self.context_length <= self.max_context) or not self.items: return system_msg = ( self.items[0] if self.items and self.items[0].role == "system" else None ) - # Preserve the first user message (task prompt) β€” never summarize it - first_user_msg = None - first_user_idx = 1 - for i in range(1, len(self.items)): - if getattr(self.items[i], "role", None) == "user": - first_user_msg = self.items[i] - first_user_idx = i - break - # Don't summarize a certain number of just-preceding messages # Walk back to find a user message to make sure we keep an assistant -> user -> # assistant general conversation structure idx = len(self.items) - self.untouched_messages while idx > 1 and self.items[idx].role != "user": idx -= 1 - # The real invariant is "idx must be strictly after first_user_idx, - # otherwise recent_messages overlaps with the messages we put in - # head". The walk-back's `idx > 1` guard is necessary (no system in - # recent) but insufficient (first_user is also in head and would be - # duplicated). Anthropic API rejects two consecutive user messages - # with a 400 β€” bot review on PR #213 caught this on the second clamp - # iteration. - if idx <= first_user_idx: - idx = first_user_idx + 1 recent_messages = self.items[idx:] - messages_to_summarize = self.items[first_user_idx + 1 : idx] - - # Truncate any message that's larger than _MAX_TOKENS_PER_MESSAGE in - # the parts we PRESERVE through compaction (first_user + recent_tail). - # These are the only places where individual messages can defeat - # compaction by being intrinsically too large. Messages in - # ``messages_to_summarize`` are folded into the summary, so their size - # doesn't matter on its own. - if first_user_msg is not None: - truncated = self._truncate_oversized([first_user_msg], model_name) - first_user_msg = truncated[0] - recent_messages = self._truncate_oversized(recent_messages, model_name) + messages_to_summarize = self.items[1:idx] - # If there's nothing to summarize but the preserved messages are now - # truncated and small, just rebuild and recompute. This is rare but - # avoids returning silently with the old (over-threshold) state. + # improbable, messages would have to very long if not messages_to_summarize: - head = [system_msg] if system_msg else [] - if first_user_msg: - head.append(first_user_msg) - self.items = head + recent_messages - self._recompute_usage(model_name) - if self.running_context_usage > self.compaction_threshold: - raise CompactionFailedError( - f"Nothing to summarize but context ({self.running_context_usage}) " - f"still over threshold ({self.compaction_threshold}) after truncation. " - f"System prompt or first user message likely exceeds the budget." - ) return - summary, completion_tokens = await summarize_messages( - messages_to_summarize, - model_name=model_name, - hf_token=hf_token, - max_tokens=self.compact_size, - tool_specs=tool_specs, - prompt=_COMPACT_PROMPT, - session=session, - kind="compaction", + messages_to_summarize.append( + Message( + role="user", + content="Please provide a concise summary of the conversation above, focusing on key decisions, code changes, problems solved, and important context needed for future turns.", + ) + ) + + hf_key = os.environ.get("INFERENCE_TOKEN") + response = await acompletion( + model=model_name, + messages=messages_to_summarize, + max_completion_tokens=self.compact_size, + api_key=hf_key + if hf_key and model_name.startswith("huggingface/") + else None, ) summarized_message = Message( - role="assistant", - content=summary, + role="assistant", content=response.choices[0].message.content ) - # Reconstruct: system + first user msg + summary + recent messages - head = [system_msg] if system_msg else [] - if first_user_msg: - head.append(first_user_msg) - self.items = head + [summarized_message] + recent_messages - - self._recompute_usage(model_name) + # Reconstruct: system + summary + recent messages (includes tools) + if system_msg: + self.items = [system_msg, summarized_message] + recent_messages + else: + self.items = [summarized_message] + recent_messages - # Hard verify: if compaction didn't bring us below the threshold even - # after truncating oversized preserved messages, retrying just burns - # Bedrock budget on the same useless compaction call. Raise so the - # caller can terminate the session cleanly. Pre-2026-05-04, the - # caller looped indefinitely (~$3/Opus retry) until the pod was - # killed β€” invisible to the dataset because the session never - # finished cleanly. - if self.running_context_usage > self.compaction_threshold: - raise CompactionFailedError( - f"Compaction ineffective: {self.running_context_usage} tokens " - f"still over threshold {self.compaction_threshold} after summarize " - f"and truncation. Likely the system prompt + first user + summary " - f"+ truncated tail still exceeds budget." - ) + self.context_length = ( + len(self.system_prompt) // 4 + response.usage.completion_tokens + ) diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index e32e4e4204fc812a4a1b451728033bbd93d66e16..335e735cfb5e4f99751b1327fa997d60f29e40cb 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -5,94 +5,22 @@ Main agent implementation with integrated tool system and MCP support import asyncio import json import logging -import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -from litellm import ( - ChatCompletionMessageToolCall, - Message, - acompletion, - stream_chunk_builder, -) -from litellm.exceptions import ContextWindowExceededError +import os + +from litellm import ChatCompletionMessageToolCall, Message, acompletion +from lmnr import observe from agent.config import Config -from agent.core.approval_policy import ( - is_scheduled_operation, - normalize_tool_operation, -) -from agent.core.cost_estimation import CostEstimate, estimate_tool_cost -from agent.messaging.gateway import NotificationGateway -from agent.core import telemetry -from agent.core.doom_loop import check_for_doom_loop -from agent.core.llm_params import _resolve_llm_params -from agent.core.prompt_caching import with_prompt_caching -from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session +from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter from agent.tools.jobs_tool import CPU_FLAVORS -from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE logger = logging.getLogger(__name__) ToolCall = ChatCompletionMessageToolCall - -_MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '" -_MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments" - - -def _malformed_tool_name(message: Message) -> str | None: - """Return the tool name for malformed-json tool-result messages.""" - if getattr(message, "role", None) != "tool": - return None - content = getattr(message, "content", None) - if not isinstance(content, str): - return None - if not content.startswith(_MALFORMED_TOOL_PREFIX): - return None - end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX)) - if end == -1: - return None - return content[len(_MALFORMED_TOOL_PREFIX) : end] - - -def _detect_repeated_malformed( - items: list[Message], - threshold: int = 2, -) -> str | None: - """Return the repeated malformed tool name if the tail contains a streak. - - Walk backward over the current conversation tail. A streak counts only - consecutive malformed tool-result messages for the same tool; any other - tool result breaks it. - """ - if threshold <= 0: - return None - - streak_tool: str | None = None - streak = 0 - - for item in reversed(items): - if getattr(item, "role", None) != "tool": - continue - - malformed_tool = _malformed_tool_name(item) - if malformed_tool is None: - break - - if streak_tool is None: - streak_tool = malformed_tool - streak = 1 - elif malformed_tool == streak_tool: - streak += 1 - else: - break - - if streak >= threshold: - return streak_tool - - return None +# Explicit inference token β€” needed because litellm checks HF_TOKEN before +# HUGGINGFACE_API_KEY, and HF_TOKEN (used for Hub ops) may lack inference permissions. +_INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN") def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: @@ -117,57 +45,22 @@ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: return True, None -_IMMEDIATE_HF_JOB_RUNS = {"run", "uv"} - - -@dataclass(frozen=True) -class ApprovalDecision: - requires_approval: bool - auto_approved: bool = False - auto_approval_blocked: bool = False - block_reason: str | None = None - estimated_cost_usd: float | None = None - remaining_cap_usd: float | None = None - billable: bool = False - - -def _operation(tool_args: dict) -> str: - return normalize_tool_operation(tool_args.get("operation")) - - -def _is_immediate_hf_job_run(tool_name: str, tool_args: dict) -> bool: - return tool_name == "hf_jobs" and _operation(tool_args) in _IMMEDIATE_HF_JOB_RUNS - - -def _is_scheduled_hf_job_run(tool_name: str, tool_args: dict) -> bool: - return tool_name == "hf_jobs" and is_scheduled_operation(_operation(tool_args)) - - -def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool: - return tool_name == "sandbox_create" or _is_immediate_hf_job_run( - tool_name, tool_args - ) - - -def _base_needs_approval( +def _needs_approval( tool_name: str, tool_args: dict, config: Config | None = None ) -> bool: - """Check if a tool call requires approval before YOLO policy is applied.""" + """Check if a tool call requires user approval before execution.""" + # Yolo mode: skip all approvals + if config and config.yolo_mode: + return False # If args are malformed, skip approval (validation error will be shown later) args_valid, _ = _validate_tool_args(tool_args) if not args_valid: return False - if tool_name == "sandbox_create": - hardware = tool_args.get("hardware") or DEFAULT_CPU_SANDBOX_HARDWARE - return hardware != DEFAULT_CPU_SANDBOX_HARDWARE - if tool_name == "hf_jobs": - operation = _operation(tool_args) - if is_scheduled_operation(operation): - return True - if operation not in _IMMEDIATE_HF_JOB_RUNS: + operation = tool_args.get("operation", "") + if operation not in ["run", "uv", "scheduled run", "scheduled uv"]: return False # Check if this is a CPU-only job @@ -219,924 +112,23 @@ def _base_needs_approval( return False -def _needs_approval( - tool_name: str, tool_args: dict, config: Config | None = None -) -> bool: - """Legacy sync approval predicate used by tests and CLI display helpers.""" - if _is_scheduled_hf_job_run(tool_name, tool_args): - return True - if config and config.yolo_mode: - return False - return _base_needs_approval(tool_name, tool_args, config) - - -def _session_auto_approval_enabled(session: Session | None) -> bool: - return bool(session and getattr(session, "auto_approval_enabled", False)) - - -def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool: - return bool( - (config and config.yolo_mode) or _session_auto_approval_enabled(session) - ) - - -def _remaining_budget_after_reservations( - session: Session | None, reserved_spend_usd: float -) -> float | None: - if not session or getattr(session, "auto_approval_cost_cap_usd", None) is None: - return None - cap = float(getattr(session, "auto_approval_cost_cap_usd") or 0.0) - spent = float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0) - return round(max(0.0, cap - spent - reserved_spend_usd), 4) - - -def _budget_block_reason( - estimate: CostEstimate, - *, - remaining_cap_usd: float | None, -) -> str | None: - if estimate.estimated_cost_usd is None: - return estimate.block_reason or "Could not estimate the cost safely." - if ( - remaining_cap_usd is not None - and estimate.estimated_cost_usd > remaining_cap_usd - ): - return ( - f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds " - f"remaining YOLO cap ${remaining_cap_usd:.2f}." - ) - return None - - -async def _approval_decision( - tool_name: str, - tool_args: dict, - session: Session, - *, - reserved_spend_usd: float = 0.0, -) -> ApprovalDecision: - """Return the approval decision for one parsed tool call.""" - config = session.config - base_requires_approval = _base_needs_approval(tool_name, tool_args, config) - - # Scheduled jobs are recurring/unbounded enough that YOLO never bypasses - # the human confirmation, including legacy config.yolo_mode. - if _is_scheduled_hf_job_run(tool_name, tool_args): - return ApprovalDecision( - requires_approval=True, - auto_approval_blocked=_effective_yolo_enabled(session, config), - block_reason="Scheduled HF jobs always require manual approval.", - ) - - yolo_enabled = _effective_yolo_enabled(session, config) - budgeted_target = _is_budgeted_auto_approval_target(tool_name, tool_args) - - # Cost caps are a session-scoped web policy. Legacy config.yolo_mode - # remains uncapped for CLI/headless, except for scheduled jobs above. - session_yolo_enabled = _session_auto_approval_enabled(session) - if yolo_enabled and budgeted_target and session_yolo_enabled: - estimate = await estimate_tool_cost(tool_name, tool_args, session=session) - remaining = _remaining_budget_after_reservations(session, reserved_spend_usd) - reason = _budget_block_reason(estimate, remaining_cap_usd=remaining) - if reason: - return ApprovalDecision( - requires_approval=True, - auto_approval_blocked=True, - block_reason=reason, - estimated_cost_usd=estimate.estimated_cost_usd, - remaining_cap_usd=remaining, - billable=estimate.billable, - ) - if base_requires_approval: - return ApprovalDecision( - requires_approval=False, - auto_approved=True, - estimated_cost_usd=estimate.estimated_cost_usd, - remaining_cap_usd=remaining, - billable=estimate.billable, - ) - return ApprovalDecision( - requires_approval=False, - estimated_cost_usd=estimate.estimated_cost_usd, - remaining_cap_usd=remaining, - billable=estimate.billable, - ) - - if base_requires_approval and yolo_enabled: - return ApprovalDecision(requires_approval=False, auto_approved=True) - - return ApprovalDecision(requires_approval=base_requires_approval) - - -def _record_estimated_spend(session: Session, decision: ApprovalDecision) -> None: - if not decision.billable or decision.estimated_cost_usd is None: - return - if hasattr(session, "add_auto_approval_estimated_spend"): - session.add_auto_approval_estimated_spend(decision.estimated_cost_usd) - else: - session.auto_approval_estimated_spend_usd = round( - float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0) - + float(decision.estimated_cost_usd), - 4, - ) - - -async def _record_manual_approved_spend_if_needed( - session: Session, - tool_name: str, - tool_args: dict, -) -> None: - if not _session_auto_approval_enabled(session): - return - if not _is_budgeted_auto_approval_target(tool_name, tool_args): - return - estimate = await estimate_tool_cost(tool_name, tool_args, session=session) - _record_estimated_spend( - session, - ApprovalDecision( - requires_approval=False, - billable=estimate.billable, - estimated_cost_usd=estimate.estimated_cost_usd, - ), - ) - - -# -- LLM retry constants -------------------------------------------------- -_MAX_LLM_RETRIES = 3 -_LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries -_LLM_RATE_LIMIT_RETRY_DELAYS = [30, 60] # exceed Bedrock's ~60s TPM bucket window - - -def _is_rate_limit_error(error: Exception) -> bool: - """Return True for rate-limit / quota-bucket style provider errors.""" - err_str = str(error).lower() - rate_limit_patterns = [ - "429", - "rate limit", - "rate_limit", - "too many requests", - "too many tokens", - "request limit", - "throttl", - ] - return any(pattern in err_str for pattern in rate_limit_patterns) - - -def _is_context_overflow_error(error: Exception) -> bool: - """Return True when the prompt exceeded the model's context window.""" - if isinstance(error, ContextWindowExceededError): - return True - - err_str = str(error).lower() - overflow_patterns = [ - "context window exceeded", - "maximum context length", - "max context length", - "prompt is too long", - "context length exceeded", - "too many input tokens", - "input is too long", - ] - return any(pattern in err_str for pattern in overflow_patterns) - - -def _retry_delay_for(error: Exception, attempt_index: int) -> int | None: - """Return the delay for this retry attempt, or None if it should not retry.""" - if _is_rate_limit_error(error): - schedule = _LLM_RATE_LIMIT_RETRY_DELAYS - elif _is_transient_error(error): - schedule = _LLM_RETRY_DELAYS - else: - return None - - if attempt_index >= len(schedule): - return None - return schedule[attempt_index] - - -def _is_transient_error(error: Exception) -> bool: - """Return True for errors that are likely transient and worth retrying.""" - err_str = str(error).lower() - transient_patterns = [ - "timeout", - "timed out", - "503", - "service unavailable", - "502", - "bad gateway", - "500", - "internal server error", - "overloaded", - "capacity", - "connection reset", - "connection refused", - "connection error", - "eof", - "broken pipe", - ] - return _is_rate_limit_error(error) or any( - pattern in err_str for pattern in transient_patterns - ) - - -def _is_effort_config_error(error: Exception) -> bool: - """Catch the two 400s the effort probe also handles β€” thinking - unsupported for this model, or the specific effort level invalid. - - This is our safety net for the case where ``/effort`` was changed - mid-conversation (which clears the probe cache) and the new level - doesn't work for the current model. We heal the cache and retry once. - """ - from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported - - return _is_thinking_unsupported(error) or _is_invalid_effort(error) - - -async def _heal_effort_and_rebuild_params( - session: Session, - error: Exception, - llm_params: dict, -) -> dict: - """Update the session's effort cache based on ``error`` and return new - llm_params. Called only when ``_is_effort_config_error(error)`` is True. - - Two branches: - β€’ thinking-unsupported β†’ cache ``None`` for this model, next call - strips thinking entirely - β€’ invalid-effort β†’ re-run the full cascade probe; the result lands - in the cache - """ - from agent.core.effort_probe import ( - ProbeInconclusive, - _is_thinking_unsupported, - probe_effort, - ) - - model = session.config.model_name - if _is_thinking_unsupported(error): - session.model_effective_effort[model] = None - logger.info("healed: %s doesn't support thinking β€” stripped", model) - else: - try: - outcome = await probe_effort( - model, - session.config.reasoning_effort, - session.hf_token, - session=session, - ) - session.model_effective_effort[model] = outcome.effective_effort - logger.info( - "healed: %s effort cascade β†’ %s", - model, - outcome.effective_effort, - ) - except ProbeInconclusive: - # Transient during healing β€” strip thinking for safety, next - # call will either succeed or surface the real error. - session.model_effective_effort[model] = None - logger.info("healed: %s probe inconclusive β€” stripped", model) - - return _resolve_llm_params( - model, - session.hf_token, - reasoning_effort=session.effective_effort_for(model), - ) - - -def _friendly_error_message(error: Exception) -> str | None: - """Return a user-friendly message for known error types, or None to fall back to traceback.""" - err_str = str(error).lower() - - if ( - "authentication" in err_str - or "unauthorized" in err_str - or "invalid x-api-key" in err_str - ): - return ( - "Authentication failed β€” your API key is missing or invalid.\n\n" - "To fix this, set the API key for your model provider:\n" - " β€’ Anthropic: export ANTHROPIC_API_KEY=sk-...\n" - " β€’ OpenAI: export OPENAI_API_KEY=sk-...\n" - " β€’ HF Router: export HF_TOKEN=hf_...\n\n" - "You can also add it to a .env file in the project root.\n" - "To switch models, use the /model command." - ) - - if "insufficient" in err_str and "credit" in err_str: - return ( - "Insufficient API credits. Please check your account balance " - "at your model provider's dashboard." - ) - - if "not supported by provider" in err_str or "no provider supports" in err_str: - return ( - "The model isn't served by the provider you pinned.\n\n" - "Drop the ':' suffix to let the HF router auto-pick a " - "provider, or use '/model' (no arg) to see which providers host " - "which models." - ) - - if "model_not_found" in err_str or ( - "model" in err_str and ("not found" in err_str or "does not exist" in err_str) - ): - return ( - "Model not found. Use '/model' to list suggestions, or paste an " - "HF model id like 'MiniMaxAI/MiniMax-M2.7'. Availability is shown " - "when you switch." - ) - - return None - - -async def _compact_and_notify(session: Session) -> None: - """Run compaction and send event if context was reduced. - - Catches ``CompactionFailedError`` and ends the session cleanly instead - of letting the caller retry. Pre-2026-05-04 the caller looped on - ContextWindowExceededError β†’ compact β†’ re-trigger, burning Bedrock - budget at ~$3/Opus retry while the session never reached the upload - path (so the cost was invisible in the dataset). - """ - from agent.context_manager.manager import CompactionFailedError - - cm = session.context_manager - old_usage = cm.running_context_usage - logger.debug( - "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s", - old_usage, - cm.model_max_tokens, - cm.compaction_threshold, - cm.needs_compaction, - ) - try: - await cm.compact( - model_name=session.config.model_name, - tool_specs=session.tool_router.get_tool_specs_for_llm(), - hf_token=session.hf_token, - session=session, - ) - except CompactionFailedError as e: - logger.error( - "Compaction failed for session %s: %s β€” terminating session", - session.session_id, - e, - ) - # Persist the failure event so the dataset has a record of WHY this - # session ended (and the cost it incurred up to that point) even if - # save_and_upload_detached has issues downstream. - await session.send_event( - Event( - event_type="session_terminated", - data={ - "reason": "compaction_failed", - "context_usage": cm.running_context_usage, - "context_threshold": cm.compaction_threshold, - "error": str(e)[:300], - "user_message": ( - "Your conversation has grown too large to continue. " - "The work you've done is saved β€” start a new session to keep going." - ), - }, - ) - ) - # Stop the agent loop; the finally in _run_session will fire - # cleanup_sandbox + save_trajectory so the dataset captures - # everything that did happen. - session.is_running = False - return - - new_usage = cm.running_context_usage - if new_usage != old_usage: - logger.warning( - "Context compacted: %d -> %d tokens (max=%d, %d messages)", - old_usage, - new_usage, - cm.model_max_tokens, - len(cm.items), - ) - await session.send_event( - Event( - event_type="compacted", - data={"old_tokens": old_usage, "new_tokens": new_usage}, - ) - ) - - -async def _cleanup_on_cancel(session: Session) -> None: - """Kill sandbox processes and cancel HF jobs when the user interrupts.""" - # Kill active sandbox processes - sandbox = getattr(session, "sandbox", None) - if sandbox: - try: - await asyncio.to_thread(sandbox.kill_all) - logger.info("Killed sandbox processes on cancel") - except Exception as e: - logger.warning("Failed to kill sandbox processes: %s", e) - - # Cancel running HF jobs - job_ids = list(session._running_job_ids) - if job_ids: - from huggingface_hub import HfApi - - api = HfApi(token=session.hf_token) - for job_id in job_ids: - try: - await asyncio.to_thread(api.cancel_job, job_id=job_id) - logger.info("Cancelled HF job %s on interrupt", job_id) - except Exception as e: - logger.warning("Failed to cancel HF job %s: %s", job_id, e) - session._running_job_ids.clear() - - -@dataclass -class LLMResult: - """Result from an LLM call (streaming or non-streaming).""" - - content: str | None - tool_calls_acc: dict[int, dict] - token_count: int - finish_reason: str | None - usage: dict = field(default_factory=dict) - thinking_blocks: list[dict[str, Any]] | None = None - reasoning_content: str | None = None - - -def _extract_thinking_state( - message: Any, -) -> tuple[list[dict[str, Any]] | None, str | None]: - """Return provider reasoning fields that must be replayed after tool calls.""" - provider_fields = getattr(message, "provider_specific_fields", None) - if not isinstance(provider_fields, dict): - provider_fields = {} - - thinking_blocks = ( - getattr(message, "thinking_blocks", None) - or provider_fields.get("thinking_blocks") - or None - ) - reasoning_content = ( - getattr(message, "reasoning_content", None) - or provider_fields.get("reasoning_content") - or None - ) - return thinking_blocks, reasoning_content - - -def _should_replay_thinking_state(model_name: str | None) -> bool: - """Only Anthropic's native adapter accepts replayed thinking metadata.""" - return bool(model_name and model_name.startswith("anthropic/")) - - -def _is_invalid_thinking_signature_error(exc: Exception) -> bool: - """Return True when Anthropic rejected replayed extended-thinking state.""" - text = str(exc) - return ( - "Invalid `signature` in `thinking` block" in text - or "Invalid signature in thinking block" in text - ) - - -def _strip_thinking_state_from_messages(messages: list[Any]) -> int: - """Remove replayed thinking metadata from assistant history messages.""" - stripped = 0 - - for message in messages: - role = ( - message.get("role") - if isinstance(message, dict) - else getattr(message, "role", None) - ) - if role != "assistant": - continue - - if isinstance(message, dict): - if message.pop("thinking_blocks", None) is not None: - stripped += 1 - if message.pop("reasoning_content", None) is not None: - stripped += 1 - provider_fields = message.get("provider_specific_fields") - content = message.get("content") - else: - if getattr(message, "thinking_blocks", None) is not None: - message.thinking_blocks = None - stripped += 1 - if getattr(message, "reasoning_content", None) is not None: - message.reasoning_content = None - stripped += 1 - provider_fields = getattr(message, "provider_specific_fields", None) - content = getattr(message, "content", None) - - if isinstance(provider_fields, dict): - cleaned_fields = dict(provider_fields) - if cleaned_fields.pop("thinking_blocks", None) is not None: - stripped += 1 - if cleaned_fields.pop("reasoning_content", None) is not None: - stripped += 1 - if cleaned_fields != provider_fields: - if isinstance(message, dict): - message["provider_specific_fields"] = cleaned_fields - else: - message.provider_specific_fields = cleaned_fields - - if isinstance(content, list): - cleaned_content = [ - block - for block in content - if not ( - isinstance(block, dict) - and block.get("type") in {"thinking", "redacted_thinking"} - ) - ] - if len(cleaned_content) != len(content): - stripped += len(content) - len(cleaned_content) - if isinstance(message, dict): - message["content"] = cleaned_content - else: - message.content = cleaned_content - - return stripped - - -async def _maybe_heal_invalid_thinking_signature( - session: Session, - messages: list[Any], - exc: Exception, - *, - already_healed: bool, -) -> bool: - if already_healed or not _is_invalid_thinking_signature_error(exc): - return False - - stripped = _strip_thinking_state_from_messages(messages) - if not stripped: - return False - - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": ( - "Anthropic rejected stale thinking signatures; retrying " - "without replayed thinking metadata." - ), - }, - ) - ) - return True - - -def _assistant_message_from_result( - llm_result: LLMResult, - *, - model_name: str | None, - tool_calls: list[ToolCall] | None = None, -) -> Message: - """Build an assistant history message without dropping reasoning state.""" - kwargs: dict[str, Any] = { - "role": "assistant", - "content": llm_result.content, - } - if tool_calls is not None: - kwargs["tool_calls"] = tool_calls - if _should_replay_thinking_state(model_name): - if llm_result.thinking_blocks: - kwargs["thinking_blocks"] = llm_result.thinking_blocks - if llm_result.reasoning_content: - kwargs["reasoning_content"] = llm_result.reasoning_content - return Message(**kwargs) - - -async def _call_llm_streaming( - session: Session, messages, tools, llm_params -) -> LLMResult: - """Call the LLM with streaming, emitting assistant_chunk events.""" - response = None - _healed_effort = False # one-shot safety net per call - _healed_thinking_signature = False - messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) - t_start = time.monotonic() - for _llm_attempt in range(_MAX_LLM_RETRIES): - try: - response = await acompletion( - messages=messages, - tools=tools, - tool_choice="auto", - stream=True, - stream_options={"include_usage": True}, - timeout=600, - **llm_params, - ) - break - except ContextWindowExceededError: - raise - except Exception as e: - if _is_context_overflow_error(e): - raise ContextWindowExceededError(str(e)) from e - if not _healed_effort and _is_effort_config_error(e): - _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params( - session, e, llm_params - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": "Reasoning effort not supported for this model β€” adjusting and retrying.", - }, - ) - ) - continue - if await _maybe_heal_invalid_thinking_signature( - session, - messages, - e, - already_healed=_healed_thinking_signature, - ): - _healed_thinking_signature = True - continue - _delay = _retry_delay_for(e, _llm_attempt) - if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: - logger.warning( - "Transient LLM error (attempt %d/%d): %s β€” retrying in %ds", - _llm_attempt + 1, - _MAX_LLM_RETRIES, - e, - _delay, - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": f"LLM connection error, retrying in {_delay}s...", - }, - ) - ) - await asyncio.sleep(_delay) - continue - raise - - full_content = "" - tool_calls_acc: dict[int, dict] = {} - token_count = 0 - finish_reason = None - final_usage_chunk = None - chunks = [] - should_replay_thinking = _should_replay_thinking_state(llm_params.get("model")) - - async for chunk in response: - chunks.append(chunk) - if session.is_cancelled: - tool_calls_acc.clear() - break - - choice = chunk.choices[0] if chunk.choices else None - if not choice: - if hasattr(chunk, "usage") and chunk.usage: - token_count = chunk.usage.total_tokens - final_usage_chunk = chunk - continue - - delta = choice.delta - if choice.finish_reason: - finish_reason = choice.finish_reason - - if delta.content: - full_content += delta.content - await session.send_event( - Event(event_type="assistant_chunk", data={"content": delta.content}) - ) - - if delta.tool_calls: - for tc_delta in delta.tool_calls: - idx = tc_delta.index - if idx not in tool_calls_acc: - tool_calls_acc[idx] = { - "id": "", - "type": "function", - "function": {"name": "", "arguments": ""}, - } - if tc_delta.id: - tool_calls_acc[idx]["id"] = tc_delta.id - if tc_delta.function: - if tc_delta.function.name: - tool_calls_acc[idx]["function"]["name"] += ( - tc_delta.function.name - ) - if tc_delta.function.arguments: - tool_calls_acc[idx]["function"]["arguments"] += ( - tc_delta.function.arguments - ) - - if hasattr(chunk, "usage") and chunk.usage: - token_count = chunk.usage.total_tokens - final_usage_chunk = chunk - - usage = await telemetry.record_llm_call( - session, - model=llm_params.get("model", session.config.model_name), - response=final_usage_chunk, - latency_ms=int((time.monotonic() - t_start) * 1000), - finish_reason=finish_reason, - ) - thinking_blocks = None - reasoning_content = None - if chunks and should_replay_thinking: - try: - rebuilt = stream_chunk_builder(chunks, messages=messages) - if rebuilt and getattr(rebuilt, "choices", None): - rebuilt_msg = rebuilt.choices[0].message - thinking_blocks, reasoning_content = _extract_thinking_state( - rebuilt_msg - ) - except Exception: - logger.debug("Failed to rebuild streaming thinking state", exc_info=True) - - return LLMResult( - content=full_content or None, - tool_calls_acc=tool_calls_acc, - token_count=token_count, - finish_reason=finish_reason, - usage=usage, - thinking_blocks=thinking_blocks, - reasoning_content=reasoning_content, - ) - - -async def _call_llm_non_streaming( - session: Session, messages, tools, llm_params -) -> LLMResult: - """Call the LLM without streaming, emit assistant_message at the end.""" - response = None - _healed_effort = False - _healed_thinking_signature = False - messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) - t_start = time.monotonic() - for _llm_attempt in range(_MAX_LLM_RETRIES): - try: - response = await acompletion( - messages=messages, - tools=tools, - tool_choice="auto", - stream=False, - timeout=600, - **llm_params, - ) - break - except ContextWindowExceededError: - raise - except Exception as e: - if _is_context_overflow_error(e): - raise ContextWindowExceededError(str(e)) from e - if not _healed_effort and _is_effort_config_error(e): - _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params( - session, e, llm_params - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": "Reasoning effort not supported for this model β€” adjusting and retrying.", - }, - ) - ) - continue - if await _maybe_heal_invalid_thinking_signature( - session, - messages, - e, - already_healed=_healed_thinking_signature, - ): - _healed_thinking_signature = True - continue - _delay = _retry_delay_for(e, _llm_attempt) - if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: - logger.warning( - "Transient LLM error (attempt %d/%d): %s β€” retrying in %ds", - _llm_attempt + 1, - _MAX_LLM_RETRIES, - e, - _delay, - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": f"LLM connection error, retrying in {_delay}s...", - }, - ) - ) - await asyncio.sleep(_delay) - continue - raise - - choice = response.choices[0] - message = choice.message - content = message.content or None - finish_reason = choice.finish_reason - token_count = response.usage.total_tokens if response.usage else 0 - thinking_blocks, reasoning_content = _extract_thinking_state(message) - - # Build tool_calls_acc in the same format as streaming - tool_calls_acc: dict[int, dict] = {} - if message.tool_calls: - for idx, tc in enumerate(message.tool_calls): - tool_calls_acc[idx] = { - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - } - - # Emit the full message as a single event - if content: - await session.send_event( - Event(event_type="assistant_message", data={"content": content}) - ) - - usage = await telemetry.record_llm_call( - session, - model=llm_params.get("model", session.config.model_name), - response=response, - latency_ms=int((time.monotonic() - t_start) * 1000), - finish_reason=finish_reason, - ) - - return LLMResult( - content=content, - tool_calls_acc=tool_calls_acc, - token_count=token_count, - finish_reason=finish_reason, - usage=usage, - thinking_blocks=thinking_blocks, - reasoning_content=reasoning_content, - ) - - class Handlers: """Handler functions for each operation type""" @staticmethod - async def _abandon_pending_approval(session: Session) -> None: - """Cancel pending approval tools when the user continues the conversation. - - Injects rejection tool-result messages into the LLM context (so the - history stays valid) and notifies the frontend that those tools were - abandoned. - """ - tool_calls = session.pending_approval.get("tool_calls", []) - for tc in tool_calls: - tool_name = tc.function.name - abandon_msg = ( - "Task abandoned β€” user continued the conversation without approving." - ) - - # Keep LLM context valid: every tool_call needs a tool result - tool_msg = Message( - role="tool", - content=abandon_msg, - tool_call_id=tc.id, - name=tool_name, - ) - session.context_manager.add_message(tool_msg) - - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": tool_name, - "state": "abandoned", - }, - ) - ) - - session.pending_approval = None - logger.info("Abandoned %d pending approval tool(s)", len(tool_calls)) - - @staticmethod + @observe(name="run_agent") async def run_agent( - session: Session, - text: str, + session: Session, text: str, max_iterations: int = 10 ) -> str | None: """ Handle user input (like user_input_or_turn in codex.rs:1291) Returns the final assistant response content, if any. """ - # Clear any stale cancellation flag from a previous run - session.reset_cancel() + # Set session ID for this trace + if hasattr(session, "session_id"): + from lmnr import Laminar - # If there's a pending approval and the user sent a new message, - # abandon the pending tools so the LLM context stays valid. - if text and session.pending_approval: - await Handlers._abandon_pending_approval(session) + Laminar.set_trace_session_id(session_id=session.session_id) # Add user message to history only if there's actual content if text: @@ -1151,132 +143,77 @@ class Handlers: # Agentic loop - continue until model doesn't call tools or max iterations is reached iteration = 0 final_response = None - errored = False - max_iterations = session.config.max_iterations - - while max_iterations == -1 or iteration < max_iterations: - # ── Cancellation check: before LLM call ── - if session.is_cancelled: - break - - # Compact before calling the LLM if context is near the limit. - # When _compact_and_notify catches CompactionFailedError it sets - # session.is_running = False; we MUST exit the loop here, otherwise - # the LLM call below fires with an over-threshold context, hits - # ContextWindowExceededError, and we end up looping again on the - # except path β€” exactly the bug this PR is supposed to fix. - await _compact_and_notify(session) - if not session.is_running: - break - - # Doom-loop detection: break out of repeated tool call patterns - doom_prompt = check_for_doom_loop(session.context_manager.items) - if doom_prompt: - session.context_manager.add_message( - Message(role="user", content=doom_prompt) - ) - - malformed_tool = _detect_repeated_malformed(session.context_manager.items) - if malformed_tool: - recovery_prompt = ( - "[SYSTEM: Repeated malformed tool arguments detected for " - f"'{malformed_tool}'. Stop retrying the same tool call shape. " - "Use a different strategy that produces smaller, valid JSON. " - "For large file writes, prefer bash with a heredoc or split the " - "edit into multiple smaller tool calls.]" - ) - session.context_manager.add_message( - Message(role="user", content=recovery_prompt) - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": ( - "Repeated malformed tool arguments detected β€” " - f"forcing a different strategy for {malformed_tool}" - ), - }, - ) - ) + while iteration < max_iterations: messages = session.context_manager.get_messages() tools = session.tool_router.get_tool_specs_for_llm() try: - # ── Call the LLM (streaming or non-streaming) ── - # Pull the per-model probed effort from the session cache when - # available; fall back to the raw preference for models we - # haven't probed yet (e.g. research sub-model). - llm_params = _resolve_llm_params( - session.config.model_name, - session.hf_token, - reasoning_effort=session.effective_effort_for( - session.config.model_name - ), - ) - if session.stream: - llm_result = await _call_llm_streaming( - session, messages, tools, llm_params - ) - else: - llm_result = await _call_llm_non_streaming( - session, messages, tools, llm_params - ) - - content = llm_result.content - tool_calls_acc = llm_result.tool_calls_acc - token_count = llm_result.token_count - finish_reason = llm_result.finish_reason - - # If output was truncated, all tool call args are garbage. - # Inject a system hint so the LLM retries with smaller content. - if finish_reason == "length" and tool_calls_acc: - dropped_names = [ - tc["function"]["name"] - for tc in tool_calls_acc.values() - if tc["function"]["name"] - ] - logger.warning( - "Output truncated (finish_reason=length) β€” dropping tool calls: %s", - dropped_names, - ) - tool_calls_acc.clear() - - # Tell the agent what happened so it can retry differently - truncation_hint = ( - "Your previous response was truncated because the output hit the " - "token limit. The following tool calls were lost: " - f"{dropped_names}. " - "IMPORTANT: Do NOT retry with the same large content. Instead:\n" - " β€’ For 'write': use bash with cat<<'HEREDOC' to write the file, " - "or split into several smaller edit calls.\n" - " β€’ For other tools: reduce the size of your arguments or use bash." - ) - if content: - assistant_msg = _assistant_message_from_result( - llm_result, - model_name=llm_params.get("model"), - ) - session.context_manager.add_message(assistant_msg, token_count) - session.context_manager.add_message( - Message(role="user", content=f"[SYSTEM: {truncation_hint}]") - ) - if session.stream: + # ── Stream the LLM response ────────────────────────── + response = await acompletion( + model=session.config.model_name, + messages=messages, + tools=tools, + tool_choice="auto", + stream=True, + stream_options={"include_usage": True}, + api_key=_INFERENCE_API_KEY + if _INFERENCE_API_KEY + and session.config.model_name.startswith("huggingface/") + else None, + ) + + full_content = "" + tool_calls_acc: dict[int, dict] = {} + token_count = 0 + + async for chunk in response: + choice = chunk.choices[0] if chunk.choices else None + if not choice: + # Last chunk may carry only usage info + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + continue + + delta = choice.delta + + # Stream text deltas to the frontend + if delta.content: + full_content += delta.content await session.send_event( - Event(event_type="assistant_stream_end", data={}) - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": f"Output truncated β€” retrying with smaller content ({dropped_names})", - }, + Event( + event_type="assistant_chunk", + data={"content": delta.content}, + ) ) - ) - iteration += 1 - continue # retry this iteration + + # Accumulate tool-call deltas (name + args arrive in pieces) + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls_acc: + tool_calls_acc[idx] = { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + if tc_delta.id: + tool_calls_acc[idx]["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + tool_calls_acc[idx]["function"]["name"] += ( + tc_delta.function.name + ) + if tc_delta.function.arguments: + tool_calls_acc[idx]["function"]["arguments"] += ( + tc_delta.function.arguments + ) + + # Capture usage from the final chunk + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + + # ── Stream finished β€” reconstruct full message ─────── + content = full_content or None # Build tool_calls list from accumulated deltas tool_calls: list[ToolCall] = [] @@ -1294,155 +231,63 @@ class Handlers: ) # Signal end of streaming to the frontend - if session.stream: - await session.send_event( - Event(event_type="assistant_stream_end", data={}) - ) + await session.send_event( + Event(event_type="assistant_stream_end", data={}) + ) # If no tool calls, add assistant message and we're done if not tool_calls: - logger.debug( - "Agent loop ending: no tool calls. " - "finish_reason=%s, token_count=%d, " - "usage=%d, model_max_tokens=%d, " - "iteration=%d/%d, " - "response_text=%s", - finish_reason, - token_count, - session.context_manager.running_context_usage, - session.context_manager.model_max_tokens, - iteration, - max_iterations, - (content or "")[:500], - ) if content: - assistant_msg = _assistant_message_from_result( - llm_result, - model_name=llm_params.get("model"), - ) + assistant_msg = Message(role="assistant", content=content) session.context_manager.add_message(assistant_msg, token_count) final_response = content break - # Validate tool call args (one json.loads per call, once) - # and split into good vs bad - good_tools: list[tuple[ToolCall, str, dict]] = [] - bad_tools: list[ToolCall] = [] - for tc in tool_calls: - try: - args = json.loads(tc.function.arguments) - good_tools.append((tc, tc.function.name, args)) - except (json.JSONDecodeError, TypeError, ValueError): - logger.warning( - "Malformed arguments for tool_call %s (%s) β€” skipping", - tc.id, - tc.function.name, - ) - tc.function.arguments = "{}" - bad_tools.append(tc) - - # Add assistant message with all tool calls to context - assistant_msg = _assistant_message_from_result( - llm_result, - model_name=llm_params.get("model"), + # Add assistant message with tool calls to history + assistant_msg = Message( + role="assistant", + content=content, tool_calls=tool_calls, ) session.context_manager.add_message(assistant_msg, token_count) - # Add error results for bad tool calls so the LLM - # knows what happened and can retry differently - for tc in bad_tools: - error_msg = ( - f"ERROR: Tool call to '{tc.function.name}' had malformed JSON " - f"arguments and was NOT executed. Retry with smaller content β€” " - f"for 'write', split into multiple smaller writes using 'edit'." - ) - session.context_manager.add_message( - Message( - role="tool", - content=error_msg, - tool_call_id=tc.id, - name=tc.function.name, - ) - ) - await session.send_event( - Event( - event_type="tool_call", - data={ - "tool": tc.function.name, - "arguments": {}, - "tool_call_id": tc.id, - }, - ) - ) - await session.send_event( - Event( - event_type="tool_output", - data={ - "tool": tc.function.name, - "tool_call_id": tc.id, - "output": error_msg, - "success": False, - }, - ) - ) + # Separate tools into those requiring approval and those that don't + approval_required_tools = [] + non_approval_tools = [] - # ── Cancellation check: before tool execution ── - if session.is_cancelled: - break + for tc in tool_calls: + tool_name = tc.function.name + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Malformed tool arguments for {tool_name}: {e}") + tool_args = {} - # Separate good tools into approval-required vs auto-execute. - # Track reserved spend while classifying a batch so two - # auto-approved jobs in one model response cannot jointly - # exceed the remaining session cap. - approval_required_tools: list[ - tuple[ToolCall, str, dict, ApprovalDecision] - ] = [] - non_approval_tools: list[ - tuple[ToolCall, str, dict, ApprovalDecision] - ] = [] - reserved_auto_spend_usd = 0.0 - for tc, tool_name, tool_args in good_tools: - decision = await _approval_decision( - tool_name, - tool_args, - session, - reserved_spend_usd=reserved_auto_spend_usd, - ) - if decision.requires_approval: - approval_required_tools.append( - (tc, tool_name, tool_args, decision) - ) + if _needs_approval(tool_name, tool_args, session.config): + approval_required_tools.append(tc) else: - non_approval_tools.append((tc, tool_name, tool_args, decision)) - if ( - decision.auto_approved - and decision.billable - and decision.estimated_cost_usd is not None - ): - reserved_auto_spend_usd += decision.estimated_cost_usd + non_approval_tools.append(tc) # Execute non-approval tools (in parallel when possible) if non_approval_tools: - # 1. Validate args upfront + # 1. Parse args and validate upfront parsed_tools: list[ - tuple[ToolCall, str, dict, ApprovalDecision, bool, str] + tuple[ChatCompletionMessageToolCall, str, dict, bool, str] ] = [] - for tc, tool_name, tool_args, decision in non_approval_tools: + for tc in non_approval_tools: + tool_name = tc.function.name + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + tool_args = {} + args_valid, error_msg = _validate_tool_args(tool_args) parsed_tools.append( - (tc, tool_name, tool_args, decision, args_valid, error_msg) + (tc, tool_name, tool_args, args_valid, error_msg) ) # 2. Send all tool_call events upfront (so frontend shows them all) - for ( - tc, - tool_name, - tool_args, - _decision, - args_valid, - _, - ) in parsed_tools: + for tc, tool_name, tool_args, args_valid, _ in parsed_tools: if args_valid: await session.send_event( Event( @@ -1455,64 +300,28 @@ class Handlers: ) ) - # 3. Execute all valid tools in parallel, cancellable + # 3. Execute all valid tools in parallel async def _exec_tool( - tc: ToolCall, + tc: ChatCompletionMessageToolCall, name: str, args: dict, - decision: ApprovalDecision, valid: bool, err: str, - ) -> tuple[ToolCall, str, dict, str, bool]: + ) -> tuple[ChatCompletionMessageToolCall, str, dict, str, bool]: if not valid: return (tc, name, args, err, False) - if decision.billable: - _record_estimated_spend(session, decision) out, ok = await session.tool_router.call_tool( - name, args, session=session, tool_call_id=tc.id + name, args, session=session ) return (tc, name, args, out, ok) - gather_task = asyncio.ensure_future( - asyncio.gather( - *[ - _exec_tool(tc, name, args, decision, valid, err) - for tc, name, args, decision, valid, err in parsed_tools - ] - ) - ) - cancel_task = asyncio.ensure_future(session._cancelled.wait()) - - done, _ = await asyncio.wait( - [gather_task, cancel_task], - return_when=asyncio.FIRST_COMPLETED, + results = await asyncio.gather( + *[ + _exec_tool(tc, name, args, valid, err) + for tc, name, args, valid, err in parsed_tools + ] ) - if cancel_task in done: - gather_task.cancel() - try: - await gather_task - except asyncio.CancelledError: - pass - # Notify frontend that in-flight tools were cancelled - for tc, name, _args, _decision, valid, _ in parsed_tools: - if valid: - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": name, - "state": "cancelled", - }, - ) - ) - await _cleanup_on_cancel(session) - break - - cancel_task.cancel() - results = gather_task.result() - # 4. Record results and send outputs (order preserved) for tc, tool_name, tool_args, output, success in results: tool_msg = Message( @@ -1539,60 +348,33 @@ class Handlers: if approval_required_tools: # Prepare batch approval data tools_data = [] - blocked_payloads = [] - for tc, tool_name, tool_args, decision in approval_required_tools: - # Resolve sandbox file paths for hf_jobs scripts so the - # frontend can display & edit the actual file content. - if tool_name == "hf_jobs" and isinstance( - tool_args.get("script"), str - ): - from agent.tools.sandbox_tool import resolve_sandbox_script - - sandbox = getattr(session, "sandbox", None) - resolved, _ = await resolve_sandbox_script( - sandbox, tool_args["script"] - ) - if resolved: - tool_args = {**tool_args, "script": resolved} - - tool_payload = { - "tool": tool_name, - "arguments": tool_args, - "tool_call_id": tc.id, - } - if decision.auto_approval_blocked: - tool_payload.update( - { - "auto_approval_blocked": True, - "block_reason": decision.block_reason, - "estimated_cost_usd": decision.estimated_cost_usd, - "remaining_cap_usd": decision.remaining_cap_usd, - } - ) - blocked_payloads.append(tool_payload) - tools_data.append(tool_payload) - - event_data = {"tools": tools_data, "count": len(tools_data)} - if blocked_payloads: - first = blocked_payloads[0] - event_data.update( + for tc in approval_required_tools: + tool_name = tc.function.name + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + tool_args = {} + tools_data.append( { - "auto_approval_blocked": True, - "block_reason": first.get("block_reason"), - "estimated_cost_usd": first.get("estimated_cost_usd"), - "remaining_cap_usd": first.get("remaining_cap_usd"), + "tool": tool_name, + "arguments": tool_args, + "tool_call_id": tc.id, } ) + await session.send_event( Event( event_type="approval_required", - data=event_data, + data={ + "tools": tools_data, # Batch of tools + "count": len(tools_data), + }, ) ) - # Store all approval-requiring tools (ToolCall objects for execution) + # Store all approval-requiring tools session.pending_approval = { - "tool_calls": [tc for tc, _, _, _ in approval_required_tools], + "tool_calls": approval_required_tools, } # Return early - wait for EXEC_APPROVAL operation @@ -1600,59 +382,36 @@ class Handlers: iteration += 1 - except ContextWindowExceededError: - # Force compact and retry this iteration. - cm = session.context_manager - logger.warning( - "ContextWindowExceededError at iteration %d β€” forcing compaction " - "(usage=%d, model_max_tokens=%d, messages=%d)", - iteration, - cm.running_context_usage, - cm.model_max_tokens, - len(cm.items), - ) - cm.running_context_usage = cm.model_max_tokens + 1 - await _compact_and_notify(session) - # Same guard as the top of the loop: if compaction couldn't - # bring us under threshold, _compact_and_notify has already - # emitted session_terminated and set is_running=False. Continue - # would just re-call the LLM with the same too-big context. - if not session.is_running: - break - continue - except Exception as e: import traceback - error_msg = _friendly_error_message(e) - if error_msg is None: - error_msg = str(e) + "\n" + traceback.format_exc() - await session.send_event( Event( event_type="error", - data={"error": error_msg}, + data={"error": str(e) + "\n" + traceback.format_exc()}, ) ) - errored = True break - if session.is_cancelled: - await _cleanup_on_cancel(session) - await session.send_event(Event(event_type="interrupted")) - elif not errored: + old_length = session.context_manager.context_length + await session.context_manager.compact(model_name=session.config.model_name) + new_length = session.context_manager.context_length + + if new_length != old_length: await session.send_event( Event( - event_type="turn_complete", - data={ - "history_size": len(session.context_manager.items), - "final_response": final_response - if isinstance(final_response, str) - else None, - }, + event_type="compacted", + data={"old_tokens": old_length, "new_tokens": new_length}, ) ) + await session.send_event( + Event( + event_type="turn_complete", + data={"history_size": len(session.context_manager.items)}, + ) + ) + # Increment turn counter and check for auto-save session.increment_turn() await session.auto_save_if_needed() @@ -1660,26 +419,50 @@ class Handlers: return final_response @staticmethod - async def undo(session: Session) -> None: - """Remove the last complete turn and notify the frontend.""" - removed = session.context_manager.undo_last_turn() - if not removed: - logger.warning("Undo: no user message found to remove") - await session.send_event(Event(event_type="undo_complete")) + async def interrupt(session: Session) -> None: + """Handle interrupt (like interrupt in codex.rs:1266)""" + session.interrupt() + await session.send_event(Event(event_type="interrupted")) @staticmethod - async def resume(session: Session, path: str) -> None: - """Reload context from a saved session log into the active session.""" - from agent.core.session_resume import restore_session_from_log + async def compact(session: Session) -> None: + """Handle compact (like compact in codex.rs:1317)""" + old_length = session.context_manager.context_length + await session.context_manager.compact(model_name=session.config.model_name) + new_length = session.context_manager.context_length - try: - result = restore_session_from_log(session, Path(path)) - except Exception as e: - await session.send_event( - Event(event_type="error", data={"error": f"Resume failed: {e}"}) + await session.send_event( + Event( + event_type="compacted", + data={"removed": old_length, "remaining": new_length}, ) + ) + + @staticmethod + async def undo(session: Session) -> None: + """Remove the last complete turn (user msg + all assistant/tool msgs that follow). + + Anthropic requires every tool_use to have a matching tool_result, + so we can't just pop 2 items β€” we must pop everything back to + (and including) the last user message to keep the history valid. + """ + items = session.context_manager.items + if not items: + await session.send_event(Event(event_type="undo_complete")) return - await session.send_event(Event(event_type="resume_complete", data=result)) + + # Pop from the end until we've removed the last user message + removed_user = False + while items: + msg = items.pop() + if getattr(msg, "role", None) == "user": + removed_user = True + break + + if not removed_user: + logger.warning("Undo: no user message found to remove") + + await session.send_event(Event(event_type="undo_complete")) @staticmethod async def exec_approval(session: Session, approvals: list[dict]) -> None: @@ -1705,11 +488,6 @@ class Handlers: # Create a map of tool_call_id -> approval decision approval_map = {a["tool_call_id"]: a for a in approvals} - for a in approvals: - if a.get("edited_script"): - logger.info( - f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)" - ) # Separate approved and rejected tool calls approved_tasks = [] @@ -1717,146 +495,43 @@ class Handlers: for tc in tool_calls: tool_name = tc.function.name - try: - tool_args = json.loads(tc.function.arguments) - except (json.JSONDecodeError, TypeError) as e: - # Malformed arguments β€” treat as failed, notify agent - logger.warning(f"Malformed tool arguments for {tool_name}: {e}") - tool_msg = Message( - role="tool", - content=f"Malformed arguments: {e}", - tool_call_id=tc.id, - name=tool_name, - ) - session.context_manager.add_message(tool_msg) - await session.send_event( - Event( - event_type="tool_output", - data={ - "tool": tool_name, - "tool_call_id": tc.id, - "output": f"Malformed arguments: {e}", - "success": False, - }, - ) - ) - continue - + tool_args = json.loads(tc.function.arguments) approval_decision = approval_map.get(tc.id, {"approved": False}) if approval_decision.get("approved", False): - edited_script = approval_decision.get("edited_script") - was_edited = False - if edited_script and "script" in tool_args: - tool_args["script"] = edited_script - was_edited = True - logger.info(f"Using user-edited script for {tool_name} ({tc.id})") - selected_namespace = approval_decision.get("namespace") - if selected_namespace and tool_name == "hf_jobs": - tool_args["namespace"] = selected_namespace - approved_tasks.append((tc, tool_name, tool_args, was_edited)) + approved_tasks.append((tc, tool_name, tool_args)) else: rejected_tasks.append((tc, tool_name, approval_decision)) - # Clear pending approval immediately so a page refresh during - # execution won't re-show the approval dialog. - session.pending_approval = None - - # Notify frontend of approval decisions immediately (before execution) - for tc, tool_name, tool_args, _was_edited in approved_tasks: - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": tool_name, - "state": "approved", - }, - ) - ) - for tc, tool_name, approval_decision in rejected_tasks: - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": tool_name, - "state": "rejected", - }, - ) - ) - # Execute all approved tools concurrently - async def execute_tool(tc, tool_name, tool_args, was_edited): - """Execute a single tool and return its result. - - The TraceLog already exists on the frontend (created by - approval_required), so we send tool_state_change instead of - tool_call to avoid creating a duplicate. - """ + async def execute_tool(tc, tool_name, tool_args): + """Execute a single tool and return its result""" await session.send_event( Event( - event_type="tool_state_change", + event_type="tool_call", data={ - "tool_call_id": tc.id, "tool": tool_name, - "state": "running", + "arguments": tool_args, + "tool_call_id": tc.id, }, ) ) - await _record_manual_approved_spend_if_needed(session, tool_name, tool_args) - output, success = await session.tool_router.call_tool( - tool_name, tool_args, session=session, tool_call_id=tc.id + tool_name, tool_args, session=session ) - return (tc, tool_name, output, success, was_edited) + return (tc, tool_name, output, success) - # Execute all approved tools concurrently (cancellable) + # Execute all approved tools concurrently and wait for ALL to complete if approved_tasks: - gather_task = asyncio.ensure_future( - asyncio.gather( - *[ - execute_tool(tc, tool_name, tool_args, was_edited) - for tc, tool_name, tool_args, was_edited in approved_tasks - ], - return_exceptions=True, - ) + results = await asyncio.gather( + *[ + execute_tool(tc, tool_name, tool_args) + for tc, tool_name, tool_args in approved_tasks + ], + return_exceptions=True, ) - cancel_task = asyncio.ensure_future(session._cancelled.wait()) - - done, _ = await asyncio.wait( - [gather_task, cancel_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - if cancel_task in done: - gather_task.cancel() - try: - await gather_task - except asyncio.CancelledError: - pass - # Notify frontend that approved tools were cancelled - for tc, tool_name, _args, _was_edited in approved_tasks: - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": tool_name, - "state": "cancelled", - }, - ) - ) - await _cleanup_on_cancel(session) - await session.send_event(Event(event_type="interrupted")) - session.increment_turn() - await session.auto_save_if_needed() - return - - cancel_task.cancel() - results = gather_task.result() # Process results and add to context for result in results: @@ -1865,10 +540,7 @@ class Handlers: logger.error(f"Tool execution error: {result}") continue - tc, tool_name, output, success, was_edited = result - - if was_edited: - output = f"[Note: The user edited the script before execution. The output below reflects the user-modified version, not your original script.]\n\n{output}" + tc, tool_name, output, success = result # Add tool result to context tool_msg = Message( @@ -1896,16 +568,7 @@ class Handlers: rejection_msg = "Job execution cancelled by user" user_feedback = approval_decision.get("feedback") if user_feedback: - # Ensure feedback is a string and sanitize any problematic characters - feedback_str = str(user_feedback).strip() - # Remove any control characters that might break JSON parsing - feedback_str = "".join( - char for char in feedback_str if ord(char) >= 32 or char in "\n\t" - ) - rejection_msg += f". User feedback: {feedback_str}" - - # Ensure rejection_msg is a clean string - rejection_msg = str(rejection_msg).strip() + rejection_msg += f". User feedback: {user_feedback}" tool_msg = Message( role="tool", @@ -1927,6 +590,9 @@ class Handlers: ) ) + # Clear pending approval + session.pending_approval = None + # Continue agent loop with empty input to process the tool results await Handlers.run_agent(session, "") @@ -1959,24 +625,18 @@ async def process_submission(session: Session, submission) -> bool: await Handlers.run_agent(session, text) return True + if op.op_type == OpType.INTERRUPT: + await Handlers.interrupt(session) + return True + if op.op_type == OpType.COMPACT: - await _compact_and_notify(session) + await Handlers.compact(session) return True if op.op_type == OpType.UNDO: await Handlers.undo(session) return True - if op.op_type == OpType.RESUME: - path = op.data.get("path") if op.data else None - if path: - await Handlers.resume(session, path) - else: - await session.send_event( - Event(event_type="error", data={"error": "Resume requires a path"}) - ) - return True - if op.op_type == OpType.EXEC_APPROVAL: approvals = op.data.get("approvals", []) if op.data else [] await Handlers.exec_approval(session, approvals) @@ -1989,19 +649,12 @@ async def process_submission(session: Session, submission) -> bool: return True +@observe(name="submission_loop") async def submission_loop( submission_queue: asyncio.Queue, event_queue: asyncio.Queue, - config: Config, + config: Config | None = None, tool_router: ToolRouter | None = None, - session_holder: list | None = None, - hf_token: str | None = None, - user_id: str | None = None, - local_mode: bool = False, - stream: bool = True, - notification_gateway: NotificationGateway | None = None, - notification_destinations: list[str] | None = None, - defer_turn_complete_notification: bool = False, ) -> None: """ Main agent loop - processes submissions and dispatches to handlers. @@ -2009,30 +662,13 @@ async def submission_loop( """ # Create session with tool router - session = Session( - event_queue, - config=config, - tool_router=tool_router, - hf_token=hf_token, - user_id=user_id, - local_mode=local_mode, - stream=stream, - notification_gateway=notification_gateway, - notification_destinations=notification_destinations, - defer_turn_complete_notification=defer_turn_complete_notification, - ) - if session_holder is not None: - session_holder[0] = session + session = Session(event_queue, config=config, tool_router=tool_router) logger.info("Agent loop started") - # Retry any failed uploads from previous sessions (fire-and-forget). - # Includes the personal trace repo when enabled so a session that failed - # to publish to the user's HF dataset gets a fresh attempt on next run. + # Retry any failed uploads from previous sessions (fire-and-forget) if config and config.save_sessions: Session.retry_failed_uploads_detached( - directory=str(DEFAULT_SESSION_LOG_DIR), - repo_id=config.session_dataset_repo, - personal_repo_id=session._personal_trace_repo_id(), + directory="session_logs", repo_id=config.session_dataset_repo ) try: @@ -2040,13 +676,7 @@ async def submission_loop( async with tool_router: # Emit ready event after initialization await session.send_event( - Event( - event_type="ready", - data={ - "message": "Agent initialized", - "tool_count": len(tool_router.tools), - }, - ) + Event(event_type="ready", data={"message": "Agent initialized"}) ) while session.is_running: diff --git a/agent/core/approval_policy.py b/agent/core/approval_policy.py deleted file mode 100644 index 73098ca61dffca66929984bd5b5c34e532106f18..0000000000000000000000000000000000000000 --- a/agent/core/approval_policy.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Shared predicates for approval-gated tool operations.""" - -from typing import Any - - -def normalize_tool_operation(operation: Any) -> str: - return str(operation or "").strip().lower() - - -def is_scheduled_operation(operation: Any) -> bool: - return normalize_tool_operation(operation).startswith("scheduled ") diff --git a/agent/core/cost_estimation.py b/agent/core/cost_estimation.py deleted file mode 100644 index a41ad196efec7495c7ca9141d2f7f3a4f38e6dbd..0000000000000000000000000000000000000000 --- a/agent/core/cost_estimation.py +++ /dev/null @@ -1,282 +0,0 @@ -"""Conservative cost estimates for auto-approved infrastructure actions.""" - -import os -import re -import time -from dataclasses import dataclass -from typing import Any - -import httpx - -OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") -JOBS_HARDWARE_URL = f"{OPENID_PROVIDER_URL}/api/jobs/hardware" -JOBS_PRICE_CACHE_TTL_S = 6 * 60 * 60 - -DEFAULT_JOB_TIMEOUT_HOURS = 0.5 -DEFAULT_SANDBOX_RESERVATION_HOURS = 1.0 - -# Static fallback prices are intentionally conservative enough for a budget -# guard. The live /api/jobs/hardware catalog wins whenever it is reachable. -HF_JOBS_PRICE_USD_PER_HOUR: dict[str, float] = { - "cpu-basic": 0.05, - "cpu-upgrade": 0.25, - "cpu-performance": 0.50, - "cpu-xl": 1.00, - "t4-small": 0.60, - "t4-medium": 0.90, - "l4x1": 1.00, - "l4x4": 4.00, - "l40sx1": 2.00, - "l40sx4": 8.00, - "l40sx8": 16.00, - "a10g-small": 1.00, - "a10g-large": 2.00, - "a10g-largex2": 4.00, - "a10g-largex4": 8.00, - "a100-large": 4.00, - "a100x4": 16.00, - "a100x8": 32.00, - "h200": 10.00, - "h200x2": 20.00, - "h200x4": 40.00, - "h200x8": 80.00, - "inf2x6": 6.00, -} - -SPACE_PRICE_USD_PER_HOUR: dict[str, float] = { - "cpu-basic": 0.0, - "cpu-upgrade": 0.05, - "cpu-performance": 0.50, - "cpu-xl": 1.00, - "t4-small": 0.60, - "t4-medium": 0.90, - "l4x1": 1.00, - "l4x4": 4.00, - "l40sx1": 2.00, - "l40sx4": 8.00, - "l40sx8": 16.00, - "a10g-small": 1.00, - "a10g-large": 2.00, - "a10g-largex2": 4.00, - "a10g-largex4": 8.00, - "a100-large": 4.00, - "a100x4": 16.00, - "a100x8": 32.00, - "h200": 10.00, - "h200x2": 20.00, - "h200x4": 40.00, - "h200x8": 80.00, - "inf2x6": 6.00, -} - -_DURATION_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s*([smhd]?)\s*$", re.IGNORECASE) -_PRICE_RE = re.compile(r"(\d+(?:\.\d+)?)") -_jobs_price_cache: tuple[float, dict[str, float]] | None = None - - -@dataclass(frozen=True) -class CostEstimate: - """Estimated cost for a tool call. - - ``estimated_cost_usd=None`` means the call may be billable but we could not - estimate it safely, so auto-approval should fall back to a human decision. - """ - - estimated_cost_usd: float | None - billable: bool - block_reason: str | None = None - label: str | None = None - - -def parse_timeout_hours( - value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS -) -> float | None: - """Parse HF timeout values into hours. - - Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are - treated as seconds, matching the Hub client's typed timeout parameter. - """ - if value is None or value == "": - return default_hours - if isinstance(value, bool): - return None - if isinstance(value, int | float): - seconds = float(value) - return seconds / 3600 if seconds > 0 else None - if not isinstance(value, str): - return None - - match = _DURATION_RE.match(value) - if not match: - return None - amount = float(match.group(1)) - unit = match.group(2).lower() or "s" - if amount <= 0: - return None - if unit == "s": - return amount / 3600 - if unit == "m": - return amount / 60 - if unit == "h": - return amount - if unit == "d": - return amount * 24 - return None - - -def _extract_flavor(item: dict[str, Any]) -> str | None: - for key in ("flavor", "name", "id", "value", "hardware", "hardware_flavor"): - value = item.get(key) - if isinstance(value, str) and value: - return value - return None - - -def _coerce_price(value: Any) -> float | None: - if isinstance(value, bool) or value is None: - return None - if isinstance(value, int | float): - return float(value) if value >= 0 else None - if isinstance(value, str): - match = _PRICE_RE.search(value.replace(",", "")) - if match: - return float(match.group(1)) - return None - - -def _extract_hourly_price(item: dict[str, Any]) -> float | None: - for key in ( - "price", - "price_usd", - "priceUsd", - "price_per_hour", - "pricePerHour", - "hourly_price", - "hourlyPrice", - "usd_per_hour", - "usdPerHour", - ): - price = _coerce_price(item.get(key)) - if price is not None: - return price - for key in ("pricing", "billing", "cost"): - nested = item.get(key) - if isinstance(nested, dict): - price = _extract_hourly_price(nested) - if price is not None: - return price - return None - - -def _iter_hardware_items(payload: Any): - if isinstance(payload, list): - for item in payload: - yield from _iter_hardware_items(item) - elif isinstance(payload, dict): - if _extract_flavor(payload): - yield payload - for key in ("hardware", "flavors", "items", "data", "jobs"): - child = payload.get(key) - if child is not None: - yield from _iter_hardware_items(child) - - -def _parse_jobs_price_catalog(payload: Any) -> dict[str, float]: - prices: dict[str, float] = {} - for item in _iter_hardware_items(payload): - flavor = _extract_flavor(item) - price = _extract_hourly_price(item) - if flavor and price is not None: - prices[flavor] = price - return prices - - -async def hf_jobs_price_catalog() -> dict[str, float]: - """Return live HF Jobs hourly prices, falling back to static prices.""" - global _jobs_price_cache - now = time.monotonic() - if _jobs_price_cache and now - _jobs_price_cache[0] < JOBS_PRICE_CACHE_TTL_S: - return dict(_jobs_price_cache[1]) - - prices: dict[str, float] = {} - try: - async with httpx.AsyncClient(timeout=3.0) as client: - response = await client.get(JOBS_HARDWARE_URL) - if response.status_code == 200: - prices = _parse_jobs_price_catalog(response.json()) - except (httpx.HTTPError, ValueError): - prices = {} - - if not prices: - prices = dict(HF_JOBS_PRICE_USD_PER_HOUR) - else: - prices = {**HF_JOBS_PRICE_USD_PER_HOUR, **prices} - - _jobs_price_cache = (now, prices) - return dict(prices) - - -async def estimate_hf_job_cost(args: dict[str, Any]) -> CostEstimate: - flavor = str( - args.get("hardware_flavor") - or args.get("flavor") - or args.get("hardware") - or "cpu-basic" - ) - timeout_hours = parse_timeout_hours(args.get("timeout")) - if timeout_hours is None: - return CostEstimate( - estimated_cost_usd=None, - billable=True, - block_reason=f"Could not parse HF job timeout: {args.get('timeout')!r}.", - label=flavor, - ) - - prices = await hf_jobs_price_catalog() - price = prices.get(flavor) - if price is None: - return CostEstimate( - estimated_cost_usd=None, - billable=True, - block_reason=f"No price is available for HF job hardware '{flavor}'.", - label=flavor, - ) - - return CostEstimate( - estimated_cost_usd=round(price * timeout_hours, 4), - billable=price > 0, - label=flavor, - ) - - -async def estimate_sandbox_cost( - args: dict[str, Any], *, session: Any = None -) -> CostEstimate: - if session is not None and getattr(session, "sandbox", None): - return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing") - - hardware = str(args.get("hardware") or "cpu-basic") - price = SPACE_PRICE_USD_PER_HOUR.get(hardware) - if price is None: - return CostEstimate( - estimated_cost_usd=None, - billable=True, - block_reason=f"No price is available for sandbox hardware '{hardware}'.", - label=hardware, - ) - - return CostEstimate( - estimated_cost_usd=round(price * DEFAULT_SANDBOX_RESERVATION_HOURS, 4), - billable=price > 0, - label=hardware, - ) - - -async def estimate_tool_cost( - tool_name: str, args: dict[str, Any], *, session: Any = None -) -> CostEstimate: - if tool_name == "sandbox_create": - return await estimate_sandbox_cost(args, session=session) - if tool_name == "hf_jobs": - return await estimate_hf_job_cost(args) - return CostEstimate(estimated_cost_usd=0.0, billable=False) diff --git a/agent/core/doom_loop.py b/agent/core/doom_loop.py deleted file mode 100644 index 3b57fe2cc3cffd07b466db9ac98cc0d0b665de79..0000000000000000000000000000000000000000 --- a/agent/core/doom_loop.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -Doom-loop detection for repeated tool call patterns. - -Detects when the agent is stuck calling the same tools repeatedly -and injects a corrective prompt to break the cycle. -""" - -import hashlib -import json -import logging -from dataclasses import dataclass - -from litellm import Message - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class ToolCallSignature: - """Hashable signature for a single tool call plus its observed result.""" - - name: str - args_hash: str - result_hash: str | None = None - - -def _normalize_args(args_str: str) -> str: - """Canonicalise a tool-call arguments string before hashing. - - LLMs can emit semantically-identical JSON for the same call with different - key orderings (``{"a": 1, "b": 2}`` vs ``{"b": 2, "a": 1}``) or whitespace - (``{"a":1}`` vs ``{"a": 1}``). Hashing the raw bytes makes the doom-loop - detector miss those repeats. We parse-and-redump with ``sort_keys=True`` - plus the most compact separators so trivially-different spellings collapse - to the same canonical form. - - Falls back to the original string if the input isn't valid JSON (e.g. a - handful of providers occasionally pass a bare string for ``arguments``); - that path keeps the legacy behaviour and never raises. - """ - if not args_str: - return "" - try: - return json.dumps(json.loads(args_str), sort_keys=True, separators=(",", ":")) - except (json.JSONDecodeError, TypeError, ValueError): - return args_str - - -def _hash_args(args_str: str) -> str: - """Return a short hash of the JSON arguments string. - - The input is normalised via :func:`_normalize_args` first so that - semantically-identical tool calls produce the same hash regardless of key - order or whitespace. - """ - return hashlib.md5(_normalize_args(args_str).encode()).hexdigest()[:12] - - -def extract_recent_tool_signatures( - messages: list[Message], lookback: int = 30 -) -> list[ToolCallSignature]: - """Extract tool call signatures from recent assistant messages. - - Includes the immediate tool result hash when present. This prevents - legitimate polling from being classified as a doom loop when the poll - arguments stay constant but the observed result keeps changing. - """ - signatures: list[ToolCallSignature] = [] - recent = messages[-lookback:] if len(messages) > lookback else messages - - for idx, msg in enumerate(recent): - if getattr(msg, "role", None) != "assistant": - continue - tool_calls = getattr(msg, "tool_calls", None) - if not tool_calls: - continue - for tc in tool_calls: - fn = getattr(tc, "function", None) - if not fn: - continue - name = getattr(fn, "name", "") or "" - args_str = getattr(fn, "arguments", "") or "" - result_hash = None - for follow in recent[idx + 1 :]: - role = getattr(follow, "role", None) - if role == "tool" and getattr(follow, "tool_call_id", None) == getattr( - tc, "id", None - ): - result_hash = _hash_args(str(getattr(follow, "content", "") or "")) - break - if role in {"assistant", "user"}: - break - signatures.append( - ToolCallSignature( - name=name, - args_hash=_hash_args(args_str), - result_hash=result_hash, - ) - ) - - return signatures - - -def detect_identical_consecutive( - signatures: list[ToolCallSignature], threshold: int = 3 -) -> str | None: - """Return the tool name if threshold+ identical consecutive calls are found.""" - if len(signatures) < threshold: - return None - - count = 1 - for i in range(1, len(signatures)): - if signatures[i] == signatures[i - 1]: - count += 1 - if count >= threshold: - return signatures[i].name - else: - count = 1 - - return None - - -def detect_repeating_sequence( - signatures: list[ToolCallSignature], -) -> list[ToolCallSignature] | None: - """Detect repeating patterns like [A,B,A,B] for sequences of length 2-5 with 2+ reps.""" - n = len(signatures) - for seq_len in range(2, 6): - min_required = seq_len * 2 - if n < min_required: - continue - - # Check the tail of the signatures list - tail = signatures[-min_required:] - pattern = tail[:seq_len] - - # Count how many full repetitions from the end - reps = 0 - for start in range(n - seq_len, -1, -seq_len): - chunk = signatures[start : start + seq_len] - if chunk == pattern: - reps += 1 - else: - break - - if reps >= 2: - return pattern - - return None - - -def check_for_doom_loop(messages: list[Message]) -> str | None: - """Check for doom loop patterns. Returns a corrective prompt or None.""" - signatures = extract_recent_tool_signatures(messages, lookback=30) - if len(signatures) < 3: - return None - - # Check for identical consecutive calls - tool_name = detect_identical_consecutive(signatures, threshold=3) - if tool_name: - logger.warning( - "Repetition guard activated: %d+ identical consecutive calls to '%s'", - 3, - tool_name, - ) - return ( - f"[SYSTEM: REPETITION GUARD] You have called '{tool_name}' with the same " - f"arguments multiple times in a row, getting the same result each time. " - f"STOP repeating this approach β€” it is not working. " - f"Step back and try a fundamentally different strategy. " - f"Consider: using a different tool, changing your arguments significantly, " - f"or explaining to the user what you're stuck on and asking for guidance." - ) - - # Check for repeating sequences - pattern = detect_repeating_sequence(signatures) - if pattern: - pattern_desc = " β†’ ".join(s.name for s in pattern) - logger.warning( - "Repetition guard activated: repeating sequence [%s]", pattern_desc - ) - return ( - f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: " - f"[{pattern_desc}]. This pattern has repeated multiple times without progress. " - f"STOP this cycle and try a fundamentally different approach. " - f"Consider: breaking down the problem differently, using alternative tools, " - f"or explaining to the user what you're stuck on and asking for guidance." - ) - - return None diff --git a/agent/core/effort_probe.py b/agent/core/effort_probe.py deleted file mode 100644 index dbad4c3da95e939ec9d2dae5c6c7408bcc6ea156..0000000000000000000000000000000000000000 --- a/agent/core/effort_probe.py +++ /dev/null @@ -1,284 +0,0 @@ -"""Probe-and-cascade for reasoning effort on /model switch. - -We don't maintain a per-model capability table. Instead, the first time a -user picks a model we fire a 1-token ping with the same params we'd use -for real and walk down a cascade (``max`` β†’ ``xhigh`` β†’ ``high`` β†’ …) -until the provider stops rejecting us. The result is cached per-model on -the session, so real messages don't pay the probe cost again. - -Three outcomes, classified from the 400 error text: - -* success β†’ cache the effort that worked -* ``"thinking ... not supported"`` β†’ model doesn't do thinking at all; - cache ``None`` so we stop sending thinking params -* ``"effort ... invalid"`` / synonyms β†’ cascade walks down and retries - -Transient errors (5xx, timeout, connection reset) bubble out as -``ProbeInconclusive`` so the caller can complete the switch with a -warning instead of blocking on a flaky provider. -""" - -from __future__ import annotations - -import asyncio -import logging -import time -from dataclasses import dataclass -from typing import Any - -from litellm import acompletion - -from agent.core.llm_params import UnsupportedEffortError, _resolve_llm_params - -logger = logging.getLogger(__name__) - - -# Cascade: for each user-stated preference, the ordered list of levels to -# try. First success wins. ``max`` is Anthropic-only; ``xhigh`` is also -# supported on current OpenAI GPT-5 models. Providers that don't accept a -# requested level raise ``UnsupportedEffortError`` synchronously (no wasted -# network round-trip) and we advance to the next level. -_EFFORT_CASCADE: dict[str, list[str]] = { - "max": ["max", "xhigh", "high", "medium", "low"], - "xhigh": ["xhigh", "high", "medium", "low"], - "high": ["high", "medium", "low"], - "medium": ["medium", "low"], - "minimal": ["minimal", "low"], - "low": ["low"], -} - -_PROBE_TIMEOUT = 15.0 -# Keep the probe cheap, but high enough that frontier reasoning models can -# finish a trivial reply instead of tripping a false "output limit reached" -# error during capability detection. -_PROBE_MAX_TOKENS = 64 - - -class ProbeInconclusive(Exception): - """The probe couldn't reach a verdict (transient network / provider error). - - Caller should complete the switch with a warning β€” the next real call - will re-surface the error if it's persistent. - """ - - -@dataclass -class ProbeOutcome: - """What the probe learned. ``effective_effort`` semantics match the cache: - - * str β†’ send this level - * None β†’ model doesn't support thinking; strip it - """ - - effective_effort: str | None - attempts: int - elapsed_ms: int - note: str | None = None # e.g. "max not supported, falling back" - - -def _is_thinking_unsupported(e: Exception) -> bool: - """Model rejected any thinking config. - - Matches Anthropic's 'thinking.type.enabled is not supported for this - model' as well as the adaptive variant. Substring-match because the - exact wording shifts across API versions. - """ - s = str(e).lower() - return "thinking" in s and "not supported" in s - - -def _is_invalid_effort(e: Exception) -> bool: - """The requested effort level isn't accepted for this model. - - Covers both API responses (Anthropic/OpenAI 400 with "invalid", "must - be one of", etc.) and LiteLLM's local validation that fires *before* - the request (e.g. "effort='max' is only supported by Claude Opus 4.6" - β€” LiteLLM knows max is Opus-4.6-only and raises synchronously). The - cascade walks down on either. - - Explicitly returns False when the message is really about thinking - itself (e.g. Anthropic's 4.7 error mentions ``output_config.effort`` - in its fix hint, but the actual failure is ``thinking.type.enabled`` - being unsupported). That case is caught by ``_is_thinking_unsupported``. - """ - if _is_thinking_unsupported(e): - return False - s = str(e).lower() - if "effort" not in s and "output_config" not in s: - return False - return any( - phrase in s - for phrase in ( - "invalid", - "not supported", - "must be one of", - "not a valid", - "unrecognized", - "unknown", - # LiteLLM's own pre-flight validation phrasing. - "only supported by", - "is only supported", - ) - ) - - -def _is_transient(e: Exception) -> bool: - """Network / provider-side flake. Keep in sync with agent_loop's list. - - Also matches by type for ``asyncio.TimeoutError`` β€” its ``str(e)`` is - empty, so substring matching alone misses it. - """ - if isinstance(e, (asyncio.TimeoutError, TimeoutError)): - return True - s = str(e).lower() - return any( - p in s - for p in ( - "timeout", - "timed out", - "429", - "rate limit", - "503", - "service unavailable", - "502", - "bad gateway", - "500", - "internal server error", - "overloaded", - "capacity", - "connection reset", - "connection refused", - "connection error", - "eof", - "broken pipe", - ) - ) - - -async def probe_effort( - model_name: str, - preference: str | None, - hf_token: str | None, - session: Any = None, -) -> ProbeOutcome: - """Walk the cascade for ``preference`` on ``model_name``. - - Returns the first effort the provider accepts, or ``None`` if it - rejects thinking altogether. Raises ``ProbeInconclusive`` only for - transient errors (5xx, timeout) β€” persistent 4xx that aren't thinking/ - effort related bubble as the original exception so callers can surface - them (auth, model-not-found, quota, etc.). - - ``session`` is optional; when provided, each successful probe attempt - is recorded via ``telemetry.record_llm_call(kind="effort_probe")`` so - the cost shows up in the session's ``total_cost_usd``. Failed probes - (rejected by the provider) typically aren't billed, so we only record - on success. - """ - loop = asyncio.get_event_loop() - start = loop.time() - attempts = 0 - - if not preference: - # User explicitly turned effort off β€” nothing to probe. A bare - # ping with no thinking params is pointless; just report "off". - return ProbeOutcome(effective_effort=None, attempts=0, elapsed_ms=0) - - cascade = _EFFORT_CASCADE.get(preference, [preference]) - skipped: list[str] = [] # levels the provider rejected synchronously - - last_error: Exception | None = None - for effort in cascade: - try: - params = _resolve_llm_params( - model_name, - hf_token, - reasoning_effort=effort, - strict=True, - ) - except UnsupportedEffortError: - # Provider can't even accept this effort name (e.g. "max" on - # HF router). Skip without a network call. - skipped.append(effort) - continue - - attempts += 1 - try: - _t0 = time.monotonic() - response = await asyncio.wait_for( - acompletion( - messages=[{"role": "user", "content": "ping"}], - max_tokens=_PROBE_MAX_TOKENS, - stream=False, - **params, - ), - timeout=_PROBE_TIMEOUT, - ) - if session is not None: - # Best-effort telemetry β€” never let a logging blip propagate - # out of the probe and break model switching. - try: - from agent.core import telemetry - - await telemetry.record_llm_call( - session, - model=model_name, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind="effort_probe", - ) - except Exception as _telem_err: - logger.debug("effort_probe telemetry failed: %s", _telem_err) - except Exception as e: - last_error = e - if _is_thinking_unsupported(e): - elapsed = int((loop.time() - start) * 1000) - return ProbeOutcome( - effective_effort=None, - attempts=attempts, - elapsed_ms=elapsed, - note="model doesn't support reasoning, dropped", - ) - if _is_invalid_effort(e): - logger.debug( - "probe: %s rejected effort=%s, trying next", model_name, effort - ) - continue - if _is_transient(e): - raise ProbeInconclusive(str(e)) from e - # Persistent non-thinking 4xx (auth, quota, model-not-found) β€” - # let the caller classify & surface. - raise - else: - elapsed = int((loop.time() - start) * 1000) - note = None - if effort != preference: - note = f"{preference} not supported, using {effort}" - return ProbeOutcome( - effective_effort=effort, - attempts=attempts, - elapsed_ms=elapsed, - note=note, - ) - - # Cascade exhausted without a success. This only happens when every - # level was either rejected synchronously (``UnsupportedEffortError``, - # e.g. preference=max on HF and we also somehow filtered all others) - # or the provider 400'd ``invalid effort`` on every level. - elapsed = int((loop.time() - start) * 1000) - if last_error is not None and not _is_invalid_effort(last_error): - raise last_error - note = ( - "no effort level accepted β€” proceeding without thinking" - if not skipped - else f"provider rejected all efforts ({', '.join(skipped)})" - ) - return ProbeOutcome( - effective_effort=None, - attempts=attempts, - elapsed_ms=elapsed, - note=note, - ) diff --git a/agent/core/hf_access.py b/agent/core/hf_access.py deleted file mode 100644 index 254a9c73161df9be2866f7cd2574dc7701934f08..0000000000000000000000000000000000000000 --- a/agent/core/hf_access.py +++ /dev/null @@ -1,172 +0,0 @@ -"""Helpers for Hugging Face account / org access decisions. - -HF Jobs are gated by *credits*, not by HF Pro subscriptions. Any user who -has credits β€” on their personal account or on an org they belong to β€” can -launch jobs under that namespace. The picker UI lets the caller choose -which wallet to bill. -""" - -from __future__ import annotations - -import asyncio -import os -import re -from dataclasses import dataclass -from typing import Any - -import httpx - -OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") - - -@dataclass(frozen=True) -class JobsAccess: - """Namespaces the caller may bill HF Jobs to.""" - - username: str | None - org_names: list[str] - eligible_namespaces: list[str] - default_namespace: str | None - access_known: bool = True - - -class JobsAccessError(Exception): - """Structured jobs-namespace error. - - ``namespace_required`` fires when the caller belongs to more than one - eligible namespace and the UI must prompt them to pick one. There is no - longer an ``upgrade_required`` state β€” Pro is irrelevant; HF Jobs are - gated on per-wallet credits, surfaced separately when the API returns - a billing error at job-creation time. - """ - - def __init__( - self, - message: str, - *, - access: JobsAccess | None = None, - namespace_required: bool = False, - ) -> None: - super().__init__(message) - self.access = access - self.namespace_required = namespace_required - - -def _extract_username(whoami: dict[str, Any]) -> str | None: - for key in ("name", "user", "preferred_username"): - value = whoami.get(key) - if isinstance(value, str) and value: - return value - return None - - -def _org_names(whoami: dict[str, Any]) -> list[str]: - """All orgs the caller belongs to. - - Plan/tier is ignored β€” credits live on the namespace itself, so any - org the user belongs to can host a job as long as it has credits. - """ - names: list[str] = [] - orgs = whoami.get("orgs") or [] - if not isinstance(orgs, list): - return names - for org in orgs: - if not isinstance(org, dict): - continue - name = org.get("name") - if isinstance(name, str) and name: - names.append(name) - return sorted(set(names)) - - -def jobs_access_from_whoami(whoami: dict[str, Any]) -> JobsAccess: - username = _extract_username(whoami) - org_names = _org_names(whoami) - eligible: list[str] = [] - if username: - eligible.append(username) - eligible.extend(org_names) - default = username if username else (org_names[0] if org_names else None) - return JobsAccess( - username=username, - org_names=org_names, - eligible_namespaces=eligible, - default_namespace=default, - ) - - -async def fetch_whoami_v2(token: str, timeout: float = 5.0) -> dict[str, Any] | None: - if not token: - return None - async with httpx.AsyncClient(timeout=timeout) as client: - try: - response = await client.get( - f"{OPENID_PROVIDER_URL}/api/whoami-v2", - headers={"Authorization": f"Bearer {token}"}, - ) - if response.status_code != 200: - return None - payload = response.json() - return payload if isinstance(payload, dict) else None - except (httpx.HTTPError, ValueError): - return None - - -async def get_jobs_access(token: str) -> JobsAccess | None: - whoami = await fetch_whoami_v2(token) - if whoami is None: - return None - return jobs_access_from_whoami(whoami) - - -async def resolve_jobs_namespace( - token: str, - requested_namespace: str | None = None, -) -> tuple[str, JobsAccess | None]: - """Return the namespace to use for jobs. - - If whoami-v2 is unavailable, fall back to the token owner's username. - """ - access = await get_jobs_access(token) - if access: - if requested_namespace: - if requested_namespace in access.eligible_namespaces: - return requested_namespace, access - raise JobsAccessError( - f"You can only run jobs under your own account or an org you belong to. " - f"Allowed namespaces: {', '.join(access.eligible_namespaces) or '(none)'}", - access=access, - ) - if access.default_namespace: - return access.default_namespace, access - raise JobsAccessError( - "Couldn't resolve a Hugging Face namespace for this token.", - access=access, - ) - - # Fallback: whoami-v2 unavailable. Don't block the call pre-emptively. - from huggingface_hub import HfApi - - username = None - if token: - whoami = await asyncio.to_thread(HfApi(token=token).whoami) - username = whoami.get("name") - if not username: - raise JobsAccessError("No HF token available to resolve a jobs namespace.") - return requested_namespace or username, None - - -_BILLING_PATTERNS = re.compile( - r"\b(insufficient[_\s-]?credits?|out\s+of\s+credits?|payment\s+required|" - r"billing|no\s+credits?|add\s+credits?|requires?\s+credits?)\b", - re.IGNORECASE, -) - - -def is_billing_error(message: str) -> bool: - """True if an HF API error message looks like an out-of-credits / billing error.""" - if not message: - return False - if "402" in message: - return True - return bool(_BILLING_PATTERNS.search(message)) diff --git a/agent/core/hf_router_catalog.py b/agent/core/hf_router_catalog.py deleted file mode 100644 index 625ccf4fb85498e229fe63dc0faac56628d0be39..0000000000000000000000000000000000000000 --- a/agent/core/hf_router_catalog.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Fetch and cache the HF Inference Router model catalog. - -The router exposes an OpenAI-compatible listing at -``https://router.huggingface.co/v1/models`` with per-provider availability, -pricing, context length, and tool-use support. We use it to: - - β€’ Validate ``/model`` switches with live data instead of a hard-coded allowlist. - β€’ Show the user which providers serve a model, at what price, and whether they - support tool calls. - β€’ Derive a reasonable context-window limit for any routed model. - -The listing is cached in-memory for a few minutes so repeated lookups during a -session are free. On fetch failure we return stale data if we have it, or an -empty catalog otherwise. -""" - -import logging -import time -from dataclasses import dataclass -from difflib import get_close_matches -from typing import Optional - -import httpx - -logger = logging.getLogger(__name__) - -_CATALOG_URL = "https://router.huggingface.co/v1/models" -_CACHE_TTL_SECONDS = 300 -_HTTP_TIMEOUT_SECONDS = 5.0 - -_cache: Optional[dict] = None -_cache_time: float = 0.0 - - -@dataclass -class ProviderInfo: - provider: str - status: str - context_length: Optional[int] - input_price: Optional[float] - output_price: Optional[float] - supports_tools: bool - supports_structured_output: bool - - -@dataclass -class ModelInfo: - id: str - providers: list[ProviderInfo] - - @property - def live_providers(self) -> list[ProviderInfo]: - return [p for p in self.providers if p.status == "live"] - - @property - def max_context_length(self) -> Optional[int]: - lengths = [p.context_length for p in self.live_providers if p.context_length] - return max(lengths) if lengths else None - - @property - def any_supports_tools(self) -> bool: - return any(p.supports_tools for p in self.live_providers) - - -def _fetch_catalog(force: bool = False) -> dict: - global _cache, _cache_time - now = time.time() - if not force and _cache is not None and now - _cache_time < _CACHE_TTL_SECONDS: - return _cache - try: - resp = httpx.get(_CATALOG_URL, timeout=_HTTP_TIMEOUT_SECONDS) - resp.raise_for_status() - _cache = resp.json() - _cache_time = now - except Exception as e: - logger.warning("Failed to fetch HF router catalog: %s", e) - if _cache is None: - _cache = {"data": []} - _cache_time = now - return _cache - - -def _parse_entry(entry: dict) -> ModelInfo: - providers = [] - for p in entry.get("providers", []) or []: - pricing = p.get("pricing") or {} - providers.append( - ProviderInfo( - provider=p.get("provider", ""), - status=p.get("status", ""), - context_length=p.get("context_length"), - input_price=pricing.get("input"), - output_price=pricing.get("output"), - supports_tools=bool(p.get("supports_tools", False)), - supports_structured_output=bool( - p.get("supports_structured_output", False) - ), - ) - ) - return ModelInfo(id=entry.get("id", ""), providers=providers) - - -def lookup(model_id: str) -> Optional[ModelInfo]: - """Find a model in the router catalog. - - Accepts ``/`` or ``/:`` β€” the tag is stripped - for lookup. Returns ``None`` if the model isn't listed. - """ - bare = model_id.split(":", 1)[0] - catalog = _fetch_catalog() - for entry in catalog.get("data", []): - if entry.get("id") == bare: - return _parse_entry(entry) - return None - - -def fuzzy_suggest(model_id: str, limit: int = 3) -> list[str]: - """Return the closest model ids from the catalog.""" - bare = model_id.split(":", 1)[0] - catalog = _fetch_catalog() - ids = [e.get("id", "") for e in catalog.get("data", []) if e.get("id")] - return get_close_matches(bare, ids, n=limit, cutoff=0.4) - - -def prewarm() -> None: - """Fetch the catalog so subsequent lookups are instant. Safe to call from - a background task β€” swallows failures.""" - try: - _fetch_catalog(force=False) - except Exception: - pass diff --git a/agent/core/hf_tokens.py b/agent/core/hf_tokens.py deleted file mode 100644 index 3e72ccc128a9d9aaecb661c4c2ba3850a10b5dc0..0000000000000000000000000000000000000000 --- a/agent/core/hf_tokens.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Hugging Face token resolution helpers.""" - -from __future__ import annotations - -import os -from typing import Any - - -def clean_hf_token(token: str | None) -> str | None: - """Normalize token strings the same way huggingface_hub does.""" - if token is None: - return None - return token.replace("\r", "").replace("\n", "").strip() or None - - -def get_cached_hf_token() -> str | None: - """Return the token from huggingface_hub's normal env/cache lookup.""" - try: - from huggingface_hub import get_token - - return get_token() - except Exception: - return None - - -def resolve_hf_token( - *candidates: str | None, - include_cached: bool = True, -) -> str | None: - """Return the first non-empty explicit token, then optionally HF cache.""" - for token in candidates: - cleaned = clean_hf_token(token) - if cleaned: - return cleaned - if include_cached: - return get_cached_hf_token() - return None - - -def resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: - """Resolve the token used for Hugging Face Router LLM calls. - - App-specific precedence: - 1. INFERENCE_TOKEN: shared hosted-Space inference token. - 2. session_hf_token: the active user/session token. - 3. huggingface_hub.get_token(): HF_TOKEN/HUGGING_FACE_HUB_TOKEN or - local ``hf auth login`` cache. - """ - return resolve_hf_token(os.environ.get("INFERENCE_TOKEN"), session_hf_token) - - -def get_hf_bill_to() -> str | None: - """Return X-HF-Bill-To only when a shared inference token is active.""" - if clean_hf_token(os.environ.get("INFERENCE_TOKEN")): - return os.environ.get("HF_BILL_TO", "smolagents") - return None - - -def bearer_token_from_header(auth_header: str | None) -> str | None: - """Extract a cleaned bearer token from an Authorization header.""" - if not auth_header or not auth_header.startswith("Bearer "): - return None - return clean_hf_token(auth_header[7:]) - - -def resolve_hf_request_token( - request: Any, - *, - include_env_fallback: bool = True, -) -> str | None: - """Resolve a user token from a FastAPI request. - - This intentionally does not use the local ``hf auth login`` cache. Backend - request paths should act as the browser user from Authorization/cookie, or - fall back only to an explicit server ``HF_TOKEN`` in dev/server contexts. - """ - token = bearer_token_from_header(request.headers.get("Authorization", "")) - if token: - return token - token = clean_hf_token(request.cookies.get("hf_access_token")) - if token: - return token - if include_env_fallback: - return clean_hf_token(os.environ.get("HF_TOKEN")) - return None diff --git a/agent/core/hub_artifacts.py b/agent/core/hub_artifacts.py deleted file mode 100644 index 8a0b1b5b11ae64ba2cfbb9c2ff7b2dfa0d3714d3..0000000000000000000000000000000000000000 --- a/agent/core/hub_artifacts.py +++ /dev/null @@ -1,758 +0,0 @@ -"""Best-effort Hub metadata for artifacts generated by ML Intern sessions.""" - -import base64 -import logging -import re -import shlex -import tempfile -import textwrap -from datetime import datetime -from pathlib import Path -from typing import Any - -from huggingface_hub import hf_hub_download -from huggingface_hub.repocard import metadata_load, metadata_save -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError - -logger = logging.getLogger(__name__) - -ML_INTERN_TAG = "ml-intern" -SUPPORTED_REPO_TYPES = {"model", "dataset", "space"} -PROVENANCE_MARKER = "" -_COLLECTION_TITLE_PREFIX = "ml-intern-artifacts" -_COLLECTION_TITLE_MAX_LENGTH = 59 -_UUID_SESSION_ID_RE = re.compile( - r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-" - r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" -) -_KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts" -_REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts" -_COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug" -_SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {} -_USAGE_HEADING_RE = re.compile( - r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b", - re.IGNORECASE | re.MULTILINE, -) -_FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL) - - -def _safe_session_id(session: Any) -> str: - raw = str(getattr(session, "session_id", "") or "unknown-session") - safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-") - return safe or "unknown-session" - - -def session_artifact_date(session: Any) -> str: - """Return the YYYY-MM-DD partition date for a session.""" - raw = getattr(session, "session_start_time", None) - if raw: - try: - return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime( - "%Y-%m-%d" - ) - except ValueError: - logger.debug("Could not parse session_start_time=%r", raw) - return datetime.utcnow().strftime("%Y-%m-%d") - - -def _collection_session_id_fragment(session: Any) -> str: - safe_id = _safe_session_id(session) - if _UUID_SESSION_ID_RE.match(safe_id): - return safe_id[:8] - stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" - max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem)) - if len(safe_id) <= max_id_length: - return safe_id - return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length] - - -def artifact_collection_title(session: Any) -> str: - return ( - f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" - f"{_collection_session_id_fragment(session)}" - ) - - -def _artifact_key(repo_id: str, repo_type: str | None) -> str: - return f"{repo_type or 'model'}:{repo_id}" - - -def _sandbox_space_name_pattern() -> str: - from agent.tools.sandbox_tool import SANDBOX_SPACE_NAME_RE - - return SANDBOX_SPACE_NAME_RE.pattern - - -def is_sandbox_hub_repo(repo_id: str | None, repo_type: str | None) -> bool: - """Return True for ML Intern's ephemeral sandbox Space repos.""" - if (repo_type or "model") != "space" or not repo_id: - return False - repo_name = str(repo_id).rsplit("/", 1)[-1] - return bool(re.fullmatch(_sandbox_space_name_pattern(), repo_name)) - - -def _session_artifact_set(session: Any, attr: str) -> set[str]: - current = getattr(session, attr, None) - if isinstance(current, set): - return current - current = set() - try: - setattr(session, attr, current) - except Exception: - logger.warning( - "Could not attach %s to session; using process-local fallback state", - attr, - ) - return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set()) - return current - - -def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None: - if session is None or not repo_id: - return - _session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add( - _artifact_key(repo_id, repo_type) - ) - - -def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool: - if session is None or not repo_id: - return False - return _artifact_key(repo_id, repo_type) in _session_artifact_set( - session, _KNOWN_ARTIFACTS_ATTR - ) - - -def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]: - merged = dict(metadata) - raw_tags = merged.get("tags") - if raw_tags is None: - tags: list[str] = [] - elif isinstance(raw_tags, str): - tags = [raw_tags] - elif isinstance(raw_tags, list): - tags = [str(item) for item in raw_tags] - else: - tags = [str(raw_tags)] - - if tag not in tags: - tags.append(tag) - merged["tags"] = tags - return merged - - -def _metadata_from_content(content: str) -> dict[str, Any]: - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) / "README.md" - path.write_text(content, encoding="utf-8") - return metadata_load(path) or {} - - -def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str: - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) / "README.md" - path.write_text(content, encoding="utf-8") - metadata_save(path, metadata) - return path.read_text(encoding="utf-8") - - -def _body_without_metadata(content: str) -> str: - return _FRONT_MATTER_RE.sub("", content, count=1).strip() - - -def _append_section(content: str, section: str) -> str: - base = content.rstrip() - if base: - return f"{base}\n\n{section.strip()}\n" - return f"{section.strip()}\n" - - -def _provenance_section(repo_type: str) -> str: - label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub") - return f"""{PROVENANCE_MARKER} -## Generated by ML Intern - -This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. - -- Try ML Intern: https://smolagents-ml-intern.hf.space -- Source code: https://github.com/huggingface/ml-intern -""" - - -def _usage_section(repo_id: str, repo_type: str) -> str: - if repo_type == "dataset": - return f"""## Usage - -```python -from datasets import load_dataset - -dataset = load_dataset("{repo_id}") -``` -""" - - return f"""## Usage - -```python -from transformers import AutoModelForCausalLM, AutoTokenizer - -model_id = "{repo_id}" -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id) -``` - -For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class. -""" - - -def augment_repo_card_content( - content: str | None, - repo_id: str, - repo_type: str = "model", - *, - extra_metadata: dict[str, Any] | None = None, -) -> str: - """Return README content with ML Intern metadata and provenance added.""" - repo_type = repo_type or "model" - content = content or "" - metadata = _metadata_from_content(content) - if extra_metadata: - metadata = {**extra_metadata, **metadata} - metadata = _merge_tags(metadata) - updated = _content_with_metadata(content, metadata) - - if not _body_without_metadata(updated): - updated = _append_section(updated, f"# {repo_id}") - - if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated: - updated = _append_section(updated, _provenance_section(repo_type)) - if not _USAGE_HEADING_RE.search(content): - updated = _append_section(updated, _usage_section(repo_id, repo_type)) - - return updated - - -def _read_remote_readme( - api: Any, - repo_id: str, - repo_type: str, - *, - token: str | bool | None = None, -) -> str: - token_value = token if token is not None else getattr(api, "token", None) - try: - readme_path = hf_hub_download( - repo_id=repo_id, - filename="README.md", - repo_type=repo_type, - token=token_value, - ) - except (EntryNotFoundError, RepositoryNotFoundError): - return "" - return Path(readme_path).read_text(encoding="utf-8") - - -def _update_repo_card( - api: Any, - repo_id: str, - repo_type: str, - *, - token: str | bool | None = None, - extra_metadata: dict[str, Any] | None = None, -) -> None: - current = _read_remote_readme(api, repo_id, repo_type, token=token) - updated = augment_repo_card_content( - current, - repo_id, - repo_type, - extra_metadata=extra_metadata, - ) - if updated == current: - return - api.upload_file( - path_or_fileobj=updated.encode("utf-8"), - path_in_repo="README.md", - repo_id=repo_id, - repo_type=repo_type, - token=token, - commit_message="Update ML Intern artifact metadata", - ) - - -def _ensure_collection_slug( - api: Any, - session: Any, - *, - token: str | bool | None = None, -) -> str | None: - slug = getattr(session, _COLLECTION_SLUG_ATTR, None) - if slug: - return slug - - title = artifact_collection_title(session) - collection = api.create_collection( - title=title, - description=( - f"Artifacts generated by ML Intern session {_safe_session_id(session)} " - f"on {session_artifact_date(session)}." - ), - private=True, - exists_ok=True, - token=token, - ) - slug = getattr(collection, "slug", None) - if slug: - setattr(session, _COLLECTION_SLUG_ATTR, slug) - return slug - - -def _add_to_collection( - api: Any, - session: Any, - repo_id: str, - repo_type: str, - *, - token: str | bool | None = None, -) -> bool: - slug = _ensure_collection_slug(api, session, token=token) - if not slug: - return False - api.add_collection_item( - collection_slug=slug, - item_id=repo_id, - item_type=repo_type, - note=( - f"Generated by ML Intern session {_safe_session_id(session)} " - f"on {session_artifact_date(session)}." - ), - exists_ok=True, - token=token, - ) - return True - - -def register_hub_artifact( - api: Any, - repo_id: str, - repo_type: str = "model", - *, - session: Any = None, - token: str | bool | None = None, - extra_metadata: dict[str, Any] | None = None, - force: bool = False, -) -> bool: - """Tag, card, and collection-register a Hub artifact without raising.""" - if session is None or not repo_id: - return False - repo_type = repo_type or "model" - if repo_type not in SUPPORTED_REPO_TYPES: - return False - if is_sandbox_hub_repo(repo_id, repo_type): - return False - - key = _artifact_key(repo_id, repo_type) - remember_hub_artifact(session, repo_id, repo_type) - registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR) - if key in registered and not force: - return True - - token_value = token if token is not None else getattr(api, "token", None) - card_updated = False - collection_updated = False - try: - _update_repo_card( - api, - repo_id, - repo_type, - token=token_value, - extra_metadata=extra_metadata, - ) - card_updated = True - except Exception as e: - logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e) - - try: - collection_updated = _add_to_collection( - api, - session, - repo_id, - repo_type, - token=token_value, - ) - except Exception as e: - logger.debug("ML Intern collection update failed for %s: %s", repo_id, e) - - if card_updated and collection_updated: - registered.add(key) - return True - return False - - -def build_hub_artifact_sitecustomize(session: Any) -> str: - """Build standalone sitecustomize.py code for HF Jobs Python processes.""" - if session is None or not getattr(session, "session_id", None): - return "" - - session_id = _safe_session_id(session) - session_date = session_artifact_date(session) - collection_title = artifact_collection_title(session) - collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None) - - return ( - textwrap.dedent( - f""" - # Auto-generated by ML Intern. Best-effort Hub artifact metadata only. - def _install_ml_intern_artifact_hooks(): - import os - import re - import tempfile - from pathlib import Path - - try: - import huggingface_hub as _hub - from huggingface_hub import HfApi, hf_hub_download - from huggingface_hub.repocard import metadata_load, metadata_save - from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError - except Exception: - return - - session_id = {session_id!r} - session_date = {session_date!r} - collection_title = {collection_title!r} - tag = {ML_INTERN_TAG!r} - marker = {PROVENANCE_MARKER!r} - supported = {sorted(SUPPORTED_REPO_TYPES)!r} - sandbox_space_re = re.compile({_sandbox_space_name_pattern()!r}) - registering = False - collection_slug = {collection_slug!r} - registered = set() - usage_re = re.compile( - r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b", - re.IGNORECASE | re.MULTILINE, - ) - front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL) - collection_cache_path = ( - os.environ.get("ML_INTERN_ARTIFACT_COLLECTION_CACHE") - or str( - Path(tempfile.gettempdir()) - / f"ml-intern-artifacts-{{session_id}}.collection" - ) - ) - - def _token(value=None, api=None): - if isinstance(value, str) and value: - return value - api_token = getattr(api, "token", None) - if isinstance(api_token, str) and api_token: - return api_token - return ( - os.environ.get("HF_TOKEN") - or os.environ.get("HUGGINGFACE_HUB_TOKEN") - or None - ) - - def _merge_tags(metadata): - metadata = dict(metadata or {{}}) - raw_tags = metadata.get("tags") - if raw_tags is None: - tags = [] - elif isinstance(raw_tags, str): - tags = [raw_tags] - elif isinstance(raw_tags, list): - tags = [str(item) for item in raw_tags] - else: - tags = [str(raw_tags)] - if tag not in tags: - tags.append(tag) - metadata["tags"] = tags - return metadata - - def _metadata_from_content(content): - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) / "README.md" - path.write_text(content or "", encoding="utf-8") - return metadata_load(path) or {{}} - - def _content_with_metadata(content, metadata): - with tempfile.TemporaryDirectory() as tmp_dir: - path = Path(tmp_dir) / "README.md" - path.write_text(content or "", encoding="utf-8") - metadata_save(path, metadata) - return path.read_text(encoding="utf-8") - - def _body_without_metadata(content): - return front_matter_re.sub("", content or "", count=1).strip() - - def _append_section(content, section): - base = (content or "").rstrip() - if base: - return base + "\\n\\n" + section.strip() + "\\n" - return section.strip() + "\\n" - - def _provenance(repo_type): - label = {{"model": "model", "dataset": "dataset"}}.get( - repo_type, "Hub" - ) - return ( - marker - + "\\n## Generated by ML Intern\\n\\n" - + f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n" - + "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n" - + "- Source code: https://github.com/huggingface/ml-intern\\n" - ) - - def _usage(repo_id, repo_type): - if repo_type == "dataset": - return ( - "## Usage\\n\\n" - "```python\\n" - "from datasets import load_dataset\\n\\n" - f"dataset = load_dataset({{repo_id!r}})\\n" - "```\\n" - ) - return ( - "## Usage\\n\\n" - "```python\\n" - "from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n" - f"model_id = {{repo_id!r}}\\n" - "tokenizer = AutoTokenizer.from_pretrained(model_id)\\n" - "model = AutoModelForCausalLM.from_pretrained(model_id)\\n" - "```\\n\\n" - "For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n" - ) - - def _augment(content, repo_id, repo_type, extra_metadata=None): - metadata = _metadata_from_content(content or "") - if extra_metadata: - metadata = {{**extra_metadata, **metadata}} - updated = _content_with_metadata(content or "", _merge_tags(metadata)) - if not _body_without_metadata(updated): - updated = _append_section(updated, f"# {{repo_id}}") - if repo_type in {{"model", "dataset"}} and marker not in updated: - updated = _append_section(updated, _provenance(repo_type)) - if not usage_re.search(content or ""): - updated = _append_section(updated, _usage(repo_id, repo_type)) - return updated - - def _readme(api, repo_id, repo_type, token_value): - try: - path = hf_hub_download( - repo_id=repo_id, - filename="README.md", - repo_type=repo_type, - token=token_value, - ) - except (EntryNotFoundError, RepositoryNotFoundError): - return "" - return Path(path).read_text(encoding="utf-8") - - def _ensure_collection(api, token_value): - nonlocal collection_slug - if collection_slug: - return collection_slug - try: - cached_slug = Path(collection_cache_path).read_text( - encoding="utf-8" - ).strip() - if cached_slug: - collection_slug = cached_slug - return collection_slug - except Exception: - pass - collection = api.create_collection( - title=collection_title, - description=( - f"Artifacts generated by ML Intern session {{session_id}} " - f"on {{session_date}}." - ), - private=True, - exists_ok=True, - token=token_value, - ) - collection_slug = getattr(collection, "slug", None) - if collection_slug: - try: - cache_path = Path(collection_cache_path) - cache_path.parent.mkdir(parents=True, exist_ok=True) - cache_path.write_text(collection_slug, encoding="utf-8") - except Exception: - pass - return collection_slug - - def _register( - repo_id, - repo_type="model", - token_value=None, - extra_metadata=None, - force=False, - ): - nonlocal registering - if registering or not repo_id: - return - repo_type = repo_type or "model" - if repo_type not in supported: - return - if _is_sandbox_repo(repo_id, repo_type): - return - key = f"{{repo_type}}:{{repo_id}}" - if key in registered and not force: - return - registering = True - try: - token_value = _token(token_value) - api = HfApi(token=token_value) - card_updated = False - try: - current = _readme(api, repo_id, repo_type, token_value) - updated = _augment( - current, repo_id, repo_type, extra_metadata=extra_metadata - ) - if updated != current: - _original_upload_file( - api, - path_or_fileobj=updated.encode("utf-8"), - path_in_repo="README.md", - repo_id=repo_id, - repo_type=repo_type, - token=token_value, - commit_message="Update ML Intern artifact metadata", - ) - card_updated = True - except Exception: - pass - collection_updated = False - try: - slug = _ensure_collection(api, token_value) - if slug: - api.add_collection_item( - collection_slug=slug, - item_id=repo_id, - item_type=repo_type, - note=( - f"Generated by ML Intern session {{session_id}} " - f"on {{session_date}}." - ), - exists_ok=True, - token=token_value, - ) - collection_updated = True - except Exception: - pass - if card_updated and collection_updated: - registered.add(key) - finally: - registering = False - - _original_create_repo = HfApi.create_repo - _original_upload_file = HfApi.upload_file - _original_upload_folder = getattr(HfApi, "upload_folder", None) - _original_create_commit = getattr(HfApi, "create_commit", None) - - def _repo_id(args, kwargs): - return kwargs.get("repo_id") or (args[0] if args else None) - - def _repo_type(kwargs): - return kwargs.get("repo_type") or "model" - - def _is_sandbox_repo(repo_id, repo_type): - if (repo_type or "model") != "space" or not repo_id: - return False - repo_name = str(repo_id).rsplit("/", 1)[-1] - return bool(sandbox_space_re.fullmatch(repo_name)) - - def _patched_create_repo(self, *args, **kwargs): - result = _original_create_repo(self, *args, **kwargs) - repo_id = _repo_id(args, kwargs) - repo_type = _repo_type(kwargs) - extra = None - if repo_type == "space" and kwargs.get("space_sdk"): - extra = {{"sdk": kwargs.get("space_sdk")}} - _register(repo_id, repo_type, _token(kwargs.get("token"), self), extra) - return result - - def _patched_upload_file(self, *args, **kwargs): - result = _original_upload_file(self, *args, **kwargs) - if not kwargs.get("create_pr"): - force = kwargs.get("path_in_repo") == "README.md" - _register( - kwargs.get("repo_id"), - _repo_type(kwargs), - _token(kwargs.get("token"), self), - force=force, - ) - return result - - def _patched_upload_folder(self, *args, **kwargs): - result = _original_upload_folder(self, *args, **kwargs) - if not kwargs.get("create_pr"): - _register( - kwargs.get("repo_id"), - _repo_type(kwargs), - _token(kwargs.get("token"), self), - force=True, - ) - return result - - def _patched_create_commit(self, *args, **kwargs): - result = _original_create_commit(self, *args, **kwargs) - if not kwargs.get("create_pr"): - _register( - _repo_id(args, kwargs), - _repo_type(kwargs), - _token(kwargs.get("token"), self), - force=True, - ) - return result - - HfApi.create_repo = _patched_create_repo - HfApi.upload_file = _patched_upload_file - if _original_upload_folder is not None: - HfApi.upload_folder = _patched_upload_folder - if _original_create_commit is not None: - HfApi.create_commit = _patched_create_commit - - def _patch_module_func(name, method_name): - original = getattr(_hub, name, None) - if original is None: - return - method = getattr(HfApi, method_name) - - def _patched(*args, **kwargs): - api = HfApi(token=_token(kwargs.get("token"))) - return method(api, *args, **kwargs) - - setattr(_hub, name, _patched) - - _patch_module_func("create_repo", "create_repo") - _patch_module_func("upload_file", "upload_file") - if _original_upload_folder is not None: - _patch_module_func("upload_folder", "upload_folder") - if _original_create_commit is not None: - _patch_module_func("create_commit", "create_commit") - - try: - _install_ml_intern_artifact_hooks() - except Exception: - pass - """ - ).strip() - + "\n" - ) - - -def wrap_shell_command_with_hub_artifact_bootstrap( - command: str, - session: Any, -) -> str: - """Prefix a shell command so child Python processes load Hub hooks.""" - sitecustomize = build_hub_artifact_sitecustomize(session) - if not sitecustomize or not command: - return command - - encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii") - bootstrap = ( - '_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" ' - f"&& printf %s {shlex.quote(encoded)} | base64 -d " - '> "$_ml_intern_artifacts_dir/sitecustomize.py" ' - '&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"' - ) - return f"{bootstrap}; {command}" diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py deleted file mode 100644 index f95695fb88ff2d6664f3a5be357c97f8b83131d8..0000000000000000000000000000000000000000 --- a/agent/core/llm_params.py +++ /dev/null @@ -1,270 +0,0 @@ -"""LiteLLM kwargs resolution for the model ids this agent accepts. - -Kept separate from ``agent_loop`` so tools (research, context compaction, etc.) -can import it without pulling in the whole agent loop / tool router and -creating circular imports. -""" - -import os - -from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token -from agent.core.local_models import ( - LOCAL_MODEL_API_KEY_DEFAULT, - LOCAL_MODEL_API_KEY_ENV, - LOCAL_MODEL_BASE_URL_ENV, - is_reserved_local_model_id, - local_model_name, - local_model_provider, -) - - -def _resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: - """Backward-compatible private wrapper used by tests and older imports.""" - return resolve_hf_router_token(session_hf_token) - - -def _patch_litellm_effort_validation() -> None: - """Neuter LiteLLM 1.83's hardcoded effort-level validation. - - Context: at ``litellm/llms/anthropic/chat/transformation.py:~1443`` the - Anthropic adapter validates ``output_config.effort ∈ {high, medium, - low, max}`` and gates ``max`` behind an ``_is_opus_4_6_model`` check - that only matches the substring ``opus-4-6`` / ``opus_4_6``. Result: - - * ``xhigh`` β€” valid on Anthropic's real API for Claude 4.7 β€” is - rejected pre-flight with "Invalid effort value: xhigh". - * ``max`` on Opus 4.7 is rejected with "effort='max' is only supported - by Claude Opus 4.6", even though Opus 4.7 accepts it in practice. - - We don't want to maintain a parallel model table, so we let the - Anthropic API itself be the validator: widen ``_is_opus_4_6_model`` - to also match ``opus-4-7``+ families, and drop the valid-effort-set - check entirely. If Anthropic rejects an effort level, we see a 400 - and the cascade walks down β€” exactly the behavior we want for any - future model family. - - Removable once litellm ships 1.83.8-stable (which merges PR #25867, - "Litellm day 0 opus 4.7 support") β€” see commit 0868a82 on their main - branch. Until then, this one-time patch is the escape hatch. - """ - try: - from litellm.llms.anthropic.chat import transformation as _t - except Exception: - return - - cfg = getattr(_t, "AnthropicConfig", None) - if cfg is None: - return - - original = getattr(cfg, "_is_opus_4_6_model", None) - if original is None or getattr(original, "_hf_agent_patched", False): - return - - def _widened(model: str) -> bool: - m = model.lower() - # Original 4.6 match plus any future Opus >= 4.6. We only need this - # to return True for families where "max" / "xhigh" are acceptable - # at the API; the cascade handles the case when they're not. - return any( - v in m - for v in ( - "opus-4-6", - "opus_4_6", - "opus-4.6", - "opus_4.6", - "opus-4-7", - "opus_4_7", - "opus-4.7", - "opus_4.7", - ) - ) - - _widened._hf_agent_patched = True # type: ignore[attr-defined] - cfg._is_opus_4_6_model = staticmethod(_widened) - - -_patch_litellm_effort_validation() - - -# Effort levels accepted on the wire. -# Anthropic (4.6+): low | medium | high | xhigh | max (output_config.effort) -# OpenAI direct: minimal | low | medium | high | xhigh (reasoning_effort top-level) -# HF router: low | medium | high (extra_body.reasoning_effort) -# -# We validate *shape* here and let the probe cascade walk down on rejection; -# we deliberately do NOT maintain a per-model capability table. -_ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"} -_OPENAI_EFFORTS = {"minimal", "low", "medium", "high", "xhigh"} -_HF_EFFORTS = {"low", "medium", "high"} - - -class UnsupportedEffortError(ValueError): - """The requested effort isn't valid for this provider's API surface. - - Raised synchronously before any network call so the probe cascade can - skip levels the provider can't accept (e.g. ``max`` on HF router). - """ - - -def _local_api_base(base_url: str) -> str: - base = base_url.strip().rstrip("/") - if base.endswith("/v1"): - return base - return f"{base}/v1" - - -def _resolve_local_model_params( - model_name: str, - reasoning_effort: str | None = None, - strict: bool = False, -) -> dict: - if reasoning_effort and strict: - raise UnsupportedEffortError( - "Local OpenAI-compatible endpoints don't accept reasoning_effort" - ) - - local_name = local_model_name(model_name) - if local_name is None: - raise ValueError(f"Unsupported local model id: {model_name}") - - provider = local_model_provider(model_name) - assert provider is not None - raw_base = ( - os.environ.get(provider["base_url_env"]) - or os.environ.get(LOCAL_MODEL_BASE_URL_ENV) - or provider["base_url_default"] - ) - api_key = ( - os.environ.get(provider["api_key_env"]) - or os.environ.get(LOCAL_MODEL_API_KEY_ENV) - or LOCAL_MODEL_API_KEY_DEFAULT - ) - return { - "model": f"openai/{local_name}", - "api_base": _local_api_base(raw_base), - "api_key": api_key, - } - - -def _resolve_llm_params( - model_name: str, - session_hf_token: str | None = None, - reasoning_effort: str | None = None, - strict: bool = False, -) -> dict: - """ - Build LiteLLM kwargs for a given model id. - - β€’ ``anthropic/`` β€” native thinking config. We bypass LiteLLM's - ``reasoning_effort`` β†’ ``thinking`` mapping (which lags new Claude - releases like 4.7 and sends the wrong API shape). Instead we pass - both ``thinking={"type": "adaptive"}`` and ``output_config= - {"effort": }`` as top-level kwargs β€” LiteLLM's Anthropic - adapter forwards unknown top-level kwargs into the request body - verbatim (confirmed by live probe; ``extra_body`` does NOT work - here because Anthropic's API rejects it as "Extra inputs are not - permitted"). This is the stable API for 4.6 and 4.7. Older - extended-thinking models that only accept ``thinking.type.enabled`` - will reject this; the probe's cascade catches that and falls back - to no thinking. - - β€’ ``openai/`` β€” ``reasoning_effort`` forwarded as a top-level - kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``. - - β€’ ``ollama/``, ``vllm/``, ``lm_studio/``, and - ``llamacpp/`` β€” local OpenAI-compatible endpoints. The id prefix - selects a configurable localhost base URL, and the model suffix is sent - to LiteLLM as ``openai/``. These endpoints don't receive - ``reasoning_effort``. - - β€’ Anything else is treated as a HuggingFace router id. We hit the - auto-routing OpenAI-compatible endpoint at - ``https://router.huggingface.co/v1``. The id can be bare or carry an - HF routing suffix (``:fastest`` / ``:cheapest`` / ``:``). - A leading ``huggingface/`` is stripped. ``reasoning_effort`` is - forwarded via ``extra_body`` (LiteLLM's OpenAI adapter refuses it as - a top-level kwarg for non-OpenAI models). "minimal" normalizes to - "low". - - ``strict=True`` raises ``UnsupportedEffortError`` when the requested - effort isn't in the provider's accepted set, instead of silently - dropping it. The probe cascade uses strict mode so it can walk down - (``max`` β†’ ``xhigh`` β†’ ``high`` …) without making an API call. Regular - runtime callers leave ``strict=False``, so a stale cached effort - can't crash a turn β€” it just doesn't get sent. - - Token precedence (first non-empty wins): - 1. INFERENCE_TOKEN env β€” shared key on the hosted Space (inference is - free for users, billed to the Space owner via ``X-HF-Bill-To``). - 2. session.hf_token β€” the user's own token (CLI / OAuth / cache file). - 3. huggingface_hub cache β€” ``HF_TOKEN`` / ``HUGGING_FACE_HUB_TOKEN`` / - local ``hf auth login`` cache. - """ - if model_name.startswith("anthropic/"): - params: dict = {"model": model_name} - if reasoning_effort: - level = reasoning_effort - if level == "minimal": - level = "low" - if level not in _ANTHROPIC_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"Anthropic doesn't accept effort={level!r}" - ) - else: - # Adaptive thinking + output_config.effort is the stable - # Anthropic API for Claude 4.6 / 4.7. Both kwargs are - # passed top-level: LiteLLM forwards unknown params into - # the request body for Anthropic, so ``output_config`` - # reaches the API. ``extra_body`` does NOT work here β€” - # Anthropic rejects it as "Extra inputs are not - # permitted". - params["thinking"] = {"type": "adaptive"} - params["output_config"] = {"effort": level} - return params - - if model_name.startswith("bedrock/"): - # LiteLLM routes ``bedrock/...`` through the Converse adapter, which - # picks up AWS credentials from the standard env vars - # (``AWS_ACCESS_KEY_ID`` / ``AWS_SECRET_ACCESS_KEY`` / ``AWS_REGION``). - # The Anthropic thinking/effort shape is not forwarded through Converse - # the same way, so we leave it off for now. - return {"model": model_name} - - if model_name.startswith("openai/"): - params = {"model": model_name} - if reasoning_effort: - if reasoning_effort not in _OPENAI_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"OpenAI doesn't accept effort={reasoning_effort!r}" - ) - else: - params["reasoning_effort"] = reasoning_effort - return params - - if is_reserved_local_model_id(model_name): - raise ValueError(f"Unsupported local model id: {model_name}") - - if local_model_provider(model_name) is not None: - return _resolve_local_model_params(model_name, reasoning_effort, strict) - - hf_model = model_name.removeprefix("huggingface/") - api_key = _resolve_hf_router_token(session_hf_token) - params = { - "model": f"openai/{hf_model}", - "api_base": "https://router.huggingface.co/v1", - "api_key": api_key, - } - if bill_to := get_hf_bill_to(): - params["extra_headers"] = {"X-HF-Bill-To": bill_to} - if reasoning_effort: - hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort - if hf_level not in _HF_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"HF router doesn't accept effort={hf_level!r}" - ) - else: - params["extra_body"] = {"reasoning_effort": hf_level} - return params diff --git a/agent/core/local_models.py b/agent/core/local_models.py deleted file mode 100644 index 9f8a9491d635dd3892388ebfdd0f8384ac78144f..0000000000000000000000000000000000000000 --- a/agent/core/local_models.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Helpers for CLI local OpenAI-compatible model ids.""" - -LOCAL_MODEL_PROVIDERS: dict[str, dict[str, str]] = { - "ollama/": { - "base_url_env": "OLLAMA_BASE_URL", - "base_url_default": "http://localhost:11434", - "api_key_env": "OLLAMA_API_KEY", - }, - "vllm/": { - "base_url_env": "VLLM_BASE_URL", - "base_url_default": "http://localhost:8000", - "api_key_env": "VLLM_API_KEY", - }, - "lm_studio/": { - "base_url_env": "LMSTUDIO_BASE_URL", - "base_url_default": "http://127.0.0.1:1234", - "api_key_env": "LMSTUDIO_API_KEY", - }, - "llamacpp/": { - "base_url_env": "LLAMACPP_BASE_URL", - "base_url_default": "http://localhost:8080", - "api_key_env": "LLAMACPP_API_KEY", - }, -} - -LOCAL_MODEL_PREFIXES = tuple(LOCAL_MODEL_PROVIDERS) -RESERVED_LOCAL_MODEL_PREFIXES = ("openai-compat/",) -LOCAL_MODEL_BASE_URL_ENV = "LOCAL_LLM_BASE_URL" -LOCAL_MODEL_API_KEY_ENV = "LOCAL_LLM_API_KEY" -LOCAL_MODEL_API_KEY_DEFAULT = "sk-local-no-key-required" - - -def local_model_provider(model_id: str) -> dict[str, str] | None: - """Return provider config for a local model id, if it uses a local prefix.""" - for prefix, config in LOCAL_MODEL_PROVIDERS.items(): - if model_id.startswith(prefix): - return config - return None - - -def local_model_name(model_id: str) -> str | None: - """Return the backend model name with the local provider prefix removed.""" - for prefix in LOCAL_MODEL_PREFIXES: - if model_id.startswith(prefix): - name = model_id[len(prefix) :] - return name or None - return None - - -def is_local_model_id(model_id: str) -> bool: - """Return True for non-empty, whitespace-free local model ids.""" - if not model_id or any(char.isspace() for char in model_id): - return False - return local_model_name(model_id) is not None - - -def is_reserved_local_model_id(model_id: str) -> bool: - """Return True for local-style prefixes intentionally not supported.""" - return model_id.startswith(RESERVED_LOCAL_MODEL_PREFIXES) diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py deleted file mode 100644 index 34eaccdd1f127253bec68b4ccdd1159c7a3c4a0a..0000000000000000000000000000000000000000 --- a/agent/core/model_switcher.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Model-switching logic for the interactive CLI's ``/model`` command. - -Split out of ``agent.main`` so the REPL dispatcher stays focused on input -parsing. Exposes: - -* ``SUGGESTED_MODELS`` β€” the short list shown by ``/model`` with no arg. -* ``is_valid_model_id`` β€” loose format check on user input. -* ``probe_and_switch_model`` β€” async: checks routing, fires a 1-token - probe to resolve the effort cascade, then commits the switch (or - rejects it on hard error). - -The probe's cascade lives in ``agent.core.effort_probe``; this module -glues it to CLI output + session state. -""" - -from __future__ import annotations - -import asyncio - -from litellm import acompletion - -from agent.core.effort_probe import ProbeInconclusive, probe_effort -from agent.core.llm_params import _resolve_llm_params -from agent.core.local_models import ( - LOCAL_MODEL_PREFIXES, - is_local_model_id, - is_reserved_local_model_id, -) - - -# Suggested models shown by `/model` (not a gate). Users can paste any HF -# model id (e.g. "MiniMaxAI/MiniMax-M2.7") or an `anthropic/` / `openai/` -# prefix for direct API access. For HF ids, append ":fastest" / -# ":cheapest" / ":preferred" / ":" to override the default -# routing policy (auto = fastest with failover). -SUGGESTED_MODELS = [ - {"id": "openai/gpt-5.5", "label": "GPT-5.5"}, - {"id": "openai/gpt-5.4", "label": "GPT-5.4"}, - {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"}, - {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"}, - { - "id": "bedrock/us.anthropic.claude-opus-4-6-v1", - "label": "Claude Opus 4.6 via Bedrock", - }, - {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, - {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, - {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, - {"id": "deepseek-ai/DeepSeek-V4-Pro:deepinfra", "label": "DeepSeek V4 Pro"}, -] - - -_ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} -_DIRECT_PREFIXES = ("anthropic/", "openai/", *LOCAL_MODEL_PREFIXES) -_LOCAL_PROBE_TIMEOUT = 15.0 - - -def is_valid_model_id(model_id: str) -> bool: - """Loose format check β€” lets users pick any model id. - - Accepts: - β€’ anthropic/ - β€’ openai/ - β€’ ollama/, vllm/, lm_studio/, llamacpp/ - β€’ /[:] (HF router; tag = provider or policy) - β€’ huggingface//[:] (same, accepts legacy prefix) - - Actual availability is verified against the HF router catalog on - switch, and by the provider on the probe's ping call. - """ - if not model_id: - return False - if is_local_model_id(model_id): - return True - if is_reserved_local_model_id(model_id): - return False - if any(model_id.startswith(prefix) for prefix in LOCAL_MODEL_PREFIXES): - return False - if "/" not in model_id: - return False - head = model_id.split(":", 1)[0] - parts = head.split("/") - return len(parts) >= 2 and all(parts) - - -def _print_hf_routing_info(model_id: str, console) -> bool: - """Show HF router catalog info (providers, price, context, tool support) - for an HF-router model id. Returns ``True`` to signal the caller can - proceed with the switch, ``False`` to indicate a hard problem the user - should notice before we fire the effort probe. - - Anthropic / OpenAI ids return ``True`` without printing anything β€” - the probe below covers "does this model exist". - """ - if model_id.startswith(_DIRECT_PREFIXES): - return True - - from agent.core import hf_router_catalog as cat - - bare, _, tag = model_id.partition(":") - info = cat.lookup(bare) - if info is None: - console.print( - f"[bold red]Warning:[/bold red] '{bare}' isn't in the HF router " - "catalog. Checking anyway β€” first call may fail." - ) - suggestions = cat.fuzzy_suggest(bare) - if suggestions: - console.print(f"[dim]Did you mean: {', '.join(suggestions)}[/dim]") - return True - - live = info.live_providers - if not live: - console.print( - f"[bold red]Warning:[/bold red] '{bare}' has no live providers " - "right now. First call will likely fail." - ) - return True - - if tag and tag not in _ROUTING_POLICIES: - matched = [p for p in live if p.provider == tag] - if not matched: - names = ", ".join(p.provider for p in live) - console.print( - f"[bold red]Warning:[/bold red] provider '{tag}' doesn't serve " - f"'{bare}'. Live providers: {names}. Checking anyway." - ) - - if not info.any_supports_tools: - console.print( - f"[bold red]Warning:[/bold red] no provider for '{bare}' advertises " - "tool-call support. This agent relies on tool calls β€” expect errors." - ) - - if tag in _ROUTING_POLICIES: - policy = tag - elif tag: - policy = f"pinned to {tag}" - else: - policy = "auto (fastest)" - console.print(f" [dim]routing: {policy}[/dim]") - for p in live: - price = ( - f"${p.input_price:g}/${p.output_price:g} per M tok" - if p.input_price is not None and p.output_price is not None - else "price n/a" - ) - ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a" - tools = "tools" if p.supports_tools else "no tools" - console.print(f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]") - return True - - -def print_model_listing(config, console) -> None: - """Render the default ``/model`` (no-arg) view: current + suggested.""" - current = config.model_name if config else "" - console.print("[bold]Current model:[/bold]") - console.print(f" {current}") - console.print("\n[bold]Suggested:[/bold]") - for m in SUGGESTED_MODELS: - marker = " [dim]<-- current[/dim]" if m["id"] == current else "" - console.print(f" {m['id']} [dim]({m['label']})[/dim]{marker}") - console.print( - "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" - "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" - "Use 'anthropic/' or 'openai/' for direct API access.\n" - "Use 'ollama/', 'vllm/', 'lm_studio/', or " - "'llamacpp/' for local OpenAI-compatible endpoints.[/dim]" - ) - - -def print_invalid_id(arg: str, console) -> None: - console.print(f"[bold red]Invalid model id format:[/bold red] {arg}") - console.print( - "[dim]Expected:\n" - " β€’ /[:tag] (HF router β€” paste from huggingface.co)\n" - " β€’ anthropic/\n" - " β€’ openai/\n" - " β€’ ollama/ | vllm/ | lm_studio/ | llamacpp/[/dim]" - ) - - -async def _probe_local_model(model_id: str) -> None: - params = _resolve_llm_params(model_id) - await asyncio.wait_for( - acompletion( - messages=[{"role": "user", "content": "ping"}], - max_tokens=1, - stream=False, - **params, - ), - timeout=_LOCAL_PROBE_TIMEOUT, - ) - - -async def probe_and_switch_model( - model_id: str, - config, - session, - console, - hf_token: str | None, -) -> None: - """Validate model+effort with a 1-token ping, cache the effective effort, - then commit the switch. - - Three visible outcomes: - - * βœ“ ``effort: `` β€” model accepted the preferred effort (or a - fallback from the cascade; the note explains if so) - * βœ“ ``effort: off`` β€” model doesn't support thinking; we'll strip it - * βœ— hard error (auth, model-not-found, quota) β€” we reject the switch - and keep the current model so the user isn't stranded - - For non-local models, transient errors (5xx, timeout) complete the switch - with a yellow warning; the next real call re-surfaces the error if it's - persistent. Local models reject every probe error, including timeouts, and - keep the current model. - """ - if is_local_model_id(model_id): - console.print(f"[dim]checking local model {model_id}...[/dim]") - try: - await _probe_local_model(model_id) - except Exception as e: - console.print(f"[bold red]Switch failed:[/bold red] {e}") - console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") - return - - _commit_switch(model_id, config, session, effective=None, cache=True) - console.print( - f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" - ) - return - - preference = config.reasoning_effort - if not _print_hf_routing_info(model_id, console): - return - - if not preference: - # Nothing to validate with a ping that we couldn't validate on the - # first real call just as cheaply. Skip the probe entirely. - _commit_switch(model_id, config, session, effective=None, cache=False) - console.print( - f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" - ) - return - - console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]") - try: - outcome = await probe_effort(model_id, preference, hf_token, session=session) - except ProbeInconclusive as e: - _commit_switch(model_id, config, session, effective=None, cache=False) - console.print( - f"[yellow]Model switched to {model_id}[/yellow] " - f"[dim](couldn't validate: {e}; will verify on first message)[/dim]" - ) - return - except Exception as e: - # Hard persistent error β€” auth, unknown model, quota. Don't switch. - console.print(f"[bold red]Switch failed:[/bold red] {e}") - console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") - return - - _commit_switch( - model_id, - config, - session, - effective=outcome.effective_effort, - cache=True, - ) - effort_label = outcome.effective_effort or "off" - suffix = f" β€” {outcome.note}" if outcome.note else "" - console.print( - f"[green]Model switched to {model_id}[/green] " - f"[dim](effort: {effort_label}{suffix}, {outcome.elapsed_ms}ms)[/dim]" - ) - - -def _commit_switch(model_id, config, session, effective, cache: bool) -> None: - """Apply the switch to the session (or bare config if no session yet). - - ``effective`` is the probe's resolved effort; ``cache=True`` stores it - in the session's per-model cache so real calls use the resolved level - instead of re-probing. ``cache=False`` (inconclusive probe / effort - off) leaves the cache untouched β€” next call falls back to preference. - """ - if session is not None: - session.update_model(model_id) - if cache: - session.model_effective_effort[model_id] = effective - else: - session.model_effective_effort.pop(model_id, None) - else: - config.model_name = model_id diff --git a/agent/core/prompt_caching.py b/agent/core/prompt_caching.py deleted file mode 100644 index b30edd9fc4845738c08e972fdab712bf2ae3988d..0000000000000000000000000000000000000000 --- a/agent/core/prompt_caching.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Anthropic prompt caching breakpoints for outgoing LLM requests. - -Caching is GA on Anthropic's API and natively supported by litellm >=1.83 -via ``cache_control`` blocks. We apply two breakpoints (out of 4 allowed): - - 1. The tool block β€” caches all tool definitions as a single prefix. - 2. The system message β€” caches the rendered system prompt. - -Together these cover the ~4-5K static tokens that were being re-billed on -every turn. Subsequent turns within the 5-minute TTL hit cache_read pricing -(~10% of input cost) instead of full input. - -Non-Anthropic models (HF router, OpenAI) are passed through unchanged. -""" - -from typing import Any - - -def with_prompt_caching( - messages: list[Any], - tools: list[dict] | None, - model_name: str | None, -) -> tuple[list[Any], list[dict] | None]: - """Return (messages, tools) with cache_control breakpoints for Anthropic. - - No-op for non-Anthropic models. Original objects are not mutated; a fresh - list with replaced first message and last tool is returned, so callers - that share the underlying ``ContextManager.items`` list don't see their - persisted history rewritten. - """ - if not model_name or "anthropic" not in model_name: - return messages, tools - - if tools: - new_tools = list(tools) - last = dict(new_tools[-1]) - last["cache_control"] = {"type": "ephemeral"} - new_tools[-1] = last - tools = new_tools - - if messages: - first = messages[0] - role = ( - first.get("role") - if isinstance(first, dict) - else getattr(first, "role", None) - ) - if role == "system": - content = ( - first.get("content") - if isinstance(first, dict) - else getattr(first, "content", None) - ) - if isinstance(content, str) and content: - cached_block = [ - { - "type": "text", - "text": content, - "cache_control": {"type": "ephemeral"}, - } - ] - new_first = {"role": "system", "content": cached_block} - messages = [new_first] + list(messages[1:]) - - return messages, tools diff --git a/agent/core/redact.py b/agent/core/redact.py deleted file mode 100644 index 8978942c8a027b56e51acdd0f6485d4e9e0fbbf2..0000000000000000000000000000000000000000 --- a/agent/core/redact.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Secret scrubbing for session trajectories before upload. - -Users frequently paste HF / API / GitHub tokens into the chat, or scripts echo -them via env dumps. This module applies regex-based redaction to any string -value found recursively in a trajectory payload. The goal is best-effort β€” -strict formats are matched; we won't catch free-form leaks like "my password -is hunter2". -""" - -from __future__ import annotations - -import re -from typing import Any - -# Each entry: (compiled regex, replacement placeholder). -# Patterns are conservative: they only match tokens with the canonical prefix -# and a minimum body length so we don't paint over normal text. -_PATTERNS: list[tuple[re.Pattern, str]] = [ - # Hugging Face tokens: hf_[A-Za-z0-9]{30,} - (re.compile(r"hf_[A-Za-z0-9]{30,}"), "[REDACTED_HF_TOKEN]"), - # Anthropic: sk-ant-[A-Za-z0-9_\-]{20,} - (re.compile(r"sk-ant-[A-Za-z0-9_\-]{20,}"), "[REDACTED_ANTHROPIC_KEY]"), - # OpenAI: sk-[A-Za-z0-9]{40,} (legacy + proj keys) - (re.compile(r"sk-(?!ant-)[A-Za-z0-9_\-]{40,}"), "[REDACTED_OPENAI_KEY]"), - # GitHub classic PATs: ghp_, gho_, ghu_, ghs_, ghr_ followed by 36+ chars - (re.compile(r"gh[pousr]_[A-Za-z0-9]{36,}"), "[REDACTED_GITHUB_TOKEN]"), - # GitHub fine-grained PATs: github_pat_ - (re.compile(r"github_pat_[A-Za-z0-9_]{36,}"), "[REDACTED_GITHUB_TOKEN]"), - # AWS access key IDs: AKIA / ASIA + 16 uppercase alnum - (re.compile(r"\b(?:AKIA|ASIA)[A-Z0-9]{16}\b"), "[REDACTED_AWS_KEY_ID]"), - # Generic 'Bearer ' header values - (re.compile(r"(?i)bearer\s+[A-Za-z0-9_\-\.=]{20,}"), "Bearer [REDACTED]"), -] - -# Env-var-like exports: we scrub the value but keep the name so callers can -# still see which secret was referenced. Covers `KEY=value` and `KEY: value` -# when the key looks secret-y. -_SECRETY_NAMES = re.compile( - r"(?i)\b(HF_TOKEN|HUGGINGFACEHUB_API_TOKEN|ANTHROPIC_API_KEY|OPENAI_API_KEY|" - r"GITHUB_TOKEN|AWS_SECRET_ACCESS_KEY|AWS_ACCESS_KEY_ID|PASSWORD|SECRET|API_KEY)" - r"\s*[:=]\s*([^\s\"']+)" -) - - -def scrub_string(s: str) -> str: - """Apply all redaction patterns to a single string. Safe on non-strings.""" - if not isinstance(s, str) or not s: - return s - out = s - for pat, repl in _PATTERNS: - out = pat.sub(repl, out) - out = _SECRETY_NAMES.sub(lambda m: f"{m.group(1)}=[REDACTED]", out) - return out - - -def scrub(obj: Any) -> Any: - """Recursively scrub every string value in a nested dict/list structure. - - Returns a new object β€” inputs are not mutated.""" - if isinstance(obj, str): - return scrub_string(obj) - if isinstance(obj, dict): - return {k: scrub(v) for k, v in obj.items()} - if isinstance(obj, list): - return [scrub(v) for v in obj] - if isinstance(obj, tuple): - return tuple(scrub(v) for v in obj) - return obj diff --git a/agent/core/session.py b/agent/core/session.py index e98778a3ad1b8f77a98f4a0d7373eb690e689d75..14396d559c2ee5ea1fea60b92a0d64f8bb224d1e 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -1,7 +1,6 @@ import asyncio import json import logging -import os import subprocess import sys import uuid @@ -13,47 +12,45 @@ from typing import Any, Optional from agent.config import Config from agent.context_manager.manager import ContextManager -from agent.messaging.gateway import NotificationGateway -from agent.messaging.models import NotificationRequest logger = logging.getLogger(__name__) +# Local max-token lookup β€” avoids litellm.get_max_tokens() which can hang +# on network calls for certain providers (known litellm issue). +_MAX_TOKENS_MAP: dict[str, int] = { + # Anthropic + "anthropic/claude-opus-4-5-20251101": 200_000, + "anthropic/claude-sonnet-4-5-20250929": 200_000, + "anthropic/claude-sonnet-4-20250514": 200_000, + "anthropic/claude-haiku-3-5-20241022": 200_000, + "anthropic/claude-3-5-sonnet-20241022": 200_000, + "anthropic/claude-3-opus-20240229": 200_000, + "huggingface/novita/MiniMaxAI/MiniMax-M2.1": 196_608, + "huggingface/novita/moonshotai/Kimi-K2.5": 262_144, + "huggingface/novita/zai-org/GLM-5": 200_000, +} _DEFAULT_MAX_TOKENS = 200_000 -_TURN_COMPLETE_NOTIFICATION_CHARS = 39000 - -DEFAULT_SESSION_LOG_DIR = Path("session_logs") def _get_max_tokens_safe(model_name: str) -> int: - """Return the max input-context tokens for a model. - - Primary source: ``litellm.get_model_info(model)['max_input_tokens']`` β€” - LiteLLM maintains an upstream catalog that knows Claude Opus 4.6 is - 1M, GPT-5 is 272k, Sonnet 4.5 is 200k, and so on. Strips any HF routing - suffix / huggingface/ prefix so tagged ids ('moonshotai/Kimi-K2.6:cheapest') - look up the bare model. Falls back to a conservative 200k default for - models not in the catalog (typically HF-router-only models). - """ - from litellm import get_model_info - - candidates = [model_name] - stripped = model_name.removeprefix("huggingface/").split(":", 1)[0] - if stripped != model_name: - candidates.append(stripped) - for candidate in candidates: - try: - info = get_model_info(candidate) - max_input = info.get("max_input_tokens") if info else None - if isinstance(max_input, int) and max_input > 0: - return max_input - except Exception: - continue - logger.info( - "No litellm.get_model_info entry for %s, falling back to %d", - model_name, - _DEFAULT_MAX_TOKENS, - ) - return _DEFAULT_MAX_TOKENS + """Return the max context window for a model without network calls.""" + tokens = _MAX_TOKENS_MAP.get(model_name) + if tokens: + return tokens + # Fallback: try litellm but with a short timeout via threading + try: + from litellm import get_max_tokens + + result = get_max_tokens(model_name) + if result and isinstance(result, int): + return result + logger.warning( + f"get_max_tokens returned {result} for {model_name}, using default" + ) + return _DEFAULT_MAX_TOKENS + except Exception as e: + logger.warning(f"get_max_tokens failed for {model_name}, using default: {e}") + return _DEFAULT_MAX_TOKENS class OpType(Enum): @@ -62,7 +59,6 @@ class OpType(Enum): INTERRUPT = "interrupt" UNDO = "undo" COMPACT = "compact" - RESUME = "resume" SHUTDOWN = "shutdown" @@ -70,7 +66,6 @@ class OpType(Enum): class Event: event_type: str data: Optional[dict[str, Any]] = None - seq: Optional[int] = None class Session: @@ -82,80 +77,39 @@ class Session: def __init__( self, event_queue: asyncio.Queue, - config: Config, + config: Config | None = None, tool_router=None, context_manager: ContextManager | None = None, - hf_token: str | None = None, - local_mode: bool = False, - stream: bool = True, - notification_gateway: NotificationGateway | None = None, - notification_destinations: list[str] | None = None, - defer_turn_complete_notification: bool = False, - session_id: str | None = None, - user_id: str | None = None, - hf_username: str | None = None, - persistence_store: Any | None = None, ): - self.hf_token: Optional[str] = hf_token - self.user_id: Optional[str] = user_id - self.hf_username: Optional[str] = hf_username - self.persistence_store = persistence_store self.tool_router = tool_router - self.stream = stream - if config is None: - raise ValueError("Session requires a Config") tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else [] self.context_manager = context_manager or ContextManager( - model_max_tokens=_get_max_tokens_safe(config.model_name), + max_context=_get_max_tokens_safe(config.model_name), compact_size=0.1, untouched_messages=5, tool_specs=tool_specs, - hf_token=hf_token, - local_mode=local_mode, ) self.event_queue = event_queue - self.session_id = session_id or str(uuid.uuid4()) - self.config = config + self.session_id = str(uuid.uuid4()) + self.config = config or Config( + model_name="anthropic/claude-sonnet-4-5-20250929", + ) self.is_running = True - self._cancelled = asyncio.Event() + self.current_task: asyncio.Task | None = None self.pending_approval: Optional[dict[str, Any]] = None - self.sandbox = None - self.sandbox_hardware: Optional[str] = None - self.sandbox_preload_task: Optional[asyncio.Task] = None - self.sandbox_preload_error: Optional[str] = None - self.sandbox_preload_cancel_event: Any | None = None - self._running_job_ids: set[str] = set() # HF job IDs currently executing - self.notification_gateway = notification_gateway - self.notification_destinations = list(notification_destinations or []) - self.defer_turn_complete_notification = defer_turn_complete_notification - self.auto_approval_enabled: bool = False - self.auto_approval_cost_cap_usd: float | None = None - self.auto_approval_estimated_spend_usd: float = 0.0 + # User's HF OAuth token β€” set by session_manager after construction + self.hf_token: Optional[str] = None # Session trajectory logging self.logged_events: list[dict] = [] self.session_start_time = datetime.now().isoformat() self.turn_count: int = 0 self.last_auto_save_turn: int = 0 - # Stable local save path so heartbeat saves overwrite one file instead - # of spamming session_logs/. ``_last_heartbeat_ts`` is owned by - # ``agent.core.telemetry.HeartbeatSaver`` and lazily initialised there. - self._local_save_path: Optional[str] = None - self._last_heartbeat_ts: Optional[float] = None - - # Per-model probed reasoning-effort cache. Populated by the probe - # on /model switch, read by ``effective_effort_for`` below. Keys are - # raw model ids (including any ``:tag``). Values: - # str β†’ the effort level to send (may be a downgrade from the - # preference, e.g. "high" when user asked for "max") - # None β†’ model rejected all efforts in the cascade; send no - # thinking params at all - # Key absent β†’ not probed yet; fall back to the raw preference. - self.model_effective_effort: dict[str, str | None] = {} - self.context_manager.on_message_added = self._schedule_trace_message async def send_event(self, event: Event) -> None: """Send event back to client and log to trajectory""" + await self.event_queue.put(event) + # Log event to trajectory self.logged_events.append( { @@ -164,211 +118,11 @@ class Session: "data": event.data, } ) - if self.persistence_store is not None: - try: - event.seq = await self.persistence_store.append_event( - self.session_id, event.event_type, event.data - ) - except Exception as e: - logger.debug("Event persistence failed for %s: %s", self.session_id, e) - - await self.event_queue.put(event) - await self._enqueue_auto_notification_requests(event) - - # Mid-turn heartbeat flush (owned by telemetry module). - from agent.core.telemetry import HeartbeatSaver - - HeartbeatSaver.maybe_fire(self) - - def _schedule_trace_message(self, message: Any) -> None: - """Best-effort append-only trace save for SFT/KPI export.""" - if self.persistence_store is None: - return - try: - payload = message.model_dump(mode="json") - except Exception: - return - try: - loop = asyncio.get_running_loop() - except RuntimeError: - return - source = str(payload.get("role") or "message") - loop.create_task( - self.persistence_store.append_trace_message( - self.session_id, payload, source=source - ) - ) - def set_notification_destinations(self, destinations: list[str]) -> None: - """Replace the session's opted-in auto-notification destinations.""" - deduped: list[str] = [] - seen: set[str] = set() - for destination in destinations: - if destination not in seen: - deduped.append(destination) - seen.add(destination) - self.notification_destinations = deduped - - async def send_deferred_turn_complete_notification(self, event: Event) -> None: - if event.event_type != "turn_complete": - return - await self._enqueue_auto_notification_requests( - event, - include_deferred_turn_complete=True, - ) - - async def _enqueue_auto_notification_requests( - self, - event: Event, - include_deferred_turn_complete: bool = False, - ) -> None: - if self.notification_gateway is None: - return - if not self.notification_destinations: - return - auto_events = set(self.config.messaging.auto_event_types) - if event.event_type not in auto_events: - return - if ( - self.defer_turn_complete_notification - and event.event_type == "turn_complete" - and not include_deferred_turn_complete - ): - return - - requests = self._build_auto_notification_requests(event) - for request in requests: - await self.notification_gateway.enqueue(request) - - def _build_auto_notification_requests( - self, event: Event - ) -> list[NotificationRequest]: - metadata = { - "session_id": self.session_id, - "model": self.config.model_name, - "event_type": event.event_type, - } - - title: str | None = None - message: str | None = None - severity = "info" - data = event.data or {} - if event.event_type == "approval_required": - tools = data.get("tools", []) - tool_names = [] - for tool in tools if isinstance(tools, list) else []: - if isinstance(tool, dict): - tool_name = str(tool.get("tool") or "").strip() - if tool_name and tool_name not in tool_names: - tool_names.append(tool_name) - count = len(tools) if isinstance(tools, list) else 0 - title = "Agent approval required" - message = ( - f"Session {self.session_id} is waiting for approval " - f"for {count} tool call(s)." - ) - if tool_names: - message += " Tools: " + ", ".join(tool_names) - severity = "warning" - elif event.event_type == "error": - title = "Agent error" - error = str(data.get("error") or "Unknown error") - message = f"Session {self.session_id} hit an error.\n{error[:500]}" - severity = "error" - elif event.event_type == "turn_complete": - title = "Agent task complete" - summary = str(data.get("final_response") or "").strip() - if summary: - summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS] - message = ( - f"Session {self.session_id} completed successfully.\n{summary}" - ) - else: - message = f"Session {self.session_id} completed successfully." - severity = "success" - - if message is None: - return [] - - requests: list[NotificationRequest] = [] - for destination in self.notification_destinations: - if not self.config.messaging.can_auto_send(destination): - continue - requests.append( - NotificationRequest( - destination=destination, - title=title, - message=message, - severity=severity, - metadata=metadata, - event_type=event.event_type, - ) - ) - return requests - - def cancel(self) -> None: - """Signal cancellation to the running agent loop.""" - self._cancelled.set() - - def reset_cancel(self) -> None: - """Clear the cancellation flag before a new run.""" - self._cancelled.clear() - - @property - def is_cancelled(self) -> bool: - return self._cancelled.is_set() - - def update_model(self, model_name: str) -> None: - """Switch the active model and update the context window limit.""" - self.config.model_name = model_name - self.context_manager.model_max_tokens = _get_max_tokens_safe(model_name) - - def set_auto_approval_policy( - self, *, enabled: bool, cost_cap_usd: float | None - ) -> None: - self.auto_approval_enabled = bool(enabled) - self.auto_approval_cost_cap_usd = cost_cap_usd - - def add_auto_approval_estimated_spend(self, amount_usd: float | None) -> None: - if amount_usd is None or amount_usd <= 0: - return - self.auto_approval_estimated_spend_usd = round( - self.auto_approval_estimated_spend_usd + float(amount_usd), 4 - ) - - @property - def auto_approval_remaining_usd(self) -> float | None: - if self.auto_approval_cost_cap_usd is None: - return None - return round( - max( - 0.0, - self.auto_approval_cost_cap_usd - - self.auto_approval_estimated_spend_usd, - ), - 4, - ) - - def auto_approval_policy_summary(self) -> dict[str, Any]: - return { - "enabled": self.auto_approval_enabled, - "cost_cap_usd": self.auto_approval_cost_cap_usd, - "estimated_spend_usd": round(self.auto_approval_estimated_spend_usd, 4), - "remaining_usd": self.auto_approval_remaining_usd, - } - - def effective_effort_for(self, model_name: str) -> str | None: - """Resolve the effort level to actually send for ``model_name``. - - Returns the probed result when we have one (may be ``None`` meaning - "model doesn't do thinking, strip it"), else the raw preference. - Unknown-model case falls back to the preference so a stale cache - from a prior ``/model`` can't poison research sub-calls that use a - different model id. - """ - if model_name in self.model_effective_effort: - return self.model_effective_effort[model_name] - return self.config.reasoning_effort + def interrupt(self) -> None: + """Interrupt current running task""" + if self.current_task and not self.current_task.done(): + self.current_task.cancel() def increment_turn(self) -> None: """Increment turn counter (called after each user interaction)""" @@ -392,36 +146,18 @@ class Session: def get_trajectory(self) -> dict: """Serialize complete session trajectory for logging""" - tools: list = [] - if self.tool_router is not None: - try: - tools = self.tool_router.get_tool_specs_for_llm() or [] - except Exception: - tools = [] - # Sum per-call cost from llm_call events so analyzers don't have to - # walk the events array themselves. Each `llm_call` event already - # carries cost_usd from `agent.core.telemetry.record_llm_call`. - total_cost_usd = sum( - float((e.get("data") or {}).get("cost_usd") or 0.0) - for e in self.logged_events - if e.get("event_type") == "llm_call" - ) return { "session_id": self.session_id, - "user_id": self.user_id, - "hf_username": self.hf_username, "session_start_time": self.session_start_time, "session_end_time": datetime.now().isoformat(), "model_name": self.config.model_name, - "total_cost_usd": total_cost_usd, "messages": [msg.model_dump() for msg in self.context_manager.items], "events": self.logged_events, - "tools": tools, } def save_trajectory_local( self, - directory: str = str(DEFAULT_SESSION_LOG_DIR), + directory: str = "session_logs", upload_status: str = "pending", dataset_url: Optional[str] = None, ) -> Optional[str]: @@ -442,237 +178,78 @@ class Session: trajectory = self.get_trajectory() - # Scrub secrets at save time so session_logs/ never holds raw - # tokens on disk β€” a log aggregator, crash dump, or filesystem - # snapshot between heartbeats would otherwise leak them. - try: - from agent.core.redact import scrub - - for key in ("messages", "events", "tools"): - if key in trajectory: - trajectory[key] = scrub(trajectory[key]) - except Exception as _e: - logger.debug("Redact-on-save failed (non-fatal): %s", _e) - # Add upload metadata trajectory["upload_status"] = upload_status trajectory["upload_url"] = dataset_url trajectory["last_save_time"] = datetime.now().isoformat() - # Reuse one stable path per session so heartbeat saves overwrite - # the same file instead of creating a new timestamped file every - # minute. The timestamp in the filename is kept for first-save - # ordering; subsequent saves just rewrite that file. - if self._local_save_path and Path(self._local_save_path).parent == log_dir: - filepath = Path(self._local_save_path) - else: - filename = ( - f"session_{self.session_id}_" - f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - ) - filepath = log_dir / filename - self._local_save_path = str(filepath) - - # Atomic-ish write: stage to .tmp then rename so a crash mid-write - # doesn't leave a truncated JSON that breaks the retry scanner. - tmp_path = filepath.with_suffix(filepath.suffix + ".tmp") - with open(tmp_path, "w") as f: + filename = f"session_{self.session_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + filepath = log_dir / filename + + with open(filepath, "w") as f: json.dump(trajectory, f, indent=2) - tmp_path.replace(filepath) return str(filepath) except Exception as e: logger.error(f"Failed to save session locally: {e}") return None - def update_local_save_status( - self, filepath: str, upload_status: str, dataset_url: Optional[str] = None - ) -> bool: - """Update the upload status of an existing local save file""" - try: - with open(filepath, "r") as f: - data = json.load(f) - - data["upload_status"] = upload_status - data["upload_url"] = dataset_url - data["last_save_time"] = datetime.now().isoformat() - - with open(filepath, "w") as f: - json.dump(data, f, indent=2) - - return True - except Exception as e: - logger.error(f"Failed to update local save status: {e}") - return False + def save_and_upload_detached(self, repo_id: str) -> Optional[str]: + """ + Save session locally and spawn detached subprocess for upload (fire-and-forget) - def _personal_trace_repo_id(self) -> Optional[str]: - """Resolve the per-user trace repo id from config + HF username. + Args: + repo_id: HuggingFace dataset repo ID - Returns ``None`` when sharing is disabled, the user is anonymous, - or the template is missing β€” caller skips the personal upload in - those cases. + Returns: + Path to local save file """ - if not getattr(self.config, "share_traces", False): - return None - hf_user = self.hf_username or self.user_id - if not hf_user: - return None - template = getattr(self.config, "personal_trace_repo_template", None) - if not template: - return None - try: - return template.format(hf_user=hf_user) - except (KeyError, IndexError): - logger.debug("personal_trace_repo_template format failed: %r", template) + # Save locally first (fast, synchronous) + local_path = self.save_trajectory_local(upload_status="pending") + if not local_path: return None - def _spawn_uploader( - self, - action: str, - target: str, - repo_id: str, - *, - format: str, - token_env: Optional[str], - private: bool, - token_value: Optional[str] = None, - ) -> None: - """Fire-and-forget spawn of ``session_uploader.py`` with the given args.""" + # Spawn detached subprocess for upload (fire-and-forget) try: uploader_script = Path(__file__).parent / "session_uploader.py" - cmd = [ - sys.executable, - str(uploader_script), - action, - target, - repo_id, - "--format", - format, - "--private", - "true" if private else "false", - ] - if token_env: - cmd.extend(["--token-env", token_env]) - - env = os.environ.copy() - if token_value: - env["_ML_INTERN_PERSONAL_TOKEN"] = token_value + # Use Popen with detached process subprocess.Popen( - cmd, + [sys.executable, str(uploader_script), "upload", local_path, repo_id], stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, - env=env, start_new_session=True, # Detach from parent ) except Exception as e: logger.warning(f"Failed to spawn upload subprocess: {e}") - def save_and_upload_detached(self, repo_id: str) -> Optional[str]: - """ - Save session locally and spawn detached subprocess(es) for upload - (fire-and-forget). - - Always uploads to the shared org dataset (``repo_id``) in the - single-row format used by the KPI scheduler. When - ``config.share_traces`` is enabled and a username is known, also - uploads to the user's personal private dataset in Claude Code JSONL - format so the HF Agent Trace Viewer auto-renders it. - - Args: - repo_id: HuggingFace dataset repo ID for the org/KPI upload. - - Returns: - Path to local save file - """ - local_path = self.save_trajectory_local(upload_status="pending") - if not local_path: - return None - - self._spawn_uploader( - "upload", - local_path, - repo_id, - format="row", - token_env=None, # default org token chain - private=False, - ) - - personal_repo = self._personal_trace_repo_id() - if personal_repo: - # User's own HF_TOKEN write-scoped to their namespace. - self._spawn_uploader( - "upload", - local_path, - personal_repo, - format="claude_code", - token_env="HF_TOKEN", - token_value=self.hf_token, - private=True, - ) - return local_path @staticmethod def retry_failed_uploads_detached( - directory: str = str(DEFAULT_SESSION_LOG_DIR), - repo_id: Optional[str] = None, - *, - personal_repo_id: Optional[str] = None, + directory: str = "session_logs", repo_id: Optional[str] = None ) -> None: """ - Spawn detached subprocess(es) to retry failed/pending uploads - (fire-and-forget). + Spawn detached subprocess to retry failed/pending uploads (fire-and-forget) Args: directory: Directory containing session logs - repo_id: Target dataset repo ID for the shared org/KPI upload. - personal_repo_id: Per-user dataset for Claude-Code-format - retries. ``None`` skips the personal retry pass. + repo_id: Target dataset repo ID """ - if not repo_id and not personal_repo_id: + if not repo_id: return try: uploader_script = Path(__file__).parent / "session_uploader.py" - if repo_id: - subprocess.Popen( - [ - sys.executable, - str(uploader_script), - "retry", - directory, - repo_id, - "--format", - "row", - ], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - start_new_session=True, - ) - - if personal_repo_id: - subprocess.Popen( - [ - sys.executable, - str(uploader_script), - "retry", - directory, - personal_repo_id, - "--format", - "claude_code", - "--token-env", - "HF_TOKEN", - "--private", - "true", - ], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - start_new_session=True, - ) + # Spawn detached subprocess for retry + subprocess.Popen( + [sys.executable, str(uploader_script), "retry", directory, repo_id], + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, # Detach from parent + ) except Exception as e: logger.warning(f"Failed to spawn retry subprocess: {e}") diff --git a/agent/core/session_persistence.py b/agent/core/session_persistence.py deleted file mode 100644 index e12467211b16fe12ec75fbf5b60edb5ee54f4072..0000000000000000000000000000000000000000 --- a/agent/core/session_persistence.py +++ /dev/null @@ -1,509 +0,0 @@ -"""Optional durable session persistence for the hosted backend. - -The public CLI must keep working without MongoDB. This module therefore -exposes one small async store interface and returns a no-op implementation -unless ``MONGODB_URI`` is configured and reachable. -""" - -from __future__ import annotations - -import logging -import os -from datetime import UTC, datetime -from typing import Any - -from bson import BSON -from pymongo import AsyncMongoClient, DeleteMany, ReturnDocument, UpdateOne -from pymongo.errors import DuplicateKeyError, InvalidDocument, PyMongoError - -logger = logging.getLogger(__name__) - -SCHEMA_VERSION = 1 -MAX_BSON_BYTES = 15 * 1024 * 1024 - - -def _now() -> datetime: - return datetime.now(UTC) - - -def _doc_id(session_id: str, idx: int) -> str: - return f"{session_id}:{idx}" - - -def _safe_message_doc(message: dict[str, Any]) -> dict[str, Any]: - """Return a Mongo-safe message document payload. - - Mongo's hard document limit is 16 MB. We stay below that and store an - explicit marker rather than failing the whole snapshot for one huge tool log. - """ - try: - if len(BSON.encode({"message": message})) <= MAX_BSON_BYTES: - return message - except (InvalidDocument, OverflowError): - pass - return { - "role": "tool", - "content": ( - "[SYSTEM: A single persisted message exceeded MongoDB's document " - "size/encoding limit and was replaced by this marker.]" - ), - "ml_intern_persistence_error": "message_too_large_or_invalid", - } - - -class NoopSessionStore: - """Async no-op store used when Mongo is not configured.""" - - enabled = False - - async def init(self) -> None: - return None - - async def close(self) -> None: - return None - - async def upsert_session(self, **_: Any) -> None: - return None - - async def save_snapshot(self, **_: Any) -> None: - return None - - async def load_session(self, *_: Any, **__: Any) -> dict[str, Any] | None: - return None - - async def list_sessions(self, *_: Any, **__: Any) -> list[dict[str, Any]]: - return [] - - async def soft_delete_session(self, *_: Any, **__: Any) -> None: - return None - - async def update_session_fields(self, *_: Any, **__: Any) -> None: - return None - - async def append_event(self, *_: Any, **__: Any) -> int | None: - return None - - async def load_events_after(self, *_: Any, **__: Any) -> list[dict[str, Any]]: - return [] - - async def append_trace_message(self, *_: Any, **__: Any) -> int | None: - return None - - async def get_quota(self, *_: Any, **__: Any) -> int | None: - return None - - async def try_increment_quota(self, *_: Any, **__: Any) -> int | None: - return None - - async def refund_quota(self, *_: Any, **__: Any) -> None: - return None - - async def mark_pro_seen(self, *_: Any, **__: Any) -> dict[str, Any] | None: - return None - - -class MongoSessionStore(NoopSessionStore): - """MongoDB-backed session store.""" - - enabled = True - - def __init__(self, uri: str, db_name: str) -> None: - self.uri = uri - self.db_name = db_name - self.enabled = False - self.client: AsyncMongoClient | None = None - self.db = None - - async def init(self) -> None: - try: - self.client = AsyncMongoClient(self.uri, serverSelectionTimeoutMS=3000) - self.db = self.client[self.db_name] - await self.client.admin.command("ping") - await self._create_indexes() - self.enabled = True - logger.info("Mongo session persistence enabled (db=%s)", self.db_name) - except Exception as e: - logger.warning("Mongo session persistence disabled: %s", e) - self.enabled = False - if self.client is not None: - await self.client.close() - self.client = None - self.db = None - - async def close(self) -> None: - if self.client is not None: - await self.client.close() - self.client = None - self.db = None - - async def _create_indexes(self) -> None: - if self.db is None: - return - await self.db.sessions.create_index( - [("user_id", 1), ("visibility", 1), ("updated_at", -1)] - ) - await self.db.sessions.create_index( - [("visibility", 1), ("status", 1), ("last_active_at", -1)] - ) - await self.db.session_messages.create_index( - [("session_id", 1), ("idx", 1)], unique=True - ) - await self.db.session_events.create_index( - [("session_id", 1), ("seq", 1)], unique=True - ) - await self.db.session_trace_messages.create_index( - [("session_id", 1), ("seq", 1)], unique=True - ) - await self.db.session_trace_messages.create_index([("created_at", -1)]) - await self.db.pro_users.create_index([("first_seen_pro_at", -1)]) - - def _ready(self) -> bool: - return bool(self.enabled and self.db is not None) - - async def upsert_session( - self, - *, - session_id: str, - user_id: str, - model: str, - title: str | None = None, - surface: str = "frontend", - created_at: datetime | None = None, - runtime_state: str = "idle", - status: str = "active", - message_count: int = 0, - turn_count: int = 0, - pending_approval: list[dict[str, Any]] | None = None, - claude_counted: bool = False, - notification_destinations: list[str] | None = None, - auto_approval_enabled: bool = False, - auto_approval_cost_cap_usd: float | None = None, - auto_approval_estimated_spend_usd: float = 0.0, - ) -> None: - if not self._ready(): - return - now = _now() - await self.db.sessions.update_one( - {"_id": session_id}, - { - "$setOnInsert": { - "_id": session_id, - "session_id": session_id, - "user_id": user_id, - "surface": surface, - "created_at": created_at or now, - "schema_version": SCHEMA_VERSION, - "visibility": "live", - }, - "$set": { - "title": title, - "model": model, - "status": status, - "runtime_state": runtime_state, - "updated_at": now, - "last_active_at": now, - "message_count": message_count, - "turn_count": turn_count, - "pending_approval": pending_approval or [], - "claude_counted": claude_counted, - "notification_destinations": notification_destinations or [], - "auto_approval_enabled": auto_approval_enabled, - "auto_approval_cost_cap_usd": auto_approval_cost_cap_usd, - "auto_approval_estimated_spend_usd": auto_approval_estimated_spend_usd, - }, - }, - upsert=True, - ) - - async def save_snapshot( - self, - *, - session_id: str, - user_id: str, - model: str, - messages: list[dict[str, Any]], - title: str | None = None, - runtime_state: str = "idle", - status: str = "active", - turn_count: int = 0, - pending_approval: list[dict[str, Any]] | None = None, - claude_counted: bool = False, - created_at: datetime | None = None, - notification_destinations: list[str] | None = None, - auto_approval_enabled: bool = False, - auto_approval_cost_cap_usd: float | None = None, - auto_approval_estimated_spend_usd: float = 0.0, - ) -> None: - if not self._ready(): - return - now = _now() - await self.upsert_session( - session_id=session_id, - user_id=user_id, - model=model, - title=title, - created_at=created_at, - runtime_state=runtime_state, - status=status, - message_count=len(messages), - turn_count=turn_count, - pending_approval=pending_approval, - claude_counted=claude_counted, - notification_destinations=notification_destinations, - auto_approval_enabled=auto_approval_enabled, - auto_approval_cost_cap_usd=auto_approval_cost_cap_usd, - auto_approval_estimated_spend_usd=auto_approval_estimated_spend_usd, - ) - ops: list[Any] = [] - for idx, raw in enumerate(messages): - ops.append( - UpdateOne( - {"_id": _doc_id(session_id, idx)}, - { - "$set": { - "session_id": session_id, - "idx": idx, - "message": _safe_message_doc(raw), - "updated_at": now, - }, - "$setOnInsert": {"created_at": now}, - }, - upsert=True, - ) - ) - ops.append( - DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}}) - ) - try: - if ops: - await self.db.session_messages.bulk_write(ops, ordered=False) - except PyMongoError as e: - logger.warning("Failed to persist session %s snapshot: %s", session_id, e) - - async def load_session( - self, session_id: str, *, include_deleted: bool = False - ) -> dict[str, Any] | None: - if not self._ready(): - return None - meta = await self.db.sessions.find_one({"_id": session_id}) - if not meta: - return None - if meta.get("visibility") == "deleted" and not include_deleted: - return None - cursor = self.db.session_messages.find({"session_id": session_id}).sort( - "idx", 1 - ) - messages = [row.get("message") async for row in cursor] - return {"metadata": meta, "messages": messages} - - async def list_sessions( - self, user_id: str, *, include_deleted: bool = False - ) -> list[dict[str, Any]]: - if not self._ready(): - return [] - query: dict[str, Any] = {"user_id": user_id} - if user_id == "dev": - query = {} - if not include_deleted: - query["visibility"] = {"$ne": "deleted"} - cursor = self.db.sessions.find(query).sort("updated_at", -1) - return [row async for row in cursor] - - async def soft_delete_session(self, session_id: str) -> None: - if not self._ready(): - return - await self.db.sessions.update_one( - {"_id": session_id}, - { - "$set": { - "visibility": "deleted", - "runtime_state": "idle", - "updated_at": _now(), - } - }, - ) - - async def update_session_fields(self, session_id: str, **fields: Any) -> None: - if not self._ready() or not fields: - return - fields["updated_at"] = _now() - await self.db.sessions.update_one({"_id": session_id}, {"$set": fields}) - - async def _next_seq(self, counter_id: str) -> int: - doc = await self.db.counters.find_one_and_update( - {"_id": counter_id}, - {"$inc": {"seq": 1}}, - upsert=True, - return_document=ReturnDocument.AFTER, - ) - return int(doc["seq"]) - - async def append_event( - self, session_id: str, event_type: str, data: dict[str, Any] | None - ) -> int | None: - if not self._ready(): - return None - try: - seq = await self._next_seq(f"event:{session_id}") - await self.db.session_events.insert_one( - { - "_id": _doc_id(session_id, seq), - "session_id": session_id, - "seq": seq, - "event_type": event_type, - "data": data or {}, - "created_at": _now(), - } - ) - return seq - except PyMongoError as e: - logger.debug("Failed to append event for %s: %s", session_id, e) - return None - - async def load_events_after( - self, session_id: str, after_seq: int = 0 - ) -> list[dict[str, Any]]: - if not self._ready(): - return [] - cursor = self.db.session_events.find( - {"session_id": session_id, "seq": {"$gt": int(after_seq or 0)}} - ).sort("seq", 1) - return [row async for row in cursor] - - async def append_trace_message( - self, session_id: str, message: dict[str, Any], source: str = "message" - ) -> int | None: - if not self._ready(): - return None - try: - seq = await self._next_seq(f"trace:{session_id}") - await self.db.session_trace_messages.insert_one( - { - "_id": _doc_id(session_id, seq), - "session_id": session_id, - "seq": seq, - "role": message.get("role"), - "message": _safe_message_doc(message), - "source": source, - "created_at": _now(), - } - ) - return seq - except PyMongoError as e: - logger.debug("Failed to append trace message for %s: %s", session_id, e) - return None - - async def get_quota(self, user_id: str, day: str) -> int | None: - if not self._ready(): - return None - doc = await self.db.claude_quotas.find_one({"_id": f"{user_id}:{day}"}) - return int(doc.get("count", 0)) if doc else 0 - - async def try_increment_quota(self, user_id: str, day: str, cap: int) -> int | None: - if not self._ready(): - return None - key = f"{user_id}:{day}" - now = _now() - try: - await self.db.claude_quotas.insert_one( - { - "_id": key, - "user_id": user_id, - "day": day, - "count": 1, - "updated_at": now, - } - ) - return 1 - except DuplicateKeyError: - pass - doc = await self.db.claude_quotas.find_one_and_update( - {"_id": key, "count": {"$lt": cap}}, - {"$inc": {"count": 1}, "$set": {"updated_at": now}}, - return_document=ReturnDocument.AFTER, - ) - return int(doc["count"]) if doc else None - - async def refund_quota(self, user_id: str, day: str) -> None: - if not self._ready(): - return - await self.db.claude_quotas.update_one( - {"_id": f"{user_id}:{day}", "count": {"$gt": 0}}, - {"$inc": {"count": -1}, "$set": {"updated_at": _now()}}, - ) - - async def mark_pro_seen( - self, user_id: str, *, is_pro: bool - ) -> dict[str, Any] | None: - """Track per-user Pro state and detect freeβ†’Pro conversions. - - Returns ``{"converted": True, "first_seen_at": ..."}`` exactly once - per user β€” the first time we see them as Pro after having recorded - them as non-Pro at least once. Otherwise returns ``None``. - - Storing ``ever_non_pro`` lets us distinguish "user joined as Pro" - (no conversion) from "user upgraded" (conversion). The atomic - ``find_one_and_update`` on a guarded filter makes the conversion - emit at-most-once even under concurrent requests. - """ - if not self._ready() or not user_id: - return None - now = _now() - set_fields: dict[str, Any] = {"last_seen_at": now, "is_pro": bool(is_pro)} - if not is_pro: - set_fields["ever_non_pro"] = True - try: - await self.db.pro_users.update_one( - {"_id": user_id}, - { - "$setOnInsert": {"_id": user_id, "first_seen_at": now}, - "$set": set_fields, - }, - upsert=True, - ) - except PyMongoError as e: - logger.debug("mark_pro_seen upsert failed for %s: %s", user_id, e) - return None - - if not is_pro: - return None - - try: - doc = await self.db.pro_users.find_one_and_update( - { - "_id": user_id, - "ever_non_pro": True, - "first_seen_pro_at": {"$exists": False}, - }, - {"$set": {"first_seen_pro_at": now}}, - return_document=ReturnDocument.AFTER, - ) - except PyMongoError as e: - logger.debug("mark_pro_seen conversion check failed for %s: %s", user_id, e) - return None - - if not doc: - return None - return { - "converted": True, - "first_seen_at": (doc.get("first_seen_at") or now).isoformat(), - } - - -_store: NoopSessionStore | MongoSessionStore | None = None - - -def get_session_store() -> NoopSessionStore | MongoSessionStore: - global _store - if _store is None: - uri = os.environ.get("MONGODB_URI") - db_name = os.environ.get("MONGODB_DB", "ml-intern") - _store = MongoSessionStore(uri, db_name) if uri else NoopSessionStore() - return _store - - -def _reset_store_for_tests( - store: NoopSessionStore | MongoSessionStore | None = None, -) -> None: - global _store - _store = store diff --git a/agent/core/session_resume.py b/agent/core/session_resume.py deleted file mode 100644 index 941c426b7b216e099de204054843953eb70fd697..0000000000000000000000000000000000000000 --- a/agent/core/session_resume.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Reload a previously saved session log into the active CLI session.""" - -from __future__ import annotations - -import json -import logging -import re -from dataclasses import dataclass -from datetime import datetime -from pathlib import Path -from typing import Any - -from litellm import Message - -from agent.core.model_switcher import is_valid_model_id -from agent.core.session import DEFAULT_SESSION_LOG_DIR - -logger = logging.getLogger(__name__) - -_REDACTED_MARKER = re.compile(r"\[REDACTED_[A-Z_]+\]") - - -@dataclass -class SessionLogEntry: - """Metadata for a locally saved session log.""" - - path: Path - session_id: str - session_start_time: str | None - session_end_time: str | None - model_name: str | None - message_count: int - preview: str - mtime: float - - -def _message_preview(content: Any, max_chars: int = 72) -> str: - """Return a one-line preview for string or OpenAI-style block content.""" - if isinstance(content, str): - text = content - elif isinstance(content, list): - parts: list[str] = [] - for block in content: - if isinstance(block, dict): - value = block.get("text") or block.get("content") - if isinstance(value, str): - parts.append(value) - elif isinstance(block, str): - parts.append(block) - text = " ".join(parts) - else: - text = "" - text = " ".join(text.split()) - if len(text) > max_chars: - return text[: max_chars - 1].rstrip() + "…" - return text - - -def _first_user_preview(messages: list[Any]) -> str: - for raw in messages: - if isinstance(raw, dict) and raw.get("role") == "user": - preview = _message_preview(raw.get("content")) - if preview: - return preview - return "(no user prompt preview)" - - -def list_session_logs( - directory: Path = DEFAULT_SESSION_LOG_DIR, -) -> list[SessionLogEntry]: - """Return readable session logs under ``directory``, newest first.""" - if not directory.exists(): - return [] - - entries: list[SessionLogEntry] = [] - for path in directory.glob("*.json"): - try: - with open(path) as f: - data = json.load(f) - except Exception: - continue - - messages = data.get("messages") or [] - if not isinstance(messages, list): - continue - - session_id = data.get("session_id") - if not isinstance(session_id, str) or not session_id: - session_id = path.stem - - stat = path.stat() - entries.append( - SessionLogEntry( - path=path, - session_id=session_id, - session_start_time=data.get("session_start_time"), - session_end_time=data.get("session_end_time"), - model_name=data.get("model_name"), - message_count=len(messages), - preview=_first_user_preview(messages), - mtime=stat.st_mtime, - ) - ) - - entries.sort(key=lambda item: item.mtime, reverse=True) - return entries - - -def format_session_log_entry(index: int, entry: SessionLogEntry) -> str: - timestamp = entry.session_end_time or entry.session_start_time - label = "unknown time" - if isinstance(timestamp, str) and timestamp: - try: - label = datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M") - except ValueError: - label = timestamp[:16] - short_id = entry.session_id[:8] - model = entry.model_name or "unknown model" - return ( - f"{index:>2}. {label} {short_id} " - f"{entry.message_count} msgs {model}\n" - f" {entry.preview}" - ) - - -def resolve_session_log_arg( - arg: str, - entries: list[SessionLogEntry], - directory: Path = DEFAULT_SESSION_LOG_DIR, -) -> Path | None: - """Resolve ``/resume `` as index, path, filename, or session id prefix.""" - value = arg.strip() - if not value: - return None - - if value.isdigit(): - idx = int(value) - if 1 <= idx <= len(entries): - return entries[idx - 1].path - - candidate = Path(value).expanduser() - candidates = [candidate] - if not candidate.is_absolute(): - candidates.append(directory / candidate) - if candidate.suffix != ".json": - candidates.append(directory / f"{value}.json") - - for path in candidates: - if path.exists() and path.is_file(): - return path - - matches = [ - entry.path - for entry in entries - if entry.session_id.startswith(value) or entry.path.name.startswith(value) - ] - if len(matches) == 1: - return matches[0] - return None - - -def _turn_count_from_messages(messages: list[Any]) -> int: - return sum( - 1 for raw in messages if isinstance(raw, dict) and raw.get("role") == "user" - ) - - -def _has_redacted_content(messages: list[Any]) -> bool: - """Whether any message body contains a ``[REDACTED_*]`` marker.""" - for raw in messages: - if not isinstance(raw, dict): - continue - content = raw.get("content") - if isinstance(content, str) and _REDACTED_MARKER.search(content): - return True - if isinstance(content, list): - for block in content: - if isinstance(block, dict): - text = block.get("text") or block.get("content") - if isinstance(text, str) and _REDACTED_MARKER.search(text): - return True - return False - - -def restore_session_from_log(session: Any, path: Path) -> dict[str, Any]: - """Replace the active session context with messages from ``path``. - - Continues the saved session (reusing its id and on-disk save path) when - the log's ``user_id`` matches the current session, and forks otherwise: - the caller's session id stays put and future heartbeat saves go to a - fresh file rather than overwriting the source log. - - Returns metadata for the ``resume_complete`` event. - """ - with open(path) as f: - data = json.load(f) - - raw_messages = data.get("messages") - if not isinstance(raw_messages, list): - raise ValueError("Selected log does not contain a messages array") - - restored_messages: list[Message] = [] - dropped_count = 0 - for raw in raw_messages: - if not isinstance(raw, dict) or raw.get("role") == "system": - continue - try: - restored_messages.append(Message.model_validate(raw)) - except Exception as e: - dropped_count += 1 - logger.warning("Dropping malformed message from %s: %s", path, e) - - if not restored_messages: - raise ValueError("Selected log has no restorable non-system messages") - - cm = session.context_manager - system_msg = cm.items[0] if cm.items and cm.items[0].role == "system" else None - cm.items = ([system_msg] if system_msg else []) + restored_messages - - # Validate the saved model id before switching. ``update_model`` doesn't - # check availability; an unrecognised id silently sticks and the next LLM - # call fails with a cryptic routing error. Logs from a different - # deployment, an older catalog, or a removed model land here. - saved_model = data.get("model_name") - invalid_saved_model: str | None = None - if isinstance(saved_model, str) and saved_model: - if is_valid_model_id(saved_model): - session.update_model(saved_model) - else: - invalid_saved_model = saved_model - logger.warning( - "Saved log model %r failed format validation; keeping %r", - saved_model, - session.config.model_name, - ) - - cm._recompute_usage(session.config.model_name) - - saved_session_id = data.get("session_id") - saved_user_id = data.get("user_id") - is_continuation = saved_user_id == session.user_id - - if is_continuation: - if isinstance(saved_session_id, str) and saved_session_id: - session.session_id = saved_session_id - session.session_start_time = ( - data.get("session_start_time") or session.session_start_time - ) - - # Always fork the on-disk save path. The source log is treated as an - # immutable snapshot: ``logged_events`` is reset to a single - # ``resumed_from`` marker below for cost accounting, so reusing the - # source path would let the next heartbeat save destroy the original - # ``llm_call``/event history on disk. The next save will pick a fresh - # filename instead. - session._local_save_path = None - - saved_event_count = ( - len(data.get("events", [])) if isinstance(data.get("events"), list) else 0 - ) - session.logged_events = [ - { - "timestamp": datetime.now().isoformat(), - "event_type": "resumed_from", - "data": { - "path": str(path), - "original_session_id": ( - saved_session_id if isinstance(saved_session_id, str) else None - ), - "original_event_count": saved_event_count, - "forked": not is_continuation, - }, - } - ] - session.turn_count = _turn_count_from_messages(raw_messages) - session.last_auto_save_turn = session.turn_count - session.pending_approval = None - - return { - "path": str(path), - "restored_count": len(restored_messages), - "dropped_count": dropped_count, - "model_name": session.config.model_name, - "invalid_saved_model": invalid_saved_model, - "forked": not is_continuation, - "had_redacted_content": _has_redacted_content(raw_messages), - } diff --git a/agent/core/session_uploader.py b/agent/core/session_uploader.py index 404fd224563cdae3d91c2b93e05e8306ee91fb7e..ef2f9496d87f832489010f9a9529c538d939bedb 100644 --- a/agent/core/session_uploader.py +++ b/agent/core/session_uploader.py @@ -3,454 +3,32 @@ Standalone script for uploading session trajectories to HuggingFace. This runs as a separate process to avoid blocking the main agent. Uses individual file uploads to avoid race conditions. - -Two formats are supported: - -* ``row`` β€” single-line JSONL row used by the existing org telemetry/KPI - pipeline (``smolagents/ml-intern-sessions``). Compatible with - ``backend/kpis_scheduler.py``. -* ``claude_code`` β€” one event per line in the Claude Code JSONL schema, - auto-detected by the HF Agent Trace Viewer - (https://huggingface.co/changelog/agent-trace-viewer). Used for the - per-user private dataset (default ``{hf_user}/ml-intern-sessions``). """ -import argparse -import hashlib import json import os import sys from datetime import datetime from pathlib import Path -from typing import Any from dotenv import load_dotenv load_dotenv() -# Token resolution for the org KPI dataset. Fallback chain (least-privilege -# first) β€” matches backend/kpis_scheduler.py so one write-scoped token on the -# Space covers every telemetry dataset. Never hardcode tokens in source. -_ORG_TOKEN_FALLBACK_CHAIN = ( - "HF_SESSION_UPLOAD_TOKEN", - "HF_TOKEN", - "HF_ADMIN_TOKEN", -) -_PERSONAL_TOKEN_ENV = "_ML_INTERN_PERSONAL_TOKEN" - - -def _resolve_token(token_env: str | None) -> str: - """Resolve an HF token from env. ``token_env`` overrides the fallback chain.""" - if token_env == "HF_TOKEN": - try: - from agent.core.hf_tokens import resolve_hf_token - - return ( - resolve_hf_token( - os.environ.get(_PERSONAL_TOKEN_ENV), - os.environ.get("HF_TOKEN"), - ) - or "" - ) - except Exception: - token = os.environ.get(_PERSONAL_TOKEN_ENV) or os.environ.get("HF_TOKEN") - return token or "" - - if token_env: - return os.environ.get(token_env, "") or "" - for var in _ORG_TOKEN_FALLBACK_CHAIN: - val = os.environ.get(var) - if val: - return val - return "" - - -def _scrub(obj: Any) -> Any: - """Best-effort regex scrub for HF tokens / API keys before upload.""" - try: - from agent.core.redact import scrub # type: ignore - except Exception: - # Fallback for environments where the agent package isn't importable - # (shouldn't happen in our subprocess, but be defensive). - import importlib.util - - _spec = importlib.util.spec_from_file_location( - "_redact", - Path(__file__).parent / "redact.py", - ) - _mod = importlib.util.module_from_spec(_spec) - _spec.loader.exec_module(_mod) # type: ignore - scrub = _mod.scrub - return scrub(obj) - - -def _msg_uuid(session_id: str, role: str, idx: int) -> str: - """Deterministic UUID-shaped id for a Claude Code message. - - Uses sha1 of ``session_id::role::idx`` so re-uploads/heartbeats keep the - parent/child chain stable. Same convention as the example dataset - https://huggingface.co/datasets/clem/hf-coding-tools-traces. - """ - digest = hashlib.sha1(f"{session_id}::{role}::{idx}".encode("utf-8")).hexdigest() - # Format like a UUID for visual familiarity (32 hex chars w/ dashes). - return ( - f"{digest[0:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}" - ) - - -def _content_to_text(content: Any) -> str: - """Best-effort flatten of a litellm/openai content field to plain text.""" - if content is None: - return "" - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for block in content: - if isinstance(block, dict): - text = block.get("text") - if isinstance(text, str): - parts.append(text) - else: - # Unknown content block β€” keep round-trippable representation. - parts.append(json.dumps(block, default=str)) - else: - parts.append(str(block)) - return "\n".join(parts) - return str(content) - - -def _parse_tool_args(raw: Any) -> Any: - """Tool call arguments arrive as a JSON-encoded string from LLMs.""" - if isinstance(raw, dict): - return raw - if isinstance(raw, str): - try: - return json.loads(raw) - except (json.JSONDecodeError, TypeError): - return {"_raw": raw} - return raw - - -def to_claude_code_jsonl(trajectory: dict) -> list[dict]: - """Convert an internal trajectory dict to Claude Code JSONL events. - - Schema reference (per the HF Agent Trace Viewer auto-detector): - - {"type":"user","message":{"role":"user","content":"..."}, - "uuid":"...","parentUuid":null,"sessionId":"...","timestamp":"..."} - {"type":"assistant", - "message":{"role":"assistant","model":"...", - "content":[{"type":"text","text":"..."}, - {"type":"tool_use","id":"...","name":"...","input":{...}}]}, - "uuid":"...","parentUuid":"","sessionId":"...","timestamp":"..."} - {"type":"user","message":{"role":"user", - "content":[{"type":"tool_result", - "tool_use_id":"...","content":"..."}]}, - "uuid":"...","parentUuid":"","sessionId":"...","timestamp":"..."} - - System messages are skipped (they're not part of the viewer schema and - contain large prompts that pollute the trace viewer UI). - """ - session_id = trajectory["session_id"] - model_name = trajectory.get("model_name") or "" - fallback_timestamp = ( - trajectory.get("session_start_time") or datetime.now().isoformat() - ) - messages: list[dict] = trajectory.get("messages") or [] - - out: list[dict] = [] - parent_uuid: str | None = None - - for idx, msg in enumerate(messages): - if not isinstance(msg, dict): - continue - role = msg.get("role") - if role == "system": - continue - timestamp = msg.get("timestamp") or fallback_timestamp - - if role == "user": - content = _content_to_text(msg.get("content")) - event_uuid = _msg_uuid(session_id, "user", idx) - out.append( - { - "type": "user", - "message": {"role": "user", "content": content}, - "uuid": event_uuid, - "parentUuid": parent_uuid, - "sessionId": session_id, - "timestamp": timestamp, - } - ) - parent_uuid = event_uuid - - elif role == "assistant": - content_text = _content_to_text(msg.get("content")) - content_blocks: list[dict] = [] - if content_text: - content_blocks.append({"type": "text", "text": content_text}) - for tc in msg.get("tool_calls") or []: - if not isinstance(tc, dict): - continue - fn = tc.get("function") or {} - content_blocks.append( - { - "type": "tool_use", - "id": tc.get("id") or "", - "name": fn.get("name") or "", - "input": _parse_tool_args(fn.get("arguments")), - } - ) - if not content_blocks: - # Edge case: empty assistant turn (shouldn't normally happen, - # but skip rather than emit an empty content array which - # confuses the viewer). - continue - event_uuid = _msg_uuid(session_id, "assistant", idx) - out.append( - { - "type": "assistant", - "message": { - "role": "assistant", - "model": model_name, - "content": content_blocks, - }, - "uuid": event_uuid, - "parentUuid": parent_uuid, - "sessionId": session_id, - "timestamp": timestamp, - } - ) - parent_uuid = event_uuid - - elif role == "tool": - tool_call_id = msg.get("tool_call_id") or "" - content_text = _content_to_text(msg.get("content")) - event_uuid = _msg_uuid(session_id, "tool", idx) - out.append( - { - "type": "user", - "message": { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tool_call_id, - "content": content_text, - } - ], - }, - "uuid": event_uuid, - "parentUuid": parent_uuid, - "sessionId": session_id, - "timestamp": timestamp, - } - ) - parent_uuid = event_uuid - - return out - - -def _scrub_session_for_upload(data: dict) -> dict: - """Best-effort scrub of transcript fields before any upload temp file.""" - scrubbed = dict(data) - scrubbed["messages"] = _scrub(data.get("messages") or []) - scrubbed["events"] = _scrub(data.get("events") or []) - scrubbed["tools"] = _scrub(data.get("tools") or []) - return scrubbed - - -def _write_row_payload(data: dict, tmp_path: str) -> None: - """Single-row JSONL (existing format) β€” used by KPI scheduler.""" - scrubbed = _scrub_session_for_upload(data) - session_row = { - "session_id": data["session_id"], - "user_id": data.get("user_id"), - "session_start_time": data["session_start_time"], - "session_end_time": data["session_end_time"], - "model_name": data["model_name"], - "total_cost_usd": data.get("total_cost_usd"), - "messages": json.dumps(scrubbed["messages"]), - "events": json.dumps(scrubbed["events"]), - "tools": json.dumps(scrubbed["tools"]), - } - - with open(tmp_path, "w") as tmp: - json.dump(session_row, tmp) - - -def _write_claude_code_payload(data: dict, tmp_path: str) -> None: - """Multi-line JSONL in Claude Code schema for the HF trace viewer.""" - # Scrub before conversion so secrets never reach the upload temp file. - scrubbed = _scrub_session_for_upload(data) - events = to_claude_code_jsonl(scrubbed) - with open(tmp_path, "w") as tmp: - for event in events: - tmp.write(json.dumps(event)) - tmp.write("\n") - - -def _status_field(format: str) -> str: - """Per-format upload status field on the local trajectory file.""" - return "personal_upload_status" if format == "claude_code" else "upload_status" - - -def _url_field(format: str) -> str: - return "personal_upload_url" if format == "claude_code" else "upload_url" - - -def _read_session_file(session_file: str) -> dict: - """Read a local session file while respecting uploader file locks.""" - import fcntl - - with open(session_file, "r") as f: - fcntl.flock(f, fcntl.LOCK_SH) - try: - return json.load(f) - finally: - fcntl.flock(f, fcntl.LOCK_UN) - - -def _update_upload_status( - session_file: str, - status_key: str, - url_key: str, - status: str, - dataset_url: str | None = None, -) -> None: - """Atomically update only this uploader's status fields. - - The org and personal uploaders run as separate processes against the same - local session JSON file. Re-read under an exclusive lock so one uploader - cannot clobber fields written by the other. - """ - import fcntl - - with open(session_file, "r+") as f: - fcntl.flock(f, fcntl.LOCK_EX) - try: - data = json.load(f) - data[status_key] = status - if dataset_url is not None: - data[url_key] = dataset_url - data["last_save_time"] = datetime.now().isoformat() - f.seek(0) - json.dump(data, f, indent=2) - f.truncate() - f.flush() - os.fsync(f.fileno()) - finally: - fcntl.flock(f, fcntl.LOCK_UN) - - -def dataset_card_readme(repo_id: str) -> str: - """Dataset card for personal ML Intern session trace repos.""" - return """--- -pretty_name: "ML Intern Session Traces" -language: -- en -license: other -task_categories: -- text-generation -tags: -- agent-traces -- coding-agent -- ml-intern -- session-traces -- claude-code -- hf-agent-trace-viewer -configs: -- config_name: default - data_files: - - split: train - path: "sessions/**/*.jsonl" ---- - -# ML Intern session traces - -This dataset contains ML Intern coding agent session traces uploaded from local -ML Intern runs. The traces are stored as JSON Lines files under `sessions/`, -with one file per session. - -## Links - -- ML Intern demo: https://smolagents-ml-intern.hf.space -- ML Intern CLI: https://github.com/huggingface/ml-intern - -## Data description - -Each `*.jsonl` file contains a single ML Intern session converted to a -Claude-Code-style event stream for the Hugging Face Agent Trace Viewer. Entries -can include user messages, assistant messages, tool calls, tool results, model -metadata, and timestamps. - -Session files are written to paths of the form: - -```text -sessions/YYYY-MM-DD/.jsonl -``` - -## Redaction and review - -**WARNING: no comprehensive redaction or human review has been performed for this dataset.** - -ML Intern applies automated best-effort scrubbing for common secret patterns -such as Hugging Face, Anthropic, OpenAI, GitHub, and AWS tokens before upload. -This is not a privacy guarantee. - -These traces may contain sensitive information, including prompts, code, -terminal output, file paths, repository names, private task context, tool -outputs, or other data from the local development environment. Treat every -session as potentially sensitive. - -Do not make this dataset public unless you have manually inspected the uploaded -sessions and are comfortable sharing their full contents. - -## Limitations - -Coding agent transcripts can include private or off-topic content, failed -experiments, credentials accidentally pasted by a user, and outputs copied from -local files or services. Use with appropriate caution, especially before -changing repository visibility. -""" - - -def _upload_dataset_card(api: Any, repo_id: str, token: str, format: str) -> None: - """Create/update a README for personal trace datasets.""" - if format != "claude_code": - return - - api.upload_file( - path_or_fileobj=dataset_card_readme(repo_id).encode("utf-8"), - path_in_repo="README.md", - repo_id=repo_id, - repo_type="dataset", - token=token, - commit_message="Update dataset card", - ) +# Token for session uploads β€” loaded from env var (never hardcode tokens in source) +_SESSION_TOKEN = os.environ.get("HF_SESSION_UPLOAD_TOKEN", "") def upload_session_as_file( - session_file: str, - repo_id: str, - max_retries: int = 3, - format: str = "row", - token_env: str | None = None, - private: bool = False, + session_file: str, repo_id: str, max_retries: int = 3 ) -> bool: - """Upload a single session as an individual JSONL file (no race conditions). + """ + Upload a single session as an individual JSONL file (no race conditions) Args: session_file: Path to local session JSON file repo_id: HuggingFace dataset repo ID max_retries: Number of retry attempts - format: ``row`` (default, KPI-compatible) or ``claude_code`` (HF - Agent Trace Viewer compatible). - token_env: Name of the env var holding the HF token. ``None`` falls - back to the org-token chain (``HF_SESSION_UPLOAD_TOKEN`` β†’ - ``HF_TOKEN`` β†’ ``HF_ADMIN_TOKEN``). - private: When creating the repo for the first time, mark it private. Returns: True if successful, False otherwise @@ -461,60 +39,72 @@ def upload_session_as_file( print("Error: huggingface_hub library not available", file=sys.stderr) return False - status_key = _status_field(format) - url_key = _url_field(format) - try: - data = _read_session_file(session_file) + # Load session data + with open(session_file, "r") as f: + data = json.load(f) - # Skip if already uploaded for this format. - if data.get(status_key) == "success": + # Check if already uploaded + upload_status = data.get("upload_status") + if upload_status == "success": return True - hf_token = _resolve_token(token_env) + # Use dedicated session upload token (write-only access to session dataset) + hf_token = _SESSION_TOKEN if not hf_token: - _update_upload_status(session_file, status_key, url_key, "failed") + # Update status to failed + data["upload_status"] = "failed" + with open(session_file, "w") as f: + json.dump(data, f, indent=2) return False - # Build temp upload payload in the requested format. + # Prepare JSONL content (single line) + # Store messages and events as JSON strings to avoid schema conflicts + session_row = { + "session_id": data["session_id"], + "session_start_time": data["session_start_time"], + "session_end_time": data["session_end_time"], + "model_name": data["model_name"], + "messages": json.dumps(data["messages"]), + "events": json.dumps(data["events"]), + } + + # Create temporary JSONL file import tempfile with tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as tmp: + json.dump(session_row, tmp) # Single line JSON tmp_path = tmp.name try: - if format == "claude_code": - _write_claude_code_payload(data, tmp_path) - else: - _write_row_payload(data, tmp_path) - + # Generate unique path in repo: sessions/YYYY-MM-DD/session_id.jsonl session_id = data["session_id"] date_str = datetime.fromisoformat(data["session_start_time"]).strftime( "%Y-%m-%d" ) repo_path = f"sessions/{date_str}/{session_id}.jsonl" + # Upload with retries api = HfApi() for attempt in range(max_retries): try: - # Idempotent create β€” visibility is set on first creation - # only. Existing repos keep whatever the user picked via - # /share-traces. + # Try to create repo if it doesn't exist (idempotent) try: api.create_repo( repo_id=repo_id, repo_type="dataset", - private=private, + private=False, token=hf_token, - exist_ok=True, + exist_ok=True, # Don't fail if already exists ) + except Exception: + # Repo might already exist, continue pass - _upload_dataset_card(api, repo_id, hf_token, format) - + # Upload the session file api.upload_file( path_or_fileobj=tmp_path, path_in_repo=repo_path, @@ -524,13 +114,12 @@ def upload_session_as_file( commit_message=f"Add session {session_id}", ) - _update_upload_status( - session_file, - status_key, - url_key, - "success", - f"https://huggingface.co/datasets/{repo_id}", - ) + # Update local status to success + data["upload_status"] = "success" + data["upload_url"] = f"https://huggingface.co/datasets/{repo_id}" + with open(session_file, "w") as f: + json.dump(data, f, indent=2) + return True except Exception: @@ -540,12 +129,14 @@ def upload_session_as_file( wait_time = 2**attempt time.sleep(wait_time) else: - _update_upload_status( - session_file, status_key, url_key, "failed" - ) + # Final attempt failed + data["upload_status"] = "failed" + with open(session_file, "w") as f: + json.dump(data, f, indent=2) return False finally: + # Clean up temp file try: os.unlink(tmp_path) except Exception: @@ -556,102 +147,56 @@ def upload_session_as_file( return False -def retry_failed_uploads( - directory: str, - repo_id: str, - format: str = "row", - token_env: str | None = None, - private: bool = False, -): - """Retry all failed/pending uploads in a directory for the given format.""" +def retry_failed_uploads(directory: str, repo_id: str): + """Retry all failed/pending uploads in a directory""" log_dir = Path(directory) if not log_dir.exists(): return - status_key = _status_field(format) session_files = list(log_dir.glob("session_*.json")) for filepath in session_files: try: - data = _read_session_file(str(filepath)) - - # Only retry pending or failed uploads. Files predating this - # field don't have it; treat unknown as "not yet attempted" for - # the row format (legacy behavior) and "skip" for claude_code - # so we don't suddenly re-upload pre-existing sessions to a - # newly-introduced personal repo. - status = data.get(status_key, "unknown") - if format == "claude_code" and status_key not in data: - continue - - if status in ("pending", "failed", "unknown"): - upload_session_as_file( - str(filepath), - repo_id, - format=format, - token_env=token_env, - private=private, - ) + with open(filepath, "r") as f: + data = json.load(f) - except Exception: - pass + upload_status = data.get("upload_status", "unknown") + # Only retry pending or failed uploads + if upload_status in ["pending", "failed"]: + upload_session_as_file(str(filepath), repo_id) -def _str2bool(v: str) -> bool: - return str(v).strip().lower() in {"1", "true", "yes", "on"} + except Exception: + pass if __name__ == "__main__": - parser = argparse.ArgumentParser(prog="session_uploader.py") - sub = parser.add_subparsers(dest="command", required=True) - - p_upload = sub.add_parser("upload") - p_upload.add_argument("session_file") - p_upload.add_argument("repo_id") - p_upload.add_argument( - "--format", - choices=["row", "claude_code"], - default="row", - ) - p_upload.add_argument( - "--token-env", - default=None, - help="Env var name holding the HF token (default: org fallback chain).", - ) - p_upload.add_argument("--private", default="false") - - p_retry = sub.add_parser("retry") - p_retry.add_argument("directory") - p_retry.add_argument("repo_id") - p_retry.add_argument( - "--format", - choices=["row", "claude_code"], - default="row", - ) - p_retry.add_argument("--token-env", default=None) - p_retry.add_argument("--private", default="false") - - args = parser.parse_args() - - if args.command == "upload": - ok = upload_session_as_file( - args.session_file, - args.repo_id, - format=args.format, - token_env=args.token_env, - private=_str2bool(args.private), - ) - sys.exit(0 if ok else 1) - - if args.command == "retry": - retry_failed_uploads( - args.directory, - args.repo_id, - format=args.format, - token_env=args.token_env, - private=_str2bool(args.private), - ) + if len(sys.argv) < 3: + print("Usage: session_uploader.py ") + sys.exit(1) + + command = sys.argv[1] + + if command == "upload": + # python session_uploader.py upload + if len(sys.argv) < 4: + print("Usage: session_uploader.py upload ") + sys.exit(1) + session_file = sys.argv[2] + repo_id = sys.argv[3] + success = upload_session_as_file(session_file, repo_id) + sys.exit(0 if success else 1) + + elif command == "retry": + # python session_uploader.py retry + if len(sys.argv) < 4: + print("Usage: session_uploader.py retry ") + sys.exit(1) + directory = sys.argv[2] + repo_id = sys.argv[3] + retry_failed_uploads(directory, repo_id) sys.exit(0) - parser.print_help() - sys.exit(1) + else: + print(f"Unknown command: {command}") + sys.exit(1) diff --git a/agent/core/telemetry.py b/agent/core/telemetry.py deleted file mode 100644 index 38d2bbe761fee99d7c8051d6788fc849df8a8fae..0000000000000000000000000000000000000000 --- a/agent/core/telemetry.py +++ /dev/null @@ -1,422 +0,0 @@ -"""All agent observability in one module. - -Every telemetry signal the agent emits β€” LLM-call usage / cost, hf_jobs -lifecycle, sandbox lifecycle, user feedback, mid-turn heartbeat saves β€” is -defined here so business-logic files stay free of instrumentation noise. - -Callsites are one-liners:: - - await telemetry.record_llm_call(session, model=..., response=r, ...) - await telemetry.record_hf_job_submit(session, job, args, image=..., job_type="Python") - HeartbeatSaver.maybe_fire(session) - -All ``record_*`` functions emit a single ``Event`` via ``session.send_event`` -and never raise β€” telemetry is best-effort and must not break the agent. -""" - -from __future__ import annotations - -import asyncio -import logging -import time -from typing import Any - -logger = logging.getLogger(__name__) - - -# ── usage extraction ──────────────────────────────────────────────────────── - - -def extract_usage(response_or_chunk: Any) -> dict: - """Flat usage dict from a litellm response or final-chunk usage object. - - Normalizes across providers: Anthropic exposes cache tokens as - ``cache_read_input_tokens`` / ``cache_creation_input_tokens``; OpenAI uses - ``prompt_tokens_details.cached_tokens``. Exposed under the stable keys - ``cache_read_tokens`` / ``cache_creation_tokens``. - """ - u = getattr(response_or_chunk, "usage", None) - if u is None and isinstance(response_or_chunk, dict): - u = response_or_chunk.get("usage") - if u is None: - return {} - - def _g(name, default=0): - if isinstance(u, dict): - return u.get(name, default) or default - return getattr(u, name, default) or default - - prompt = _g("prompt_tokens") - completion = _g("completion_tokens") - total = _g("total_tokens") or (prompt + completion) - - cache_read = _g("cache_read_input_tokens") - cache_creation = _g("cache_creation_input_tokens") - - if not cache_read: - details = _g("prompt_tokens_details", None) - if details is not None: - if isinstance(details, dict): - cache_read = details.get("cached_tokens", 0) or 0 - else: - cache_read = getattr(details, "cached_tokens", 0) or 0 - - return { - "prompt_tokens": int(prompt), - "completion_tokens": int(completion), - "total_tokens": int(total), - "cache_read_tokens": int(cache_read), - "cache_creation_tokens": int(cache_creation), - } - - -# ── llm_call ──────────────────────────────────────────────────────────────── - - -async def record_llm_call( - session: Any, - *, - model: str, - response: Any = None, - latency_ms: int, - finish_reason: str | None, - kind: str = "main", -) -> dict: - """Emit an ``llm_call`` event and return the extracted usage dict so - callers can stash it on their result object if they want. - - ``kind`` tags the call site so downstream analytics can break spend - down by category. Values currently emitted by the codebase: - - * ``main`` β€” agent loop turn (user-facing reply or tool follow-up) - * ``research`` β€” research sub-agent inner loop (3 call sites) - * ``compaction`` β€” context-window summary on overflow - * ``effort_probe``β€” effort cascade walk on rejection / model switch - * ``restore`` β€” session re-seed summary after a Space restart - - Pre-2026-04-29 only ``main`` calls were instrumented; observed gap on - Cost Explorer was ~67%, with the other 5 call sites accounting for - the rest. Tagging lets us split the dataset's ``total_cost_usd`` by - category and validate against AWS billing. - - The ``/title`` (HF Router, not Bedrock) and ``/health/llm`` (diagnostic - endpoint, no session context) call sites are intentionally not - instrumented β€” together they're <1% of spend. - """ - usage = extract_usage(response) if response is not None else {} - cost_usd = 0.0 - if response is not None: - try: - from litellm import completion_cost - - cost_usd = float(completion_cost(completion_response=response) or 0.0) - except Exception: - cost_usd = 0.0 - from agent.core.session import Event # local import to avoid cycle - - try: - await session.send_event( - Event( - event_type="llm_call", - data={ - "model": model, - "latency_ms": latency_ms, - "finish_reason": finish_reason, - "cost_usd": cost_usd, - "kind": kind, - **usage, - }, - ) - ) - except Exception as e: - logger.debug("record_llm_call failed (non-fatal): %s", e) - return usage - - -# ── hf_jobs ──────────────────────────────────────────────────────────────── - - -def _infer_push_to_hub(script_or_cmd: Any) -> bool: - if not isinstance(script_or_cmd, str): - return False - return ( - "push_to_hub=True" in script_or_cmd - or "push_to_hub=true" in script_or_cmd - or "hub_model_id" in script_or_cmd - ) - - -async def record_hf_job_submit( - session: Any, - job: Any, - args: dict, - *, - image: str, - job_type: str, -) -> float: - """Emit ``hf_job_submit``. Returns the monotonic start timestamp so the - caller can pass it back into :func:`record_hf_job_complete`.""" - from agent.core.session import Event - - t_start = time.monotonic() - try: - script_text = args.get("script") or args.get("command") or "" - await session.send_event( - Event( - event_type="hf_job_submit", - data={ - "job_id": getattr(job, "id", None), - "job_url": getattr(job, "url", None), - "flavor": args.get("hardware_flavor", "cpu-basic"), - "timeout": args.get("timeout", "30m"), - "job_type": job_type, - "image": image, - "namespace": args.get("namespace"), - "push_to_hub": _infer_push_to_hub(script_text), - }, - ) - ) - except Exception as e: - logger.debug("record_hf_job_submit failed (non-fatal): %s", e) - return t_start - - -async def record_hf_job_complete( - session: Any, - job: Any, - *, - flavor: str, - final_status: str, - submit_ts: float, -) -> None: - from agent.core.session import Event - - try: - wall_time_s = int(time.monotonic() - submit_ts) - await session.send_event( - Event( - event_type="hf_job_complete", - data={ - "job_id": getattr(job, "id", None), - "flavor": flavor, - "final_status": final_status, - "wall_time_s": wall_time_s, - }, - ) - ) - except Exception as e: - logger.debug("record_hf_job_complete failed (non-fatal): %s", e) - - -# ── sandbox ───────────────────────────────────────────────────────────────── - - -async def record_sandbox_create( - session: Any, - sandbox: Any, - *, - hardware: str, - create_latency_s: int, -) -> None: - from agent.core.session import Event - - try: - # Pin created-at on the session so record_sandbox_destroy can diff. - session._sandbox_created_at = time.monotonic() - create_latency_s - await session.send_event( - Event( - event_type="sandbox_create", - data={ - "sandbox_id": getattr(sandbox, "space_id", None), - "hardware": hardware, - "create_latency_s": int(create_latency_s), - }, - ) - ) - except Exception as e: - logger.debug("record_sandbox_create failed (non-fatal): %s", e) - - -async def record_sandbox_destroy(session: Any, sandbox: Any) -> None: - from agent.core.session import Event - - try: - created = getattr(session, "_sandbox_created_at", None) - lifetime_s = int(time.monotonic() - created) if created else None - await session.send_event( - Event( - event_type="sandbox_destroy", - data={ - "sandbox_id": getattr(sandbox, "space_id", None), - "lifetime_s": lifetime_s, - }, - ) - ) - except Exception as e: - logger.debug("record_sandbox_destroy failed (non-fatal): %s", e) - - -# ── feedback ─────────────────────────────────────────────────────────────── - - -async def record_feedback( - session: Any, - *, - rating: str, - turn_index: int | None = None, - message_id: str | None = None, - comment: str | None = None, -) -> None: - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="feedback", - data={ - "rating": rating, - "turn_index": turn_index, - "message_id": message_id, - "comment": (comment or "")[:500], - }, - ) - ) - except Exception as e: - logger.debug("record_feedback failed (non-fatal): %s", e) - - -async def record_jobs_access_blocked( - session: Any, - *, - tool_call_ids: list[str], - plan: str, - eligible_namespaces: list[str], -) -> None: - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="jobs_access_blocked", - data={ - "tool_call_ids": tool_call_ids, - "plan": plan, - "eligible_namespaces": eligible_namespaces, - }, - ) - ) - except Exception as e: - logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e) - - -async def record_pro_cta_click( - session: Any, - *, - source: str, - target: str = "pro_pricing", -) -> None: - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="pro_cta_click", - data={"source": source, "target": target}, - ) - ) - except Exception as e: - logger.debug("record_pro_cta_click failed (non-fatal): %s", e) - - -async def record_pro_conversion( - session: Any, - *, - first_seen_at: str | None = None, -) -> None: - """Emit a ``pro_conversion`` event for a user we've previously observed - as non-Pro and now see as Pro for the first time. Detected upstream in - ``MongoSessionStore.mark_pro_seen``; fired into the user's first Pro - session so the rollup picks it up alongside other event-driven KPIs.""" - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="pro_conversion", - data={"first_seen_at": first_seen_at}, - ) - ) - except Exception as e: - logger.debug("record_pro_conversion failed (non-fatal): %s", e) - - -async def record_credits_topped_up( - session: Any, - *, - namespace: str | None = None, -) -> None: - """Emit a ``credits_topped_up`` event when an hf_job submits successfully - in a session that previously hit ``jobs_access_blocked`` β€” i.e. the user - came back from the HF billing top-up flow and unblocked themselves. - Caller is responsible for firing this at most once per session.""" - from agent.core.session import Event - - try: - await session.send_event( - Event( - event_type="credits_topped_up", - data={"namespace": namespace}, - ) - ) - except Exception as e: - logger.debug("record_credits_topped_up failed (non-fatal): %s", e) - - -# ── heartbeat ────────────────────────────────────────────────────────────── - -# Module-level reference set for fire-and-forget heartbeat tasks. asyncio only -# keeps *weak* references to tasks, so the returned Task would otherwise be -# eligible for GC before running β€” the task gets discarded and the upload -# silently never happens. Hold strong refs until the task completes. -_heartbeat_tasks: set[asyncio.Task] = set() - - -class HeartbeatSaver: - """Time-gated mid-turn flush. - - Called from ``Session.send_event`` after every event. Fires - ``save_and_upload_detached`` in a worker thread at most once per - ``heartbeat_interval_s`` (default 60s). Guards against losing trace data - on long-running turns that crash before ``turn_complete``. - """ - - @staticmethod - def maybe_fire(session: Any) -> None: - if not getattr(session.config, "save_sessions", False): - return - interval = getattr(session.config, "heartbeat_interval_s", 0) or 0 - if interval <= 0: - return - now = time.monotonic() - last = getattr(session, "_last_heartbeat_ts", None) - if last is None: - # Initialise on first event; no save yet. - session._last_heartbeat_ts = now - return - if now - last < interval: - return - session._last_heartbeat_ts = now - repo_id = session.config.session_dataset_repo - try: - task = asyncio.get_running_loop().create_task( - asyncio.to_thread(session.save_and_upload_detached, repo_id) - ) - # Hold a strong reference until the task finishes so asyncio can't - # GC it. ``set.discard`` is a no-op on missing keys β†’ safe callback. - _heartbeat_tasks.add(task) - task.add_done_callback(_heartbeat_tasks.discard) - except RuntimeError: - try: - session.save_and_upload_detached(repo_id) - except Exception as e: - logger.debug("Heartbeat save failed (non-fatal): %s", e) diff --git a/agent/core/tools.py b/agent/core/tools.py index 1b750671605143958f1193c38ef7c1ee083a3cdc..b3dc8f3be75419d16d215a39d819f7748357d41c 100644 --- a/agent/core/tools.py +++ b/agent/core/tools.py @@ -8,8 +8,11 @@ import warnings from dataclasses import dataclass from typing import Any, Awaitable, Callable, Optional +logger = logging.getLogger(__name__) + from fastmcp import Client from fastmcp.exceptions import ToolError +from lmnr import observe from mcp.types import EmbeddedResource, ImageContent, TextContent from agent.config import MCPServerConfig @@ -44,12 +47,7 @@ from agent.tools.hf_repo_git_tool import ( hf_repo_git_handler, ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler -from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler -from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler -from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler -from agent.tools.sandbox_tool import get_sandbox_tools -from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git # from agent.tools.private_hf_repo_tools import ( @@ -62,8 +60,6 @@ warnings.filterwarnings( "ignore", category=DeprecationWarning, module="aiohttp.connector" ) -logger = logging.getLogger(__name__) - NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"] @@ -131,28 +127,18 @@ class ToolRouter: Based on codex-rs/core/src/tools/router.rs """ - def __init__( - self, - mcp_servers: dict[str, MCPServerConfig], - hf_token: str | None = None, - local_mode: bool = False, - ): + def __init__(self, mcp_servers: dict[str, MCPServerConfig]): self.tools: dict[str, ToolSpec] = {} self.mcp_servers: dict[str, dict[str, Any]] = {} - for tool in create_builtin_tools(local_mode=local_mode): + for tool in create_builtin_tools(): self.register_tool(tool) self.mcp_client: Client | None = None if mcp_servers: mcp_servers_payload = {} for name, server in mcp_servers.items(): - data = server.model_dump() - if hf_token: - data.setdefault("headers", {})["Authorization"] = ( - f"Bearer {hf_token}" - ) - mcp_servers_payload[name] = data + mcp_servers_payload[name] = server.model_dump() self.mcp_client = Client({"mcpServers": mcp_servers_payload}) self._mcp_initialized = False @@ -187,19 +173,17 @@ class ToolRouter: search_openapi_handler, ) - try: - openapi_spec = await _get_api_search_tool_spec() - self.register_tool( - ToolSpec( - name=openapi_spec["name"], - description=openapi_spec["description"], - parameters=openapi_spec["parameters"], - handler=search_openapi_handler, - ) + # Register search_hf_api_endpoints with dynamic spec + openapi_spec = await _get_api_search_tool_spec() + self.register_tool( + ToolSpec( + name=openapi_spec["name"], + description=openapi_spec["description"], + parameters=openapi_spec["parameters"], + handler=search_openapi_handler, ) - logger.info(f"Loaded OpenAPI search tool: {openapi_spec['name']}") - except Exception as e: - logger.warning("Failed to load OpenAPI search tool: %s", e) + ) + logger.info(f"Loaded OpenAPI search tool: {openapi_spec['name']}") def get_tool_specs_for_llm(self) -> list[dict[str, Any]]: """Get tool specifications in OpenAI format""" @@ -219,17 +203,12 @@ class ToolRouter: async def __aenter__(self) -> "ToolRouter": if self.mcp_client is not None: - try: - await self.mcp_client.__aenter__() - await self.mcp_client.initialize() - await self.register_mcp_tools() - self._mcp_initialized = True - except Exception as e: - logger.warning( - "MCP connection failed, continuing without MCP tools: %s", e - ) - self.mcp_client = None + await self.mcp_client.__aenter__() + await self.mcp_client.initialize() + await self.register_mcp_tools() + self._mcp_initialized = True + # Register OpenAPI tool (requires async initialization) await self.register_openapi_tool() total_tools = len(self.tools) @@ -242,12 +221,9 @@ class ToolRouter: await self.mcp_client.__aexit__(exc_type, exc, tb) self._mcp_initialized = False + @observe(name="call_tool") async def call_tool( - self, - tool_name: str, - arguments: dict[str, Any], - session: Any = None, - tool_call_id: str | None = None, + self, tool_name: str, arguments: dict[str, Any], session: Any = None ) -> tuple[str, bool]: """ Call a tool and return (output_string, success_bool). @@ -263,11 +239,6 @@ class ToolRouter: # Check if handler accepts session argument sig = inspect.signature(tool.handler) if "session" in sig.parameters: - # Check if handler also accepts tool_call_id parameter - if "tool_call_id" in sig.parameters: - return await tool.handler( - arguments, session=session, tool_call_id=tool_call_id - ) return await tool.handler(arguments, session=session) return await tool.handler(arguments) @@ -290,17 +261,10 @@ class ToolRouter: # ============================================================================ -def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: +def create_builtin_tools() -> list[ToolSpec]: """Create built-in tool specifications""" # in order of importance tools = [ - # Research sub-agent (delegates to read-only tools in independent context) - ToolSpec( - name=RESEARCH_TOOL_SPEC["name"], - description=RESEARCH_TOOL_SPEC["description"], - parameters=RESEARCH_TOOL_SPEC["parameters"], - handler=research_handler, - ), # Documentation search tools ToolSpec( name=EXPLORE_HF_DOCS_TOOL_SPEC["name"], @@ -314,19 +278,6 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"], handler=hf_docs_fetch_handler, ), - # Paper discovery and reading - ToolSpec( - name=HF_PAPERS_TOOL_SPEC["name"], - description=HF_PAPERS_TOOL_SPEC["description"], - parameters=HF_PAPERS_TOOL_SPEC["parameters"], - handler=hf_papers_handler, - ), - ToolSpec( - name=WEB_SEARCH_TOOL_SPEC["name"], - description=WEB_SEARCH_TOOL_SPEC["description"], - parameters=WEB_SEARCH_TOOL_SPEC["parameters"], - handler=web_search_handler, - ), # Dataset inspection tool (unified) ToolSpec( name=HF_INSPECT_DATASET_TOOL_SPEC["name"], @@ -341,12 +292,6 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: parameters=PLAN_TOOL_SPEC["parameters"], handler=plan_tool_handler, ), - ToolSpec( - name=NOTIFY_TOOL_SPEC["name"], - description=NOTIFY_TOOL_SPEC["description"], - parameters=NOTIFY_TOOL_SPEC["parameters"], - handler=notify_handler, - ), ToolSpec( name=HF_JOBS_TOOL_SPEC["name"], description=HF_JOBS_TOOL_SPEC["description"], @@ -386,14 +331,6 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: ), ] - # Sandbox or local tools (highest priority) - if local_mode: - from agent.tools.local_tools import get_local_tools - - tools = get_local_tools() + tools - else: - tools = get_sandbox_tools() + tools - tool_names = ", ".join([t.name for t in tools]) logger.info(f"Loaded {len(tools)} built-in tools: {tool_names}") diff --git a/agent/main.py b/agent/main.py index 25d0859b31018d37ee627219b6a0d38b696e576f..542da05694a0a4241531a3490ab518b30d3abc65 100644 --- a/agent/main.py +++ b/agent/main.py @@ -1,84 +1,35 @@ """ Interactive CLI chat with the agent - -Supports two modes: - Interactive: python -m agent.main - Headless: python -m agent.main "find me bird datasets" """ -import argparse import asyncio import json -import logging import os -import signal -import sys -import time from dataclasses import dataclass from pathlib import Path from typing import Any, Optional import litellm +from lmnr import Laminar, LaminarLiteLLMCallback from prompt_toolkit import PromptSession from agent.config import load_config -from agent.core.approval_policy import is_scheduled_operation from agent.core.agent_loop import submission_loop -from agent.core import model_switcher -from agent.core.hf_tokens import resolve_hf_token -from agent.core.local_models import is_local_model_id from agent.core.session import OpType from agent.core.tools import ToolRouter -from agent.messaging.gateway import NotificationGateway from agent.utils.reliability_checks import check_training_script_save_pattern from agent.utils.terminal_display import ( - get_console, - print_approval_header, - print_approval_item, - print_banner, - print_compacted, - print_error, - print_help, - print_init_done, - print_interrupted, - print_markdown, - print_plan, - print_tool_call, - print_tool_log, - print_tool_output, - print_turn_complete, - print_yolo_approve, + format_error, + format_header, + format_plan_display, + format_separator, + format_success, + format_tool_call, + format_tool_output, + format_turn_complete, ) litellm.drop_params = True -# Suppress the "Give Feedback / Get Help" banner LiteLLM prints to stderr -# on every error β€” users don't need it, and our friendly errors cover the case. -litellm.suppress_debug_info = True - -CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json" -logger = logging.getLogger(__name__) - - -def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool: - if tool_info.get("tool") != "hf_jobs": - return False - arguments = tool_info.get("arguments") or {} - if isinstance(arguments, str): - try: - arguments = json.loads(arguments) - except json.JSONDecodeError: - return False - if not isinstance(arguments, dict): - return False - return is_scheduled_operation(arguments.get("operation")) - - -def _configure_runtime_logging() -> None: - """Keep third-party warning spam from punching through the interactive UI.""" - import logging - - logging.getLogger("LiteLLM").setLevel(logging.ERROR) - logging.getLogger("litellm").setLevel(logging.ERROR) def _safe_get_args(arguments: dict) -> dict: @@ -90,60 +41,14 @@ def _safe_get_args(arguments: dict) -> dict: return args if isinstance(args, dict) else {} -def _get_hf_user(token: str | None) -> str | None: - """Resolve the HF username for a token, if available.""" - if not token: - return None +lmnr_api_key = os.environ.get("LMNR_API_KEY") +if lmnr_api_key: try: - from huggingface_hub import HfApi - - return HfApi(token=token).whoami().get("name") - except Exception: - return None - - -async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str: - """Prompt user for HF token, validate it, save via huggingface_hub.login(). Loops until valid.""" - from prompt_toolkit.formatted_text import HTML - from huggingface_hub import HfApi, login - - print("\nA Hugging Face token is required.") - print("Get one at: https://huggingface.co/settings/tokens\n") - - while True: - try: - token = await prompt_session.prompt_async( - HTML("Paste your HF token: ") - ) - except (EOFError, KeyboardInterrupt): - print("\nToken is required to continue.") - continue - - token = token.strip() - if not token: - print("Token cannot be empty.") - continue - - # Validate token against the API - try: - api = HfApi(token=token) - user_info = api.whoami() - username = user_info.get("name", "unknown") - print(f"Token valid (user: {username})") - except Exception: - print("Invalid token. Please try again.") - continue - - # Save for future sessions - try: - login(token=token, add_to_git_credential=False) - print("Token saved to ~/.cache/huggingface/token") - except Exception as e: - print( - f"Warning: could not persist token ({e}), using for this session only." - ) - - return token + Laminar.initialize(project_api_key=lmnr_api_key) + litellm.callbacks = [LaminarLiteLLMCallback()] + print("Laminar initialized") + except Exception as e: + print(f"Failed to initialize Laminar: {e}") @dataclass @@ -162,132 +67,6 @@ class Submission: operation: Operation -def _create_rich_console(): - """Get the shared rich Console.""" - return get_console() - - -class _ThinkingShimmer: - """Animated shiny/shimmer thinking indicator β€” a bright gradient sweeps across the text.""" - - _BASE = (90, 90, 110) # dim base color - _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold) - _WIDTH = 5 # shimmer width in characters - _FPS = 24 - - def __init__(self, console): - self._console = console - self._task = None - self._running = False - - def start(self): - if self._running: - return - self._running = True - self._task = asyncio.ensure_future(self._animate()) - - def stop(self): - if not self._running: - return # no-op when never started (e.g. headless mode) - self._running = False - if self._task: - self._task.cancel() - self._task = None - # Clear the shimmer line - self._console.file.write("\r\033[K") - self._console.file.flush() - - def _render_frame(self, text: str, offset: float) -> str: - """Render one frame: a bright spot sweeps left-to-right across `text`.""" - out = [] - n = len(text) - for i, ch in enumerate(text): - # Distance from the shimmer center (wraps around) - dist = abs(i - offset) - wrap_dist = abs(i - offset + n + self._WIDTH) - dist = min(dist, wrap_dist, abs(i - offset - n - self._WIDTH)) - # Blend factor: 1.0 at center, 0.0 beyond _WIDTH - t = max(0.0, 1.0 - dist / self._WIDTH) - t = t * t * (3 - 2 * t) # smoothstep - r = int(self._BASE[0] + (self._HIGHLIGHT[0] - self._BASE[0]) * t) - g = int(self._BASE[1] + (self._HIGHLIGHT[1] - self._BASE[1]) * t) - b = int(self._BASE[2] + (self._HIGHLIGHT[2] - self._BASE[2]) * t) - out.append(f"\033[38;2;{r};{g};{b}m{ch}") - out.append("\033[0m") - return "".join(out) - - async def _animate(self): - text = "Thinking..." - n = len(text) - speed = 0.45 # characters per frame - pos = 0.0 - try: - while self._running: - frame = self._render_frame(text, pos) - self._console.file.write(f"\r {frame}") - self._console.file.flush() - pos = (pos + speed) % (n + self._WIDTH) - await asyncio.sleep(1.0 / self._FPS) - except asyncio.CancelledError: - pass - - -class _StreamBuffer: - """Accumulates streamed tokens, renders markdown block-by-block as complete - blocks appear. A "block" is everything up to a paragraph break (\\n\\n). - Unclosed code fences (odd count of ```) hold back flushing until closed so - a code block is always rendered as one unit.""" - - def __init__(self, console): - self._console = console - self._buffer = "" - - def add_chunk(self, text: str): - self._buffer += text - - def _pop_block(self) -> str | None: - """Extract the next complete block, or return None if nothing complete.""" - if self._buffer.count("```") % 2 == 1: - return None # inside an open code fence β€” wait for close - idx = self._buffer.find("\n\n") - if idx == -1: - return None - block = self._buffer[:idx] - self._buffer = self._buffer[idx + 2 :] - return block - - async def flush_ready( - self, - cancel_event: "asyncio.Event | None" = None, - instant: bool = False, - ): - """Render any complete blocks that have accumulated; leave the tail.""" - while True: - if cancel_event is not None and cancel_event.is_set(): - return - block = self._pop_block() - if block is None: - return - if block.strip(): - await print_markdown(block, cancel_event=cancel_event, instant=instant) - - async def finish( - self, - cancel_event: "asyncio.Event | None" = None, - instant: bool = False, - ): - """Flush complete blocks, then render whatever incomplete tail remains.""" - await self.flush_ready(cancel_event=cancel_event, instant=instant) - if self._buffer.strip(): - await print_markdown( - self._buffer, cancel_event=cancel_event, instant=instant - ) - self._buffer = "" - - def discard(self): - self._buffer = "" - - async def event_listener( event_queue: asyncio.Queue, submission_queue: asyncio.Queue, @@ -295,162 +74,67 @@ async def event_listener( ready_event: asyncio.Event, prompt_session: PromptSession, config=None, - session_holder=None, ) -> None: """Background task that listens for events and displays them""" - submission_id = [1000] - last_tool_name = [None] - console = _create_rich_console() - shimmer = _ThinkingShimmer(console) - stream_buf = _StreamBuffer(console) - - def _cancel_event(): - """Return the session's cancellation Event so print_markdown can abort - its typewriter loop mid-stream when Ctrl+C fires.""" - s = session_holder[0] if session_holder else None - return s._cancelled if s is not None else None + submission_id = [1000] # Use list to make it mutable in closure + last_tool_name = [None] # Track last tool called while True: try: event = await event_queue.get() + # Display event if event.event_type == "ready": - tool_count = event.data.get("tool_count", 0) if event.data else 0 - print_init_done(tool_count=tool_count) + print(format_success("\U0001f917 Agent ready")) ready_event.set() elif event.event_type == "assistant_message": - shimmer.stop() - content = event.data.get("content", "") if event.data else "" - if content: - await print_markdown(content, cancel_event=_cancel_event()) - elif event.event_type == "assistant_chunk": content = event.data.get("content", "") if event.data else "" if content: - stream_buf.add_chunk(content) - # Flush any complete markdown blocks progressively so the - # user sees paragraphs appear as they're produced, not just - # at the end of the whole response. - shimmer.stop() - await stream_buf.flush_ready(cancel_event=_cancel_event()) - elif event.event_type == "assistant_stream_end": - shimmer.stop() - await stream_buf.finish(cancel_event=_cancel_event()) + print(f"\nAssistant: {content}") elif event.event_type == "tool_call": - shimmer.stop() - stream_buf.discard() tool_name = event.data.get("tool", "") if event.data else "" arguments = event.data.get("arguments", {}) if event.data else {} if tool_name: - last_tool_name[0] = tool_name - # Skip printing research tool_call β€” the tool_log handler shows it - if tool_name != "research": - args_str = json.dumps(arguments)[:80] - print_tool_call(tool_name, args_str) + last_tool_name[0] = tool_name # Store for tool_output event + args_str = json.dumps(arguments)[:100] + "..." + print(format_tool_call(tool_name, args_str)) elif event.event_type == "tool_output": output = event.data.get("output", "") if event.data else "" success = event.data.get("success", False) if event.data else False - # Only show output for plan_tool β€” everything else is noise - if last_tool_name[0] == "plan_tool" and output: - print_tool_output(output, success, truncate=False) - shimmer.start() + if output: + # Don't truncate plan_tool output, truncate everything else + should_truncate = last_tool_name[0] != "plan_tool" + print(format_tool_output(output, success, truncate=should_truncate)) elif event.event_type == "turn_complete": - shimmer.stop() - stream_buf.discard() - print_turn_complete() - print_plan() - session = session_holder[0] if session_holder else None - if session is not None: - await session.send_deferred_turn_complete_notification(event) - turn_complete_event.set() - elif event.event_type == "interrupted": - shimmer.stop() - stream_buf.discard() - print_interrupted() - turn_complete_event.set() - elif event.event_type == "undo_complete": - console.print("[dim]Undone.[/dim]") - turn_complete_event.set() - elif event.event_type == "resume_complete": - data = event.data or {} - path = data.get("path", "?") - count = data.get("restored_count", 0) - dropped = int(data.get("dropped_count", 0) or 0) - model = data.get("model_name", "?") - invalid_model = data.get("invalid_saved_model") - forked = bool(data.get("forked", False)) - redacted = bool(data.get("had_redacted_content", False)) - verb = "Forked from" if forked else "Resumed" - console.print( - f"[green]{verb}[/green] {path} " - f"([cyan]{count}[/cyan] messages, " - f"model [cyan]{model}[/cyan])." - ) - if dropped: - console.print( - f"[yellow]Warning:[/yellow] dropped {dropped} " - "malformed message(s) while restoring β€” surrounding " - "tool-call alignment may be off." - ) - if invalid_model: - console.print( - f"[yellow]Warning:[/yellow] saved model id " - f"[cyan]{invalid_model}[/cyan] failed validation; " - f"kept current model [cyan]{model}[/cyan]." - ) - if forked: - console.print( - "[dim]Saved log belongs to a different user β€” kept " - "current session id; future saves go to a fresh file.[/dim]" - ) - if redacted: - console.print( - "[yellow]Note:[/yellow] tokens/secrets in restored " - "messages were scrubbed at save time. Your live tokens " - "are used for this session; [REDACTED_*] markers in " - "past messages are not re-injected." - ) + print(format_turn_complete()) + # Display plan after turn complete + plan_display = format_plan_display() + if plan_display: + print(plan_display) turn_complete_event.set() - elif event.event_type == "tool_log": - tool = event.data.get("tool", "") if event.data else "" - log = event.data.get("log", "") if event.data else "" - if log: - agent_id = event.data.get("agent_id", "") if event.data else "" - label = event.data.get("label", "") if event.data else "" - print_tool_log(tool, log, agent_id=agent_id, label=label) - elif event.event_type == "tool_state_change": - pass # visual noise β€” approval flow handles this elif event.event_type == "error": - shimmer.stop() - stream_buf.discard() error = ( event.data.get("error", "Unknown error") if event.data else "Unknown error" ) - print_error(error) + print(format_error(error)) turn_complete_event.set() elif event.event_type == "shutdown": - shimmer.stop() - stream_buf.discard() break elif event.event_type == "processing": - shimmer.start() + pass # print("Processing...", flush=True) elif event.event_type == "compacted": old_tokens = event.data.get("old_tokens", 0) if event.data else 0 new_tokens = event.data.get("new_tokens", 0) if event.data else 0 - print_compacted(old_tokens, new_tokens) + print(f"Compacted context: {old_tokens} β†’ {new_tokens} tokens") elif event.event_type == "approval_required": # Handle batch approval format tools_data = event.data.get("tools", []) if event.data else [] count = event.data.get("count", 0) if event.data else 0 - # If yolo mode is active, auto-approve everything except - # scheduled HF jobs, whose recurring cost stays manual. - if ( - config - and config.yolo_mode - and not any(_is_scheduled_hf_job_tool(t) for t in tools_data) - ): + # If yolo mode is active, auto-approve everything + if config and config.yolo_mode: approvals = [ { "tool_call_id": t.get("tool_call_id", ""), @@ -459,7 +143,7 @@ async def event_listener( } for t in tools_data ] - print_yolo_approve(count) + print(f"\n⚑ YOLO MODE: Auto-approving {count} item(s)") submission_id[0] += 1 approval_submission = Submission( id=f"approval_{submission_id[0]}", @@ -471,7 +155,14 @@ async def event_listener( await submission_queue.put(approval_submission) continue - print_approval_header(count) + print("\n" + format_separator()) + print( + format_header( + f"APPROVAL REQUIRED ({count} item{'s' if count != 1 else ''})" + ) + ) + print(format_separator()) + approvals = [] # Ask for approval for each tool @@ -490,7 +181,9 @@ async def event_listener( operation = arguments.get("operation", "") - print_approval_item(i, count, tool_name, operation) + print(f"\n[Item {i}/{count}]") + print(f"Tool: {tool_name}") + print(f"Operation: {operation}") # Handle different tool types if tool_name == "hf_jobs": @@ -683,35 +376,10 @@ async def event_listener( if gated is not None: print(f"Gated: {gated}") - # Get user decision for this item. Ctrl+C / EOF here is - # treated as "reject remaining" (matches Codex's modal - # priority and Forgecode's approval-cancel path). Without - # this, KeyboardInterrupt kills the event listener and - # the main loop deadlocks waiting for turn_complete. - try: - response = await prompt_session.prompt_async( - f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): " - ) - except (KeyboardInterrupt, EOFError): - get_console().print( - "[dim]Approval cancelled β€” rejecting remaining items[/dim]" - ) - approvals.append( - { - "tool_call_id": tool_call_id, - "approved": False, - "feedback": "User cancelled approval", - } - ) - for remaining in tools_data[i:]: - approvals.append( - { - "tool_call_id": remaining.get("tool_call_id", ""), - "approved": False, - "feedback": None, - } - ) - break + # Get user decision for this item + response = await prompt_session.prompt_async( + f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): " + ) response = response.strip().lower() @@ -719,7 +387,7 @@ async def event_listener( if response == "yolo": config.yolo_mode = True print( - "YOLO MODE ACTIVATED - Auto-approving all future tool calls" + "⚑ YOLO MODE ACTIVATED - Auto-approving all future tool calls" ) # Auto-approve this item and all remaining approvals.append( @@ -760,7 +428,7 @@ async def event_listener( ), ) await submission_queue.put(approval_submission) - console.print() # spacing after approval + print(format_separator() + "\n") # Silently ignore other events except asyncio.CancelledError: @@ -776,334 +444,28 @@ async def get_user_input(prompt_session: PromptSession) -> str: return await prompt_session.prompt_async(HTML("\n> ")) -# ── Slash command helpers ──────────────────────────────────────────────── - -# Slash commands are defined in terminal_display - - -async def _resume_picker( - arg: str, - prompt_session: PromptSession | None, -) -> Path | None: - """Resolve a session log path via ``arg`` or interactive selection. - - Returns ``None`` if the user cancels, no logs exist, or the argument - matches nothing β€” already prints the explanation in those cases. - """ - from agent.core.session_resume import ( - format_session_log_entry, - list_session_logs, - resolve_session_log_arg, - ) - from agent.core.session import DEFAULT_SESSION_LOG_DIR - - console = get_console() - directory = DEFAULT_SESSION_LOG_DIR - entries = list_session_logs(directory) - if not entries: - console.print(f"[yellow]No session logs found in ./{directory}.[/yellow]") - return None - - if arg: - selected = resolve_session_log_arg(arg, entries, directory) - if selected is None: - console.print(f"[bold red]No matching session log:[/bold red] {arg}") - return selected - - console.print() - console.print("[bold]Saved sessions[/bold]") - for index, entry in enumerate(entries, start=1): - console.print(format_session_log_entry(index, entry)) - console.print() - - if prompt_session is None: - console.print("[yellow]Cannot prompt for a selection here.[/yellow]") - return None - - try: - choice = await prompt_session.prompt_async( - "Select session number (blank to cancel): " - ) - except (EOFError, KeyboardInterrupt): - console.print("[dim]Resume cancelled.[/dim]") - return None - choice = choice.strip() - if not choice: - console.print("[dim]Resume cancelled.[/dim]") - return None - selected = resolve_session_log_arg(choice, entries, directory) - if selected is None: - console.print(f"[bold red]Invalid selection:[/bold red] {choice}") - return selected - - -async def _handle_slash_command( - cmd: str, - config, - session_holder: list, - submission_queue: asyncio.Queue, - submission_id: list[int], - prompt_session: PromptSession | None = None, -) -> Submission | None: - """ - Handle a slash command. Returns a Submission to enqueue, or None if - the command was handled locally (caller should set turn_complete_event). - - Async because ``/model`` fires a probe ping to validate the model+effort - combo before committing the switch. - """ - parts = cmd.strip().split(None, 1) - command = parts[0].lower() - arg = parts[1].strip() if len(parts) > 1 else "" - - if command == "/help": - print_help() - return None - - if command == "/undo": - submission_id[0] += 1 - return Submission( - id=f"sub_{submission_id[0]}", - operation=Operation(op_type=OpType.UNDO), - ) - - if command == "/compact": - submission_id[0] += 1 - return Submission( - id=f"sub_{submission_id[0]}", - operation=Operation(op_type=OpType.COMPACT), - ) - - if command == "/resume": - session = session_holder[0] if session_holder else None - if session is None: - get_console().print( - "[bold red]No active session to restore into.[/bold red]" - ) - return None - selected_path = await _resume_picker(arg, prompt_session) - if selected_path is None: - return None - submission_id[0] += 1 - return Submission( - id=f"sub_{submission_id[0]}", - operation=Operation( - op_type=OpType.RESUME, data={"path": str(selected_path)} - ), - ) - - if command == "/model": - console = get_console() - if not arg: - model_switcher.print_model_listing(config, console) - return None - if not model_switcher.is_valid_model_id(arg): - model_switcher.print_invalid_id(arg, console) - return None - normalized = arg.removeprefix("huggingface/") - session = session_holder[0] if session_holder else None - await model_switcher.probe_and_switch_model( - normalized, - config, - session, - console, - resolve_hf_token(), - ) - return None - - if command == "/yolo": - config.yolo_mode = not config.yolo_mode - state = "ON" if config.yolo_mode else "OFF" - print(f"YOLO mode: {state}") - return None - - if command == "/effort": - console = get_console() - valid = {"minimal", "low", "medium", "high", "xhigh", "max", "off"} - session = session_holder[0] if session_holder else None - if not arg: - current = config.reasoning_effort or "off" - console.print(f"[bold]Reasoning effort preference:[/bold] {current}") - if session and session.model_effective_effort: - console.print("[dim]Probed per model:[/dim]") - for m, eff in session.model_effective_effort.items(): - console.print(f" [dim]{m}: {eff or 'off'}[/dim]") - console.print( - "[dim]Set with '/effort minimal|low|medium|high|xhigh|max|off'. " - "'max' is Anthropic-only; 'xhigh' is also supported by current " - "OpenAI GPT-5 models. The cascade falls back to whatever the " - "model actually accepts.[/dim]" - ) - return None - level = arg.lower() - if level not in valid: - console.print(f"[bold red]Invalid level:[/bold red] {arg}") - console.print(f"[dim]Expected one of: {', '.join(sorted(valid))}[/dim]") - return None - config.reasoning_effort = None if level == "off" else level - # Drop the per-model probe cache β€” the new preference may resolve - # differently. Next ``/model`` (or the retry safety net) reprobes. - if session is not None: - session.model_effective_effort.clear() - console.print(f"[green]Reasoning effort: {level}[/green]") - if session is not None: - console.print( - "[dim]run /model to re-probe, or send a message β€” " - "the agent adjusts automatically if the new level isn't supported.[/dim]" - ) - return None - - if command == "/status": - session = session_holder[0] if session_holder else None - print(f"Model: {config.model_name}") - print(f"Reasoning effort: {config.reasoning_effort or 'off'}") - if session: - print(f"Turns: {session.turn_count}") - print(f"Context items: {len(session.context_manager.items)}") - return None - - if command == "/share-traces": - session = session_holder[0] if session_holder else None - await _handle_share_traces_command(arg, config, session) - return None - - print(f"Unknown command: {command}. Type /help for available commands.") - return None - - -async def _handle_share_traces_command(arg: str, config, session) -> None: - """Show or flip visibility of the user's personal trace dataset. - - Uses the user's own HF_TOKEN (write-scoped to their namespace). Only - operates on the personal trace repo configured via - ``personal_trace_repo_template`` β€” never touches the shared org dataset. - """ - from huggingface_hub import HfApi - from huggingface_hub.utils import HfHubHTTPError - - console = get_console() - if session is None: - console.print("[bold red]No active session.[/bold red]") - return - - repo_id = session._personal_trace_repo_id() if session is not None else None - if not repo_id: - if not getattr(config, "share_traces", False): - console.print( - "[yellow]share_traces is disabled in config. " - "Set it to true to publish per-session traces to your HF dataset." - "[/yellow]" - ) - return - if not session.user_id: - console.print( - "[yellow]No HF username resolved \u2014 cannot pick a personal " - "trace repo. Set HF_TOKEN to a token tied to your account.[/yellow]" - ) - return - console.print( - "[yellow]personal_trace_repo_template is unset \u2014 nothing to do.[/yellow]" - ) - return - - token = session.hf_token or resolve_hf_token() - if not token: - console.print( - "[bold red]No HF_TOKEN available.[/bold red] Cannot read or change " - "dataset visibility." - ) - return - - api = HfApi(token=token) - url = f"https://huggingface.co/datasets/{repo_id}" - target = arg.strip().lower() - - if not target: - try: - info = await asyncio.to_thread( - api.repo_info, repo_id=repo_id, repo_type="dataset" - ) - visibility = "private" if getattr(info, "private", False) else "public" - console.print(f"[bold]Trace dataset:[/bold] {url}") - console.print(f"[bold]Visibility:[/bold] {visibility}") - console.print( - "[dim]Use '/share-traces public' to publish, " - "'/share-traces private' to lock it back down.[/dim]" - ) - except HfHubHTTPError as e: - if getattr(e.response, "status_code", None) == 404: - console.print( - f"[dim]Dataset {repo_id} doesn't exist yet \u2014 it'll be " - "created (private) on the next session save.[/dim]" - ) - else: - console.print(f"[bold red]Hub error:[/bold red] {e}") - except Exception as e: - console.print(f"[bold red]Could not fetch dataset info:[/bold red] {e}") - return - - if target not in {"public", "private"}: - console.print( - f"[bold red]Unknown argument:[/bold red] {target}. " - "Expected 'public' or 'private'." - ) - return - - private = target == "private" - try: - # Idempotent β€” create if missing so first-flip works even before any - # session has been saved yet. - await asyncio.to_thread( - api.create_repo, - repo_id=repo_id, - repo_type="dataset", - private=private, - token=token, - exist_ok=True, - ) - await asyncio.to_thread( - api.update_repo_settings, - repo_id=repo_id, - repo_type="dataset", - private=private, - token=token, - ) - except Exception as e: - console.print(f"[bold red]Failed to update visibility:[/bold red] {e}") - return - - label = "PUBLIC" if not private else "private" - console.print(f"[green]Dataset is now {label}.[/green] {url}") - - -async def main(model: str | None = None): +async def main(): """Interactive chat with the agent""" + from agent.utils.terminal_display import Colors # Clear screen os.system("clear" if os.name != "nt" else "cls") - # Create prompt session for input (needed early for token prompt) - prompt_session = PromptSession() - - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) - if model: - config.model_name = model - - # HF token β€” required for Hub-backed models/tools, but not for local LLMs. - hf_token = resolve_hf_token() - if not hf_token and not is_local_model_id(config.model_name): - hf_token = await _prompt_and_save_hf_token(prompt_session) - - # Resolve username for banner - hf_user = _get_hf_user(hf_token) - - print_banner(model=config.model_name, hf_user=hf_user) - - # Pre-warm the HF router catalog in the background so /model switches - # don't block on a network fetch. - from agent.core import hf_router_catalog + banner = r""" + _ _ _ _____ _ _ + | | | |_ _ __ _ __ _(_)_ __ __ _ | ___|_ _ ___ ___ / \ __ _ ___ _ __ | |_ + | |_| | | | |/ _` |/ _` | | '_ \ / _` | | |_ / _` |/ __/ _ \ / _ \ / _` |/ _ \ '_ \| __| + | _ | |_| | (_| | (_| | | | | | (_| | | _| (_| | (_| __/ / ___ \ (_| | __/ | | | |_ + |_| |_|\__,_|\__, |\__, |_|_| |_|\__, | |_| \__,_|\___\___| /_/ \_\__, |\___|_| |_|\__| + |___/ |___/ |___/ |___/ + """ - asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm)) + print(format_separator()) + print(f"{Colors.YELLOW} {banner}{Colors.RESET}") + print("Type your messages below. Type 'exit', 'quit', or '/quit' to end.\n") + print(format_separator()) + # Wait for agent to initialize + print("Initializing agent...") # Create queues for communication submission_queue = asyncio.Queue() @@ -1114,13 +476,16 @@ async def main(model: str | None = None): turn_complete_event.set() ready_event = asyncio.Event() - notification_gateway = NotificationGateway(config.messaging) - await notification_gateway.start() - # Create tool router with local mode - tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) + # Start agent loop in background + config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json" + config = load_config(config_path) + + # Create tool router + print(f"Loading MCP servers: {', '.join(config.mcpServers.keys())}") + tool_router = ToolRouter(config.mcpServers) - # Session holder for interrupt/model/status access - session_holder = [None] + # Create prompt session for input + prompt_session = PromptSession() agent_task = asyncio.create_task( submission_loop( @@ -1128,14 +493,6 @@ async def main(model: str | None = None): event_queue, config=config, tool_router=tool_router, - session_holder=session_holder, - hf_token=hf_token, - user_id=hf_user, - local_mode=True, - stream=True, - notification_gateway=notification_gateway, - notification_destinations=config.messaging.default_auto_destinations(), - defer_turn_complete_notification=True, ) ) @@ -1148,93 +505,24 @@ async def main(model: str | None = None): ready_event, prompt_session, config, - session_holder=session_holder, ) ) await ready_event.wait() - submission_id = [0] - # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137 - # (`QUIT_SHORTCUT_TIMEOUT = Duration::from_secs(1)`). Two Ctrl+C presses - # within this window quit; a single press cancels the in-flight turn. - CTRL_C_QUIT_WINDOW = 1.0 - # Hint string matches codex-rs/tui/src/bottom_pane/footer.rs:746 - # (`" again to quit"` prefixed with the key binding, rendered dim). - CTRL_C_HINT = "[dim]ctrl + c again to quit[/dim]" - interrupt_state = {"last": 0.0, "exit": False} - - loop = asyncio.get_running_loop() - - def _on_sigint() -> None: - """SIGINT handler β€” fires while the agent is generating (terminal is - in cooked mode between prompts). Mirrors Codex's `on_ctrl_c` in - codex-rs/tui/src/chatwidget.rs: first press cancels active work and - arms the quit hint; second press within the window quits.""" - now = time.monotonic() - session = session_holder[0] - - if now - interrupt_state["last"] < CTRL_C_QUIT_WINDOW: - interrupt_state["exit"] = True - if session: - session.cancel() - # Wake the main loop out of turn_complete_event.wait() - turn_complete_event.set() - return - - interrupt_state["last"] = now - if session and not session.is_cancelled: - session.cancel() - get_console().print(f"\n{CTRL_C_HINT}") - - def _install_sigint() -> bool: - try: - loop.add_signal_handler(signal.SIGINT, _on_sigint) - return True - except (NotImplementedError, RuntimeError): - return False # Windows or non-main thread - - # prompt_toolkit's prompt_async installs its own SIGINT handler and, on - # exit, calls loop.remove_signal_handler(SIGINT) β€” which wipes ours too. - # So we re-arm at the top of every loop iteration, right before the busy - # wait. Without this, Ctrl+C during agent streaming after the first turn - # falls through to the default handler and the terminal just echoes ^C. - sigint_available = _install_sigint() + submission_id = 0 try: while True: - if sigint_available: - _install_sigint() - - try: - await turn_complete_event.wait() - except asyncio.CancelledError: - break + # Wait for previous turn to complete + await turn_complete_event.wait() turn_complete_event.clear() - if interrupt_state["exit"]: - break - - # Get user input. prompt_toolkit puts the terminal in raw mode and - # installs its own SIGINT handling; ^C arrives as \x03 and surfaces - # as KeyboardInterrupt here. On return, prompt_toolkit removes the - # loop's SIGINT handler β€” we re-arm at the top of the next iter. + # Get user input try: user_input = await get_user_input(prompt_session) except EOFError: break - except KeyboardInterrupt: - now = time.monotonic() - if now - interrupt_state["last"] < CTRL_C_QUIT_WINDOW: - break - interrupt_state["last"] = now - get_console().print(CTRL_C_HINT) - turn_complete_event.set() - continue - - # A successful read ends the double-press window β€” an unrelated - # Ctrl+C during the next turn should start a fresh arming. - interrupt_state["last"] = 0.0 # Check for exit commands if user_input.strip().lower() in ["exit", "quit", "/quit", "/exit"]: @@ -1245,337 +533,35 @@ async def main(model: str | None = None): turn_complete_event.set() continue - # Handle slash commands - if user_input.strip().startswith("/"): - sub = await _handle_slash_command( - user_input.strip(), - config, - session_holder, - submission_queue, - submission_id, - prompt_session, - ) - if sub is None: - # Command handled locally, loop back for input - turn_complete_event.set() - continue - else: - await submission_queue.put(sub) - continue - # Submit to agent - submission_id[0] += 1 + submission_id += 1 submission = Submission( - id=f"sub_{submission_id[0]}", + id=f"sub_{submission_id}", operation=Operation( op_type=OpType.USER_INPUT, data={"text": user_input} ), ) + # print(f"Main submitting: {submission.operation.op_type}") await submission_queue.put(submission) except KeyboardInterrupt: - pass - finally: - if sigint_available: - try: - loop.remove_signal_handler(signal.SIGINT) - except (NotImplementedError, RuntimeError): - pass + print("\n\nInterrupted by user") # Shutdown + print("\nπŸ›‘ Shutting down agent...") shutdown_submission = Submission( id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN) ) await submission_queue.put(shutdown_submission) - # Wait for agent to finish (the listener must keep draining events - # or the agent will block on event_queue.put) - try: - await asyncio.wait_for(agent_task, timeout=10.0) - except asyncio.TimeoutError: - agent_task.cancel() - # Agent didn't shut down cleanly β€” close MCP explicitly - await tool_router.__aexit__(None, None, None) - finally: - await notification_gateway.close() - - # Now safe to cancel the listener (agent is done emitting events) + await asyncio.wait_for(agent_task, timeout=5.0) listener_task.cancel() - get_console().print("\n[dim]Bye.[/dim]\n") - - -async def headless_main( - prompt: str, - model: str | None = None, - max_iterations: int | None = None, - stream: bool = True, -) -> None: - """Run a single prompt headlessly and exit.""" - import logging - - logging.basicConfig(level=logging.WARNING) - _configure_runtime_logging() - - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) - config.yolo_mode = True # Auto-approve everything in headless mode - - if model: - config.model_name = model - - hf_token = resolve_hf_token() - if not hf_token and not is_local_model_id(config.model_name): - print( - "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", - file=sys.stderr, - ) - sys.exit(1) - - if hf_token: - print("HF token loaded", file=sys.stderr) + print("✨ Goodbye!\n") - notification_gateway = NotificationGateway(config.messaging) - await notification_gateway.start() - hf_user = _get_hf_user(hf_token) - - if max_iterations is not None: - config.max_iterations = max_iterations - - print(f"Model: {config.model_name}", file=sys.stderr) - print(f"Max iterations: {config.max_iterations}", file=sys.stderr) - print(f"Prompt: {prompt}", file=sys.stderr) - print("---", file=sys.stderr) - - submission_queue: asyncio.Queue = asyncio.Queue() - event_queue: asyncio.Queue = asyncio.Queue() - - tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) - session_holder: list = [None] - - agent_task = asyncio.create_task( - submission_loop( - submission_queue, - event_queue, - config=config, - tool_router=tool_router, - session_holder=session_holder, - hf_token=hf_token, - user_id=hf_user, - local_mode=True, - stream=stream, - notification_gateway=notification_gateway, - notification_destinations=config.messaging.default_auto_destinations(), - defer_turn_complete_notification=True, - ) - ) - - # Wait for ready - while True: - event = await event_queue.get() - if event.event_type == "ready": - break - - # Submit the prompt - submission = Submission( - id="sub_1", - operation=Operation(op_type=OpType.USER_INPUT, data={"text": prompt}), - ) - await submission_queue.put(submission) - - # Process events until turn completes. Headless mode is for scripts / - # log capture: no shimmer animation, no typewriter, no live-redrawing - # research overlay. Output is plain, append-only text. - console = _create_rich_console() - stream_buf = _StreamBuffer(console) - _hl_last_tool = [None] - _hl_sub_id = [1] - # Research sub-agent tool calls are buffered per agent_id and dumped as - # a static block once each sub-agent finishes, instead of streaming via - # the live redrawing SubAgentDisplayManager (which is TTY-only). - _hl_research_buffers: dict[str, dict] = {} - - while True: - event = await event_queue.get() - - if event.event_type == "assistant_chunk": - content = event.data.get("content", "") if event.data else "" - if content: - stream_buf.add_chunk(content) - await stream_buf.flush_ready(instant=True) - elif event.event_type == "assistant_stream_end": - await stream_buf.finish(instant=True) - elif event.event_type == "assistant_message": - content = event.data.get("content", "") if event.data else "" - if content: - await print_markdown(content, instant=True) - elif event.event_type == "tool_call": - stream_buf.discard() - tool_name = event.data.get("tool", "") if event.data else "" - arguments = event.data.get("arguments", {}) if event.data else {} - if tool_name: - _hl_last_tool[0] = tool_name - if tool_name != "research": - args_str = json.dumps(arguments)[:80] - print_tool_call(tool_name, args_str) - elif event.event_type == "tool_output": - output = event.data.get("output", "") if event.data else "" - success = event.data.get("success", False) if event.data else False - if _hl_last_tool[0] == "plan_tool" and output: - print_tool_output(output, success, truncate=False) - elif event.event_type == "tool_log": - tool = event.data.get("tool", "") if event.data else "" - log = event.data.get("log", "") if event.data else "" - if not log: - pass - elif tool == "research": - # Headless mode: buffer research sub-agent activity per-agent, - # then dump each as a static block on completion. The live - # SubAgentDisplayManager uses terminal cursor tricks that are - # unfit for non-TTY output, but parallel agents still need - # distinct output so we key buffers by agent_id. - agent_id = event.data.get("agent_id", "") if event.data else "" - label = event.data.get("label", "") if event.data else "" - aid = agent_id or "research" - if log == "Starting research sub-agent...": - _hl_research_buffers[aid] = { - "label": label or "research", - "calls": [], - } - elif log == "Research complete.": - buf = _hl_research_buffers.pop(aid, None) - if buf is not None: - f = get_console().file - f.write(f" \033[38;2;255;200;80mβ–Έ {buf['label']}\033[0m\n") - for call in buf["calls"]: - f.write(f" \033[2m{call}\033[0m\n") - f.flush() - elif log.startswith("tokens:") or log.startswith("tools:"): - pass # stats updates β€” only useful for the live display - elif aid in _hl_research_buffers: - _hl_research_buffers[aid]["calls"].append(log) - else: - # Orphan event (Start was missed) β€” fall back to raw print - print_tool_log(tool, log, agent_id=agent_id, label=label) - else: - print_tool_log(tool, log) - elif event.event_type == "approval_required": - # Auto-approve in headless mode, except scheduled HF jobs. Those - # are rejected because their recurring cost needs manual approval. - tools_data = event.data.get("tools", []) if event.data else [] - approvals = [ - { - "tool_call_id": t.get("tool_call_id", ""), - "approved": not _is_scheduled_hf_job_tool(t), - "feedback": ( - "Scheduled HF jobs require manual approval." - if _is_scheduled_hf_job_tool(t) - else None - ), - } - for t in tools_data - ] - _hl_sub_id[0] += 1 - await submission_queue.put( - Submission( - id=f"hl_approval_{_hl_sub_id[0]}", - operation=Operation( - op_type=OpType.EXEC_APPROVAL, - data={"approvals": approvals}, - ), - ) - ) - elif event.event_type == "compacted": - old_tokens = event.data.get("old_tokens", 0) if event.data else 0 - new_tokens = event.data.get("new_tokens", 0) if event.data else 0 - print_compacted(old_tokens, new_tokens) - elif event.event_type == "error": - stream_buf.discard() - error = ( - event.data.get("error", "Unknown error") - if event.data - else "Unknown error" - ) - print_error(error) - break - elif event.event_type in ("turn_complete", "interrupted"): - stream_buf.discard() - history_size = event.data.get("history_size", "?") if event.data else "?" - print( - f"\n--- Agent {event.event_type} (history_size={history_size}) ---", - file=sys.stderr, - ) - if event.event_type == "turn_complete": - session = session_holder[0] if session_holder else None - if session is not None: - await session.send_deferred_turn_complete_notification(event) - break - - # Shutdown - shutdown_submission = Submission( - id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN) - ) - await submission_queue.put(shutdown_submission) - - try: - await asyncio.wait_for(agent_task, timeout=10.0) - except asyncio.TimeoutError: - agent_task.cancel() - await tool_router.__aexit__(None, None, None) - finally: - await notification_gateway.close() - - -def cli(): - """Entry point for the ml-intern CLI command.""" - import logging as _logging - import warnings - - # Suppress aiohttp "Unclosed client session" noise during event loop teardown - _logging.getLogger("asyncio").setLevel(_logging.CRITICAL) - _configure_runtime_logging() - # Suppress litellm pydantic deprecation warnings - warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm") - # Suppress whoosh invalid escape sequence warnings (third-party, unfixed upstream) - warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh") - - parser = argparse.ArgumentParser(description="Hugging Face Agent CLI") - parser.add_argument( - "prompt", nargs="?", default=None, help="Run headlessly with this prompt" - ) - parser.add_argument( - "--model", "-m", default=None, help="Model to use (default: from config)" - ) - parser.add_argument( - "--max-iterations", - type=int, - default=None, - help="Max LLM requests per turn (default: 50, use -1 for unlimited)", - ) - parser.add_argument( - "--no-stream", - action="store_true", - help="Disable token streaming (use non-streaming LLM calls)", - ) - args = parser.parse_args() +if __name__ == "__main__": try: - if args.prompt: - max_iter = args.max_iterations - if max_iter is not None and max_iter < 0: - max_iter = 10_000 # effectively unlimited - asyncio.run( - headless_main( - args.prompt, - model=args.model, - max_iterations=max_iter, - stream=not args.no_stream, - ) - ) - else: - asyncio.run(main(model=args.model)) + asyncio.run(main()) except KeyboardInterrupt: - print("\n\nGoodbye!") - - -if __name__ == "__main__": - cli() + print("\n\n✨ Goodbye!") diff --git a/agent/messaging/__init__.py b/agent/messaging/__init__.py deleted file mode 100644 index c399d254e30fcbce555d6f51b810440b1171ec1a..0000000000000000000000000000000000000000 --- a/agent/messaging/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from agent.messaging.gateway import NotificationGateway -from agent.messaging.models import ( - MessagingConfig, - NotificationRequest, - NotificationResult, - SUPPORTED_AUTO_EVENT_TYPES, -) - -__all__ = [ - "MessagingConfig", - "NotificationGateway", - "NotificationRequest", - "NotificationResult", - "SUPPORTED_AUTO_EVENT_TYPES", -] diff --git a/agent/messaging/base.py b/agent/messaging/base.py deleted file mode 100644 index a74f9cf0d1cb2a77328124414b04de9ebbd6b582..0000000000000000000000000000000000000000 --- a/agent/messaging/base.py +++ /dev/null @@ -1,31 +0,0 @@ -from abc import ABC, abstractmethod - -import httpx - -from agent.messaging.models import ( - DestinationConfig, - NotificationRequest, - NotificationResult, -) - - -class NotificationError(Exception): - """Delivery failed and should not be retried.""" - - -class RetryableNotificationError(NotificationError): - """Delivery failed transiently and can be retried.""" - - -class NotificationProvider(ABC): - provider_name: str - - @abstractmethod - async def send( - self, - client: httpx.AsyncClient, - destination_name: str, - destination: DestinationConfig, - request: NotificationRequest, - ) -> NotificationResult: - """Deliver a notification to one destination.""" diff --git a/agent/messaging/gateway.py b/agent/messaging/gateway.py deleted file mode 100644 index 1de9438f5c5c8ae2847ef1bf4a398d10e8903048..0000000000000000000000000000000000000000 --- a/agent/messaging/gateway.py +++ /dev/null @@ -1,172 +0,0 @@ -import asyncio -import logging -from collections.abc import Iterable - -import httpx - -from agent.messaging.base import ( - NotificationError, - NotificationProvider, - RetryableNotificationError, -) -from agent.messaging.models import ( - MessagingConfig, - NotificationRequest, - NotificationResult, -) -from agent.messaging.slack import SlackProvider - -logger = logging.getLogger(__name__) - -_RETRY_DELAYS = (1, 2, 4) - - -class NotificationGateway: - def __init__(self, config: MessagingConfig): - self.config = config - self._providers: dict[str, NotificationProvider] = { - "slack": SlackProvider(), - } - self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue() - self._worker_task: asyncio.Task | None = None - self._client: httpx.AsyncClient | None = None - - @property - def enabled(self) -> bool: - return self.config.enabled - - async def start(self) -> None: - if not self.enabled or self._worker_task is not None: - return - self._client = httpx.AsyncClient(timeout=10.0) - self._worker_task = asyncio.create_task( - self._worker(), name="notification-gateway" - ) - - async def flush(self) -> None: - if not self.enabled: - return - await self._queue.join() - - async def close(self) -> None: - if not self.enabled: - return - await self.flush() - if self._worker_task is not None: - self._worker_task.cancel() - try: - await self._worker_task - except asyncio.CancelledError: - pass - self._worker_task = None - if self._client is not None: - await self._client.aclose() - self._client = None - - async def send(self, request: NotificationRequest) -> NotificationResult: - if not self.enabled: - return NotificationResult( - destination=request.destination, - ok=False, - provider="disabled", - error="Messaging is disabled", - ) - - destination = self.config.get_destination(request.destination) - if destination is None: - return NotificationResult( - destination=request.destination, - ok=False, - provider="unknown", - error=f"Unknown destination '{request.destination}'", - ) - - provider = self._providers.get(destination.provider) - if provider is None: - return NotificationResult( - destination=request.destination, - ok=False, - provider=destination.provider, - error=f"No provider implementation for '{destination.provider}'", - ) - return await self._send_with_retries( - provider, request.destination, destination, request - ) - - async def send_many( - self, requests: Iterable[NotificationRequest] - ) -> list[NotificationResult]: - results: list[NotificationResult] = [] - for request in requests: - results.append(await self.send(request)) - return results - - async def enqueue(self, request: NotificationRequest) -> bool: - if not self.enabled or self._worker_task is None: - return False - await self._queue.put(request) - return True - - async def _worker(self) -> None: - while True: - request = await self._queue.get() - try: - result = await self.send(request) - if not result.ok: - logger.warning( - "Notification delivery failed for %s: %s", - request.destination, - result.error, - ) - except Exception: - logger.exception("Unexpected notification worker failure") - finally: - self._queue.task_done() - - async def _send_with_retries( - self, - provider: NotificationProvider, - destination_name: str, - destination, - request: NotificationRequest, - ) -> NotificationResult: - client = self._client or httpx.AsyncClient(timeout=10.0) - owns_client = self._client is None - try: - for attempt in range(len(_RETRY_DELAYS) + 1): - try: - return await provider.send( - client, destination_name, destination, request - ) - except RetryableNotificationError as exc: - if attempt >= len(_RETRY_DELAYS): - return NotificationResult( - destination=destination_name, - ok=False, - provider=provider.provider_name, - error=str(exc), - ) - delay = _RETRY_DELAYS[attempt] - logger.warning( - "Retrying notification to %s in %ss after transient error: %s", - destination_name, - delay, - exc, - ) - await asyncio.sleep(delay) - except NotificationError as exc: - return NotificationResult( - destination=destination_name, - ok=False, - provider=provider.provider_name, - error=str(exc), - ) - return NotificationResult( - destination=destination_name, - ok=False, - provider=provider.provider_name, - error="Notification delivery exhausted retries", - ) - finally: - if owns_client: - await client.aclose() diff --git a/agent/messaging/models.py b/agent/messaging/models.py deleted file mode 100644 index 16148a8179f5de3fa38b36ce76166a48e9f54a83..0000000000000000000000000000000000000000 --- a/agent/messaging/models.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import Annotated, Literal - -from pydantic import BaseModel, Field, field_validator, model_validator - -_DESTINATION_NAME_CHARS = set("abcdefghijklmnopqrstuvwxyz0123456789._-") -SUPPORTED_AUTO_EVENT_TYPES = {"approval_required", "error", "turn_complete"} - - -class SlackDestinationConfig(BaseModel): - provider: Literal["slack"] = "slack" - token: str - channel: str - allow_agent_tool: bool = False - allow_auto_events: bool = False - username: str | None = None - icon_emoji: str | None = None - - @field_validator("token", "channel") - @classmethod - def _require_non_empty(cls, value: str) -> str: - value = value.strip() - if not value: - raise ValueError("must not be empty") - return value - - -DestinationConfig = Annotated[SlackDestinationConfig, Field(discriminator="provider")] - - -class MessagingConfig(BaseModel): - enabled: bool = False - auto_event_types: list[str] = Field( - default_factory=lambda: ["approval_required", "error", "turn_complete"] - ) - destinations: dict[str, DestinationConfig] = Field(default_factory=dict) - - @field_validator("destinations") - @classmethod - def _validate_destination_names( - cls, destinations: dict[str, DestinationConfig] - ) -> dict[str, DestinationConfig]: - for name in destinations: - if not name or any(char not in _DESTINATION_NAME_CHARS for char in name): - raise ValueError( - "destination names must use lowercase letters, digits, '.', '_' or '-'" - ) - return destinations - - @field_validator("auto_event_types") - @classmethod - def _validate_auto_event_types(cls, event_types: list[str]) -> list[str]: - if not event_types: - return [] - normalized: list[str] = [] - seen: set[str] = set() - for event_type in event_types: - if event_type not in SUPPORTED_AUTO_EVENT_TYPES: - raise ValueError(f"unsupported auto event type '{event_type}'") - if event_type not in seen: - normalized.append(event_type) - seen.add(event_type) - return normalized - - @model_validator(mode="after") - def _require_destinations_when_enabled(self) -> "MessagingConfig": - if self.enabled and not self.destinations: - raise ValueError("messaging.enabled requires at least one destination") - return self - - def get_destination(self, name: str) -> DestinationConfig | None: - return self.destinations.get(name) - - def can_agent_tool_send(self, name: str) -> bool: - destination = self.get_destination(name) - return bool(destination and destination.allow_agent_tool) - - def can_auto_send(self, name: str) -> bool: - destination = self.get_destination(name) - return bool(destination and destination.allow_auto_events) - - def default_auto_destinations(self) -> list[str]: - if not self.enabled: - return [] - return [name for name in self.destinations if self.can_auto_send(name)] - - -class NotificationRequest(BaseModel): - destination: str - title: str | None = None - message: str - severity: Literal["info", "success", "warning", "error"] = "info" - metadata: dict[str, str] = Field(default_factory=dict) - event_type: str | None = None - - @field_validator("destination", "message") - @classmethod - def _require_text(cls, value: str) -> str: - value = value.strip() - if not value: - raise ValueError("must not be empty") - return value - - @field_validator("title") - @classmethod - def _normalize_title(cls, value: str | None) -> str | None: - if value is None: - return None - value = value.strip() - return value or None - - -class NotificationResult(BaseModel): - destination: str - ok: bool - provider: str - error: str | None = None - external_id: str | None = None diff --git a/agent/messaging/slack.py b/agent/messaging/slack.py deleted file mode 100644 index 3790e44af790db8579a9a8efb88a2a16283ec71d..0000000000000000000000000000000000000000 --- a/agent/messaging/slack.py +++ /dev/null @@ -1,184 +0,0 @@ -import json -import re - -import httpx - -from agent.messaging.base import ( - NotificationError, - NotificationProvider, - RetryableNotificationError, -) -from agent.messaging.models import ( - NotificationRequest, - NotificationResult, - SlackDestinationConfig, -) - -_SEVERITY_PREFIX = { - "info": "[INFO]", - "success": "[SUCCESS]", - "warning": "[WARNING]", - "error": "[ERROR]", -} - - -def _format_slack_mrkdwn(content: str) -> str: - """Convert common Markdown constructs to Slack's mrkdwn syntax.""" - if not content: - return content - - placeholders: dict[str, str] = {} - placeholder_index = 0 - - def placeholder(value: str) -> str: - nonlocal placeholder_index - key = f"\x00SLACK{placeholder_index}\x00" - placeholder_index += 1 - placeholders[key] = value - return key - - text = content - - # Protect code before any formatting conversion. Slack's mrkdwn ignores - # formatting inside backticks, so these regions should stay byte-for-byte. - text = re.sub( - r"(```(?:[^\n]*\n)?[\s\S]*?```)", - lambda match: placeholder(match.group(0)), - text, - ) - text = re.sub(r"(`[^`\n]+`)", lambda match: placeholder(match.group(0)), text) - - def convert_markdown_link(match: re.Match[str]) -> str: - label = match.group(1) - url = match.group(2).strip() - if url.startswith("<") and url.endswith(">"): - url = url[1:-1].strip() - return placeholder(f"<{url}|{label}>") - - text = re.sub( - r"\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)", - convert_markdown_link, - text, - ) - - # Preserve existing Slack entities and manual mrkdwn links before escaping. - text = re.sub( - r"(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)", - lambda match: placeholder(match.group(1)), - text, - ) - text = re.sub( - r"^(>+\s)", - lambda match: placeholder(match.group(0)), - text, - flags=re.MULTILINE, - ) - - text = text.replace("&", "&").replace("<", "<").replace(">", ">") - text = text.replace("&", "&").replace("<", "<").replace(">", ">") - - def convert_header(match: re.Match[str]) -> str: - header = match.group(1).strip() - header = re.sub(r"\*\*(.+?)\*\*", r"\1", header) - return placeholder(f"*{header}*") - - text = re.sub(r"^#{1,6}\s+(.+)$", convert_header, text, flags=re.MULTILINE) - text = re.sub( - r"\*\*\*(.+?)\*\*\*", - lambda match: placeholder(f"*_{match.group(1)}_*"), - text, - ) - text = re.sub( - r"\*\*(.+?)\*\*", - lambda match: placeholder(f"*{match.group(1)}*"), - text, - ) - text = re.sub( - r"(? str: - lines: list[str] = [] - prefix = _SEVERITY_PREFIX[request.severity] - if request.title: - lines.append(f"{prefix} {request.title}") - else: - lines.append(prefix) - lines.append(request.message) - for key, value in request.metadata.items(): - lines.append(f"{key}: {value}") - return _format_slack_mrkdwn("\n".join(lines)) - - -class SlackProvider(NotificationProvider): - provider_name = "slack" - - async def send( - self, - client: httpx.AsyncClient, - destination_name: str, - destination: SlackDestinationConfig, - request: NotificationRequest, - ) -> NotificationResult: - payload = { - "channel": destination.channel, - "text": _format_text(request), - "mrkdwn": True, - "unfurl_links": False, - "unfurl_media": False, - } - if destination.username: - payload["username"] = destination.username - if destination.icon_emoji: - payload["icon_emoji"] = destination.icon_emoji - - try: - response = await client.post( - "https://slack.com/api/chat.postMessage", - headers={ - "Authorization": f"Bearer {destination.token}", - "Content-Type": "application/json; charset=utf-8", - }, - content=json.dumps(payload), - ) - except httpx.TimeoutException as exc: - raise RetryableNotificationError("Slack request timed out") from exc - except httpx.TransportError as exc: - raise RetryableNotificationError("Slack transport error") from exc - - if response.status_code == 429 or response.status_code >= 500: - raise RetryableNotificationError(f"Slack HTTP {response.status_code}") - if response.status_code >= 400: - raise NotificationError(f"Slack HTTP {response.status_code}") - - try: - data = response.json() - except ValueError as exc: - raise RetryableNotificationError("Slack returned invalid JSON") from exc - - if not data.get("ok"): - error = str(data.get("error") or "unknown_error") - if error == "ratelimited": - raise RetryableNotificationError(error) - raise NotificationError(error) - - return NotificationResult( - destination=destination_name, - ok=True, - provider=self.provider_name, - external_id=str(data.get("ts") or ""), - error=None, - ) diff --git a/agent/prompts/system_prompt_v2.yaml b/agent/prompts/system_prompt_v2.yaml index c7806ebe7c8bf55cd2d5223a8b6f8c97474feef4..d404b2788fe887a1a6f0f326961b284efbc9ca09 100644 --- a/agent/prompts/system_prompt_v2.yaml +++ b/agent/prompts/system_prompt_v2.yaml @@ -23,29 +23,93 @@ system_prompt: | ## PHASE 1: RESEARCH (Mandatory - Never Skip) - ⚠️ **CRITICAL:** Your training data is outdated. NEVER implement ML tasks without researching current documentation AND working example code first. + ⚠️ **CRITICAL:** Your training data is outdated. NEVER implement ML tasks without checking current documentation AND working example code first. APIs, best practices, and methods change frequently. + + **Research Checklist:** + 1. βœ… **Identify relevant libraries** (TRL for training, datasets for data, PEFT for LoRA, trackio for monitoring) + 2. βœ… **Find working example code FIRST**: `github_find_examples({"repo": "trl", "keyword": "grpo"})` + - ⚠️ MANDATORY: Find reference implementations before coding + - Returns: Working scripts/notebooks from examples/ and scripts/ directories + - Shows: Current API usage, proven patterns, best practices + 3. βœ… **Read example implementations**: `github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/..."})` + - Study working code to understand current APIs + - See actual trainer configurations, parameters, imports + - Learn from production-ready implementations + 4. βœ… **Explore documentation structure**: `explore_hf_docs()` + - For training: "trl", "peft", "accelerate" + - For data: "datasets", "dataset-viewer" + - For monitoring: "trackio" + - For inference: "vllm", "inference-endpoints" + 5. βœ… **Fetch specific documentation**: `fetch_hf_docs()` from explore results + 6. βœ… **Find API endpoints if needed**: `find_hf_api(query="space logs")` or `find_hf_api(tag="spaces")` for REST API operations + + **βœ“ CORRECT Research Pattern:** + ```python + # User requests: "Fine-tune a model for instruction following using SFT" + + # Step 1: Find working example code FIRST + github_find_examples({"repo": "trl", "keyword": "sft", "org": "huggingface"}) + # Returns: examples/scripts/sft.py, examples/scripts/sft_vlm.py + + # Step 2: Read the example implementation + github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"}) + # Study: imports, SFTTrainer usage, SFTConfig parameters, dataset handling + + # Step 3: Explore TRL documentation for details + explore_hf_docs("trl") # Discover available pages + + # Step 4: Fetch specific trainer documentation + fetch_hf_docs("https://huggingface.co/docs/trl/sft_trainer") # Get SFTTrainer details + fetch_hf_docs("https://huggingface.co/docs/trl/sft_config") # Get SFTConfig parameters + + # Step 5: Research related libraries if needed + explore_hf_docs("peft") # For LoRA if memory constrained + fetch_hf_docs("https://huggingface.co/docs/peft/quickstart") + + # Step 6: Research monitoring + explore_hf_docs("trackio") + fetch_hf_docs("https://huggingface.co/docs/trackio/quickstart") - **Use the `research` tool.** It spawns a sub-agent with its own context window that explores docs, reads example code, and returns a concise summary β€” keeping your context clean. + # Now I have: working example code + current documentation + API details + # Proceed to Phase 2 with accurate, proven implementation patterns + ``` + **βœ— WRONG - Skipping Research:** ```python - # Example: User requests "Fine-tune a model for instruction following using SFT" - research({ - "task": "Research current TRL SFTTrainer: find working example scripts in the trl repo, read the SFT example implementation, check SFTConfig parameters in docs, and check trackio monitoring setup.", - "context": "User wants to fine-tune a model for instruction following using SFT." - }) - # Returns: key findings, code patterns, imports, config parameters, file references + # User requests: "Fine-tune a model" + # Immediately creating training script based on internal knowledge + # This will likely use outdated APIs or wrong patterns! ``` - **Be specific in your research task** β€” include library names, trainer types, dataset names, specific questions. The sub-agent knows how to use github_find_examples, github_read_file, explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, and hf_papers. + **βœ— ALSO WRONG - Documentation Only (No Example Code):** + ```python + # User requests: "Fine-tune a model" + # Only reading docs, not looking at working examples + explore_hf_docs("trl") + fetch_hf_docs("https://...") + # This misses proven patterns and actual working code! + ``` - **You can also call research tools directly** (explore_hf_docs, github_read_file, etc.) for quick lookups that don't need a full research cycle. + **βœ— ALSO WRONG - Using PEFT without being asked for it explicitly:** + ```python + # User requests: "Fine-tune a model" + # Using PEFT without being asked for it explicitly + explore_hf_docs("peft") + fetch_hf_docs("https://...") + # This is not what the user asked for! + ``` - **Skip research ONLY for:** + **Skip Research ONLY for:** - Simple factual questions ("What is LoRA?", "What is DPO?") - Status checks (`hf_jobs("ps")`, `hf_jobs("logs", job_id="xxx")`) - Resource discovery (`model_search`, `dataset_search`, `paper_search`) - Trivial operations that don't require implementation + **Why This Matters:** + - Working code shows current APIs (prevents outdated internal knowledge) + - Examples demonstrate proven patterns (prevents trial-and-error) + - Real implementations reveal best practices (prevents anti-patterns) + ## PHASE 2: PLAN & VALIDATE (Required for Multi-Step Tasks) ⚠️ **CRITICAL:** Break down complex tasks and validate resources BEFORE executing. @@ -200,22 +264,74 @@ system_prompt: | # Tool Usage Patterns for Reliability - ## Research + ## GitHub Code Research Tools (⚠️ CRITICAL - Use BEFORE Implementing) - Use the `research` tool for any ML implementation research. It handles the full - github_find_examples β†’ github_read_file β†’ explore_hf_docs β†’ fetch_hf_docs chain - in its own context and returns a summary. You can also call these tools directly for quick lookups. + **github_find_examples:** + - ⚠️ MANDATORY: ALWAYS use before implementing ML tasks + - Find working example code (scripts, notebooks, tutorials) in repositories + - Use to discover current implementations BEFORE writing code + - Pattern: find_examples β†’ read_file β†’ implement using proven patterns + - Shows: Current API usage, best practices, working configurations + - Example: `github_find_examples({"repo": "trl", "keyword": "grpo"})` - ## Hub Discovery Tools (MCP) + **github_read_file:** + - Use AFTER github_find_examples to study implementation code + - Read trainer classes, example scripts, configuration files + - Returns: File contents with line numbers (default 300 lines) + - Use line_start/line_end for large files + - Example: `github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"})` + + + **github_list_repos:** + - Discover libraries and repositories for a task + - List repos by stars, forks, update date + - Use when exploring what libraries exist + - Example: `github_list_repos({"owner": "huggingface", "sort": "stars", "limit": 10})` + + ## Documentation Tools - **model_search / dataset_search / paper_search / hub_repo_details:** - - Find models, datasets, papers by query - - ⚠️ ALWAYS verify dataset format with hub_repo_details before training - - hub_repo_details: check model size, architecture, dataset columns/splits + **explore_hf_docs:** + - Use AFTER github_find_examples to complement example code with docs + - Use to discover current documentation structure + - Returns list of pages with 300-char glimpses + - Then use fetch_hf_docs for detailed content + + **fetch_hf_docs:** + - Use after explore_hf_docs to get full page content + - Get complete API documentation, examples, parameters + - Critical for training tasks to get current trainer configs **find_hf_api:** - - Find REST API endpoints by keyword or tag - - For API-only operations: streaming logs, org management, etc. + - Find REST API endpoints by keyword search or tag browsing + - Use `query` for keyword search (e.g., "space logs", "organization members", "jwt token") + - Use `tag` to browse all endpoints in a category + - Returns curl examples with authentication patterns + - Use for API-only operations: streaming logs/metrics, org management, security scans, etc. + + ## Hub Discovery Tools (MCP) + + **model_search:** + - Find models by query, task, author, library + - Sort by downloads, likes, trending, created date + - ALWAYS verify with hub_repo_details before using + - Select most appropriate option based on requirements + + **dataset_search:** + - Find datasets by query, tags, author + - Sort by downloads, likes, trending + - ALWAYS verify format with hub_repo_details before training + - Select most suitable dataset based on format and task + + **paper_search:** + - Find research papers semantically + - Get paper abstracts and links + - Useful for understanding methods before implementing + + **hub_repo_details:** + - Get detailed information about repos + - ⚠️ CRITICAL: Use this to verify dataset format before training + - Check model size, architecture, requirements + - Verify dataset columns, splits, size ## Execution & Storage Tools @@ -285,13 +401,16 @@ system_prompt: | ## Documentation Usage **βœ“ DO:** - - Use `research` tool before implementing any ML task - - Base implementation on the research findings (code patterns, imports, config) + - Research before implementing any ML task + - Use explore β†’ fetch β†’ implement pattern + - Check current APIs and parameters + - Base implementation on researched approaches **βœ— DON'T:** - - Implement based on internal knowledge without researching first + - Implement based on internal knowledge without checking docs - Assume you know current API syntax - - Skip research for "simple" ML tasks + - Skip research for "simple" tasks + - Use outdated patterns or methods ## Error Handling & Recovery @@ -400,24 +519,42 @@ system_prompt: | User: Fine-tune Llama for instruction following on ultrachat dataset Assistant: - I'll fine-tune Llama for instruction following. Let me research current TRL SFT patterns and validate the dataset. + βœ“ I'll help you fine-tune Llama for instruction following. Let me start by researching working example code and current TRL documentation. - [Creates plan with plan_tool: Research, Find model, Validate dataset, Create script, Submit job] + [Creates plan with plan_tool: Find examples, Study code, Research docs, Find model, Validate dataset, Create script, Submit job] - [STEP 1: Research via sub-agent β€” keeps main context clean] - research({ - "task": "Research current TRL SFTTrainer: find working SFT example scripts in the trl repo, read the implementation, check SFTConfig parameters and imports. Also check trackio monitoring setup.", - "context": "User wants to SFT fine-tune Llama on ultrachat dataset." - }) - # Returns: key imports, SFTConfig params, working code patterns, trackio setup + [STEP 1: Find working example code FIRST] + github_find_examples({"repo": "trl", "keyword": "sft", "org": "huggingface"}) + # Found: examples/scripts/sft.py, examples/scripts/sft_vlm.py + + [STEP 2: Read the working implementation] + github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"}) + # Studied: SFTTrainer usage, SFTConfig parameters, dataset handling, imports + + [STEP 3: Research documentation for details] + [Researches: explore_hf_docs("trl"), fetch_hf_docs(SFT pages), explore_hf_docs("trackio")] + + [STEP 4: Discover resources] + [Discovers resources: model_search, hub_repo_details for latest Llama models] + [Discovers datasets: dataset_search, hub_repo_details for ultrachat] + + [STEP 5: Select optimal configuration] + After evaluating options: + - Selected: meta-llama/Llama-3.2-1B (1.24B params) - optimal balance of quality and efficiency + - Dataset: HuggingFaceH4/ultrachat_200k (207K samples, "messages" format βœ“ SFT-compatible) + - Hardware: t4-small (4vCPU/15GB/GPU 16GB, $0.60/hr) - cost-efficient for this model size + - Estimated: 3 hours, ~$1.80 total cost + + [STEP 6: Create and submit training job] + [Updates plan: mark resource selection complete, mark script creation in_progress] - [STEP 2: Discover and validate resources] - model_search({"query": "llama instruct", "sort": "downloads"}) - hub_repo_details({"repo_ids": ["meta-llama/Llama-3.2-1B", "HuggingFaceH4/ultrachat_200k"]}) - # Validates: model exists, dataset has "messages" column βœ“ SFT-compatible + [Creates script based on examples/scripts/sft.py pattern with: + - Imports from studied example (transformers, trl, datasets, trackio) + - SFTTrainer configuration from working code + - Dataset handling pattern from example (load_dataset + format verification) + - Trackio monitoring as shown in docs + - push_to_hub configuration with HF_TOKEN] - [STEP 3: Create and submit training job] - [Creates script based on research findings β€” correct imports, SFTConfig, dataset handling, trackio, push_to_hub] [Submits training job with hf_jobs: hardware=t4-small, timeout=4h, env=HF_TOKEN] @@ -464,8 +601,8 @@ system_prompt: | # Additional Instructions - - **Always use current information:** Use the `research` tool before implementing ML tasks; internal knowledge may be outdated - - **Example code first:** The research sub-agent finds and reads working examples β€” real code shows current APIs and patterns + - **Always use current information:** Find working examples with github_find_examples + check documentation before implementing; internal knowledge may be outdated + - **Example code first:** ALWAYS use github_find_examples + github_read_file before implementing ML tasks - real code shows current APIs and patterns - **Search before building:** Use Hub search tools, GitHub code search, and documentation before creating custom solutions - **Verify explicitly:** Never assume dataset schemas, column names, or API details; always check with hub_repo_details - **Base on documented practices:** Implement using researched approaches from documentation, not general knowledge diff --git a/agent/prompts/system_prompt_v3.yaml b/agent/prompts/system_prompt_v3.yaml deleted file mode 100644 index 4543048f1fd6721264b2ca9ff72b96fb9da472ee..0000000000000000000000000000000000000000 --- a/agent/prompts/system_prompt_v3.yaml +++ /dev/null @@ -1,200 +0,0 @@ -system_prompt: | - You are ML Intern, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face (HF) ecosystem. - - Your goal is to complete what the user requested with zero errors. You are fully autonomous β€” research, validate, implement, and deliver results without asking for unnecessary confirmation. - - # Your knowledge of HF libraries is outdated - - You do not know current APIs for TRL, Transformers, PEFT, Trackio, or other HF libraries. Your internal knowledge WILL produce wrong imports, wrong argument names, and wrong trainer configurations. - - Before writing any ML implementation code, start from the literature. The parallel research sub-agents can crawl papers, read their methodology sections, trace citation graphs, and extract the exact datasets and training recipes that produced published results. This is your primary advantage β€” use it. - - Your default workflow for any ML task: - 1. Find the landmark paper(s) for the task or domain - 2. Crawl their citation graphs to find recent downstream work - 3. Read methodology sections (not abstracts) of the most promising papers β€” especially recent ones with strong results, lot of citations, and publications in high-impact conferences - 4. Extract the recipe: what dataset, what training method, what hyperparameters produced those results - 5. Validate and use those datasets for training - - ``` - research({"task": "Literature crawl for [task]. Start from [paper/topic]. Crawl citation graph for recent downstream papers. Read their methodology sections (3, 4, 5) β€” extract the exact datasets, training methods, and hyperparameters that produced their best results. Attribute every finding to a specific result (e.g. 'Dataset X + method Y β†’ 85.3% on benchmark Z'). Also find working code examples using current TRL/Transformers APIs.", "context": "User wants to [goal]. We need the best training recipe backed by published results."}) - ``` - - The sub-agent knows how to use github_find_examples, github_read_file, explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, and hf_papers (with citation_graph, read_paper, snippet_search, find_datasets). Be specific in your task description β€” name anchor papers or arxiv IDs when you have them. - - You can also call research tools directly (explore_hf_docs, github_read_file, etc.) for quick lookups. - - Skip research only for trivial non-code operations. - - # Mistakes you WILL make without research - - HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio config field names. Fix: read a current example script first. - - WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs. - - WRONG DATASET FORMAT: You will assume column names without checking. Training fails with KeyError. Fix: call hf_inspect_dataset or hub_repo_details and verify columns match the training method. - - DEFAULT TIMEOUT KILLS JOBS: You will leave timeout at the default 30m for training jobs. Training takes hours. The job gets killed and all progress is lost. Fix: set timeout based on model size (minimum 2h for any training). - - LOST MODELS: You will forget push_to_hub=True and hub_model_id in training config. Job storage is ephemeral β€” the filesystem is deleted when the job ends. Without push_to_hub, the trained model is permanently lost. - - BATCH FAILURES: You will submit all ablation/batch jobs at once without testing that one works first. All will fail for the same bug. Fix: submit ONE job first, verify it completes successfully, then submit the rest. - - SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do. - - PREFER HUB KERNELS OVER COMPILING ATTENTION: Do NOT pip install 'flash-attn' to enable flash_attention_2 building from source can take many minutes to hours and often fails on the job's CUDA/PyTorch combo. Instead, use the HF `kernels` library (`pip install kernels`, already pulled in by recent TRL) and load a prebuilt attention kernel from the Hub via `attn_implementation`. Examples: `AutoModelForCausalLM.from_pretrained(..., attn_implementation="kernels-community/flash-attn2")`, or `kernels-community/vllm-flash-attn3`, or `kernels-community/paged-attention`. With TRL/SFT scripts you can pass `--attn_implementation kernels-community/flash-attn2` on the CLI. Search additional kernels at https://huggingface.co/models?other=kernel. Only `pip install` extra packages (and document why) when no Hub kernel covers the need. - - SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself β€” switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task. - - # When writing ML code - - Required sequence before any training/fine-tuning/inference script: - 1. Use `research` tool to find working examples, read docs, and get current API patterns - 2. Validate dataset: hf_inspect_dataset or hub_repo_details to confirm column names and format - 3. Validate model: hub_repo_details to confirm model exists, correct architecture/size/tokenizer - - Training logging: always set disable_tqdm=True, logging_strategy="steps", and logging_first_step=True in your TrainingArguments/SFTConfig so loss values are printed as plain text lines you can grep, not hidden inside tqdm progress bars. - - Dataset format requirements by training method: - SFT: "messages", "text", or "prompt"/"completion" - DPO: "prompt", "chosen", "rejected" - GRPO: "prompt" - - # Trackio - - Trackio is natively integrated with Transformers Trainer and all TRL trainers β€” the built-in TrackioCallback handles init/log/finish. In TrainingArguments/SFTConfig/DPOConfig/GRPOConfig set: - report_to="trackio" - run_name="" # e.g. "sft_qwen3-4b_lr2e-5_bs128" - project="" # keeps related runs grouped so you can compare them - trackio_space_id="/mlintern-<8-char-id>" # creates a public dashboard Space - `project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars. - - Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels: - ERROR β€” stop and change approach (divergence, NaN, OOM) - WARN β€” tweak hyperparameters (overfitting, early stopping, KL spike, reward collapse, slow convergence) - INFO β€” milestones (training complete, target reached, checkpoint saved) - Always include numeric values and an actionable suggestion in `text`, e.g. "loss=12.4 at step 200 β€” lr likely too high, try Γ—0.1". A future call must be able to parse it and act on it. - - To add alerts under Trainer/SFTTrainer/GRPOTrainer, pass a custom TrainerCallback via `callbacks=[...]` that calls trackio.alert() inside `on_log` (training metrics like loss, reward, kl) and `on_evaluate` (eval metrics β€” only available here, not in `on_log`). Keep each `if` simple: one metric, one threshold. Conditions stay easy to adjust between runs. - - Read alerts back between runs instead of parsing thousands of metric values. CLI β€” always use --json: - trackio get alerts --project

--run --json - trackio get alerts --project

--since --json # incremental polling - trackio get run --project

--run --json - trackio get metric --project

--run --metric --json - trackio list runs --project

--json - Python: api = trackio.Api(); api.alerts(

, run=, since=); api.runs(

) (each run has .name, .config, .alerts()). - - Drive the next config from prior alerts: - diverged β†’ lr Γ— 0.1 - overfitting β†’ weight_decay Γ— 10 or reduce capacity - early stopping β†’ lr Γ— 0.5 or adjust schedule - high accuracy β†’ refine around current config - Read prior config via api.runs(...).config and only mutate keys the alerts justify changing. - - # Data audit - - Before working with any dataset, audit it first. Do not assume you know what the data looks like β€” inspect it. - - Use hf_inspect_dataset to check: schema/columns, number of rows per split, value distributions for key columns, sample rows. Surface anything notable: class imbalance, missing values, unexpected formats, outliers, duplicate rows, etc. - - Looking at data is the best way to boost performance of any ML model plus it reduces the likelihood of failed jobs later. - - # When submitting a training job - - Before calling hf_jobs, output a pre-flight check: - - Reference implementation: [which example you based this on] - - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details] - - push_to_hub=True and hub_model_id set - - timeout: [value] (based on: [model size] on [hardware]) - - Trackio monitoring included and deploying metrics to a public Space - - If you cannot fill in all items, stop and complete the missing steps first. - - For batch/ablation jobs: submit ONE job first. Check logs to confirm it starts training successfully. Only then submit the remaining jobs. Never submit all at once. - - Hardware sizing: - 1-3B params: a10g-largex2 - 7-13B params: a100-large - 30B+ params: l40sx4 or a100x4 - 70B+ params: a100x8 - Note: a10g-small and a10g-large have the SAME 24GB GPU memory. The difference is CPU/RAM only. - - # Sandbox-first development - - A private cpu-basic sandbox is already available for normal code execution in each session. For non-trivial scripts, develop and test there before launching via hf_jobs: - write script β†’ pip install β†’ test with small run using bash/read/write/edit β†’ fix errors β†’ launch via hf_jobs at scale - - Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier. - - Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths. - - - # When a task has 3+ steps - - Use plan_tool to track progress. One task in_progress at a time. Mark completed immediately after finishing. Update frequently to show the user what you're doing. - - # Error recovery - - When something fails: - - Diagnose the actual error. Read the full error message and logs. - - Do not retry the exact same thing. Identify what needs to change. - - If an API/import error: check documentation for the correct API. - - If an OOM error: (1) reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally to keep effective batch size identical, (2) enable gradient_checkpointing=True, (3) upgrade to larger GPU (a10gx4β†’a100β†’a100x4β†’a100x8). Do NOT switch training methods (e.g. SFTβ†’LoRA) or reduce max_length β€” those change what the user gets. If OOM happens in sandbox, create a new sandbox with larger GPU hardware. - - Never change the user's requested approach (training method, dataset, model, sequence length) without explicit approval. - - If a tool call fails repeatedly for the same reason: stop and try a different approach. - - Never silently substitute resources (datasets, models) β€” tell the user if something isn't available. - - # Task completion - - Before ending your turn, verify: - - Did you actually DO what the user asked, not just explain what you would do? - - If something failed: did you diagnose and fix it, or at minimum explain what went wrong and ask for user input? - - For training jobs: did you include a working Trackio dashboard URL? - - Do not stop after describing what you plan to do. Continue calling tools until the task is verifiably done. - Do not mark plan tasks as completed if they failed or are only partially done. - - # Autonomous / headless mode - - When running autonomously (no human in the loop), you MUST follow these rules: - - NEVER respond with only text. Every response MUST include at least one tool call. If you have nothing to do, check the plan, verify outputs or plan ahead. A text-only response ends the agent loop permanently β€” there is no human to re-prompt you. - - NEVER STOP WORKING. Do NOT decide you are "done" while time remains. The human is not watching β€” they expect you to use the ENTIRE time budget productively. Do NOT ask "should I continue?" or "is this a good stopping point?" β€” there is nobody to answer. You run until the timer expires or you are manually killed. - - Your workflow is a loop, not a checklist. Once you have a working result, KEEP ITERATING: - - LOOP UNTIL TIME RUNS OUT: - 1. Research the approach (read docs, find examples, check current APIs) - 2. Implement the solution (write code, set up training) - 3. Train and evaluate - 4. Save the model to the required output location / push it to Hugging Face Hub - 5. Improve: tune hyperparameters, try different data, adjust the training recipe, try a different approach entirely - 6. Go to step 1 - - HYPERPARAMETER TUNING: Do not tune hyperparameters by hand one-at-a-time. Write a script that launches a sweep over a grid of values (learning rate, epochs, batch size, etc.) and evaluates each run automatically. One well-designed sweep script beats ten manual experiments. - - If you run out of ideas: go back to the literature. Crawl citation graphs deeper β€” find papers you haven't read yet, read their methodology sections, extract new datasets or training tricks. Look for papers that cite your current approach and improved on it. Try combining recipes from different papers. Re-read the task prompt for angles you missed. Re-read the training logs for clues. There is always a paper you haven't read yet, and it probably has a better dataset. - - Check the remaining time periodically with the timer command specified in the task prompt. Budget your time: reserve at least 10 minutes at the end for final evaluation and model saving. - - The task is NOT done until: - - The required output exists (e.g. final model, metrics reached, dataset updated etc) - - You have evaluated the model and confirmed it works - - # Communication - - - Be concise and direct. No filler, no restating what the user said. - - One-word answers when appropriate for simple questions. - - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs. - - For errors: state what went wrong, why, and what you're doing to fix it. - - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity. - - Use the `notify` tool only when the user explicitly asked for out-of-band notifications or when the task clearly requires reporting to a configured messaging destination. Do not use it for routine chat updates. - - # Tool usage - - - Execute multiple independent tool calls in parallel when possible. - - HF_TOKEN is automatically available in job secrets β€” no need to include it extra. - - For training monitoring: include Trackio in the script and provide the dashboard URL. - - For private/gated datasets: HF_TOKEN is needed β€” it's auto-loaded into job secrets. diff --git a/agent/sft/tagger.py b/agent/sft/tagger.py deleted file mode 100644 index 528bc9d0d80b7e63bc63f527e94cabf59b214966..0000000000000000000000000000000000000000 --- a/agent/sft/tagger.py +++ /dev/null @@ -1,353 +0,0 @@ -"""Derive tags for a session trajectory. - -``tag_session(trajectory)`` β†’ ``list[str]``. Pure function. No filtering, no -mutation β€” tags are purely metadata so downstream pipelines can slice the raw -SFT dataset (``where 'hf_job:succeeded' in tags``) without re-reading trajectories. - -Tag namespaces (all tags are ``":"`` strings): - -* ``tool:`` β€” every tool called at least once (``tool:hf_jobs``, …) -* ``outcome:`` β€” ``completed`` / ``errored`` / ``interrupted`` / - ``ongoing`` / ``doom_loop`` / ``context_exceeded`` -* ``hf_job:`` β€” ``submitted``, ``succeeded``, ``failed``, - ``multi`` (>1), ``oom``, ``push_to_hub`` -* ``gpu:`` β€” ``none``, ``t4``, ``a10g``, ``a100``, ``l40s``, - ``h100``, plus ``gpu:multi`` for x2/x4/x8 flavors -* ``sandbox:`` β€” ``created``, ``gpu``, ``cpu``, ``long_lived`` (>30 min) -* ``feedback:`` β€” ``up``, ``down``, ``mixed``, ``none`` -* ``model:`` β€” ``opus`` / ``sonnet`` / ``haiku`` / ``kimi`` / - ``gpt`` / ``deepseek`` / ``qwen`` / ``other`` -* ``turns:`` β€” ``short`` (<5) / ``medium`` (5–20) / ``long`` (>20) -* ``cost:`` β€” ``low`` (<$0.10) / ``med`` (<$1) / ``high`` -* ``task:`` β€” ``training`` / ``inference`` / ``data_prep`` / - ``research_only`` (heuristic on tools + scripts) - -Tags are deduplicated before returning. -""" - -from __future__ import annotations - -from typing import Iterable - -# Flavor β†’ GPU-family mapping. Keep conservative; unknown flavors β†’ "none". -_GPU_FAMILY = { - "cpu-basic": "none", - "cpu-upgrade": "none", - "t4-small": "t4", - "t4-medium": "t4", - "l4x1": "l40s", - "l4x4": "l40s", - "l40sx1": "l40s", - "l40sx4": "l40s", - "l40sx8": "l40s", - "a10g-small": "a10g", - "a10g-large": "a10g", - "a10g-largex2": "a10g", - "a10g-largex4": "a10g", - "a100-large": "a100", - "a100x2": "a100", - "a100x4": "a100", - "a100x8": "a100", - "h100": "h100", - "h100x8": "h100", -} - -# Substrings that count a flavor as multi-GPU. -_MULTI_GPU_MARKERS = ("x2", "x4", "x8") - -# Tool names that don't touch training/inference or sandbox/jobs. If a session -# only used these, we tag it research_only. -_RESEARCH_ONLY_TOOLS = { - "research", - "github_find_examples", - "github_read_file", - "github_list_repos", - "hf_papers", - "explore_hf_docs", - "fetch_hf_docs", - "hub_repo_details", - "plan", - "hf_inspect_dataset", - "web_search", -} - -# Tool names that signal data manipulation workflows. -_DATA_PREP_TOOLS = {"hf_inspect_dataset", "dataset_tools", "hub_repo_details"} - - -def _model_family(model_name: str | None) -> str: - if not model_name: - return "other" - n = model_name.lower() - if "opus" in n: - return "opus" - if "sonnet" in n: - return "sonnet" - if "haiku" in n: - return "haiku" - if "kimi" in n: - return "kimi" - if "gpt" in n: - return "gpt" - if "deepseek" in n: - return "deepseek" - if "qwen" in n: - return "qwen" - if "llama" in n: - return "llama" - return "other" - - -def _turns_bucket(n: int) -> str: - if n < 5: - return "short" - if n <= 20: - return "medium" - return "long" - - -def _cost_bucket(cost_usd: float) -> str: - if cost_usd < 0.10: - return "low" - if cost_usd < 1.0: - return "med" - return "high" - - -def _flavor_to_gpu_tags(flavor: str) -> list[str]: - family = _GPU_FAMILY.get(flavor, "none") - tags = [f"gpu:{family}"] - if any(m in flavor for m in _MULTI_GPU_MARKERS): - tags.append("gpu:multi") - return tags - - -def _has_oom_signal(tool_outputs: Iterable[str]) -> bool: - for out in tool_outputs: - if not isinstance(out, str): - continue - low = out.lower() - if "outofmemoryerror" in low or "cuda out of memory" in low or "oom" in low: - return True - return False - - -def _infer_task_tag( - tool_names: set[str], - hf_job_submit_scripts: list[str], -) -> str | None: - """Return a ``task:*`` tag or None if we can't tell. - - Heuristic order: training > inference > data_prep > research_only. - """ - # training: any hf_jobs script with a Trainer/SFT/training keyword, OR uses - # hf_jobs at all and a script mentions training APIs. - for script in hf_job_submit_scripts: - low = script.lower() - if any( - k in low - for k in ( - "sftconfig", - "sfttrainer", - "trainer(", - "trainingarguments", - "grpo", - "dpo", - ".train(", - "transformers import", - "trainer import", - "fine-tune", - "finetune", - ) - ): - return "training" - - # inference: sessions that use inference tools but never hf_jobs/sandbox - uses_compute = bool(tool_names & {"hf_jobs", "sandbox_create", "sandbox_exec"}) - if not uses_compute and tool_names & {"inference", "generate", "run_inference"}: - return "inference" - - # data_prep: primarily dataset tools and no training/inference - if tool_names & _DATA_PREP_TOOLS and not uses_compute: - return "data_prep" - - # research_only: every tool used is in the research allow-list - if tool_names and tool_names <= _RESEARCH_ONLY_TOOLS: - return "research_only" - - return None - - -def tag_session(trajectory: dict) -> list[str]: - """Derive tags from a session trajectory. Pure function.""" - tags: set[str] = set() - - events: list[dict] = trajectory.get("events") or [] - messages: list[dict] = trajectory.get("messages") or [] - model_name: str | None = trajectory.get("model_name") - - # model - tags.add(f"model:{_model_family(model_name)}") - - # turns - user_turns = sum(1 for m in messages if m.get("role") == "user") - tags.add(f"turns:{_turns_bucket(user_turns)}") - - # cost + tool-name enumeration + outcome detection - cost_usd = 0.0 - tool_names: set[str] = set() - tool_outputs: list[str] = [] - hf_job_submit_count = 0 - hf_job_submit_scripts: list[str] = [] - hf_job_success_count = 0 - hf_job_fail_count = 0 - hf_job_push_to_hub = False - gpu_tags_seen: set[str] = set() - - # Outcome is the *last* terminal signal. Seed with "ongoing" β€” overridden - # if we see a terminal event. - outcome = "ongoing" - had_error = False - had_doom_loop = False - had_compact = False - - feedback_up = 0 - feedback_down = 0 - - sandbox_created = False - sandbox_hardware: str | None = None - sandbox_lifetime_s: int | None = None - - for ev in events: - et = ev.get("event_type") - data = ev.get("data") or {} - - if et == "llm_call": - cost_usd += float(data.get("cost_usd") or 0.0) - - elif et == "tool_call": - name = data.get("tool") - if name: - tool_names.add(name) - - elif et == "tool_output": - out = data.get("output") - if isinstance(out, str): - tool_outputs.append(out) - - elif et == "hf_job_submit": - hf_job_submit_count += 1 - if data.get("push_to_hub"): - hf_job_push_to_hub = True - flavor = data.get("flavor") or "cpu-basic" - for t in _flavor_to_gpu_tags(flavor): - gpu_tags_seen.add(t) - - elif et == "hf_job_complete": - final = (data.get("final_status") or "").lower() - if final in ("completed", "succeeded", "success"): - hf_job_success_count += 1 - elif final in ("failed", "error", "timeout", "cancelled"): - hf_job_fail_count += 1 - - elif et == "sandbox_create": - sandbox_created = True - sandbox_hardware = data.get("hardware") - - elif et == "sandbox_destroy": - lt = data.get("lifetime_s") - if isinstance(lt, (int, float)): - sandbox_lifetime_s = int(lt) - - elif et == "feedback": - rating = data.get("rating") - if rating == "up": - feedback_up += 1 - elif rating == "down": - feedback_down += 1 - - elif et == "error": - had_error = True - elif et == "turn_complete": - if not had_error: - outcome = "completed" - elif et == "interrupted": - outcome = "interrupted" - elif et == "compacted": - had_compact = True - elif et == "tool_log": - log_text = (data.get("log") or "").lower() - if "doom loop" in log_text: - had_doom_loop = True - - if had_error and outcome not in ("completed", "interrupted"): - outcome = "errored" - - tags.add(f"outcome:{outcome}") - if had_doom_loop: - tags.add("outcome:doom_loop") - if had_compact: - tags.add("outcome:context_exceeded") - - # tools - for name in tool_names: - tags.add(f"tool:{name}") - - # hf_jobs facets - if hf_job_submit_count >= 1: - tags.add("hf_job:submitted") - if hf_job_submit_count > 1: - tags.add("hf_job:multi") - if hf_job_success_count > 0: - tags.add("hf_job:succeeded") - if hf_job_fail_count > 0: - tags.add("hf_job:failed") - if hf_job_push_to_hub: - tags.add("hf_job:push_to_hub") - if _has_oom_signal(tool_outputs): - tags.add("hf_job:oom") - - # gpu tags (from all submitted jobs) - tags.update(gpu_tags_seen) - if "gpu:none" in tags and len(gpu_tags_seen) > 1: - # If any GPU flavor was used, drop the "none" tag for clarity. - tags.discard("gpu:none") - - # sandbox facets - if sandbox_created: - tags.add("sandbox:created") - if sandbox_hardware: - fam = _GPU_FAMILY.get(sandbox_hardware, "none") - tags.add("sandbox:cpu" if fam == "none" else "sandbox:gpu") - if sandbox_lifetime_s is not None and sandbox_lifetime_s > 1800: - tags.add("sandbox:long_lived") - - # feedback - if feedback_up and feedback_down: - tags.add("feedback:mixed") - elif feedback_up: - tags.add("feedback:up") - elif feedback_down: - tags.add("feedback:down") - else: - tags.add("feedback:none") - - # cost bucket - tags.add(f"cost:{_cost_bucket(cost_usd)}") - - # task heuristic (needs scripts β€” pull from the hf_job_submit events' - # matching tool_call arguments in the event list). - for ev in events: - if ev.get("event_type") == "tool_call": - data = ev.get("data") or {} - if data.get("tool") == "hf_jobs": - args = data.get("arguments") or {} - script = args.get("script") or args.get("command") or "" - if isinstance(script, str): - hf_job_submit_scripts.append(script) - - task_tag = _infer_task_tag(tool_names, hf_job_submit_scripts) - if task_tag: - tags.add(f"task:{task_tag}") - - return sorted(tags) diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py index 65c793cbaad3b2f74eacaf1da6038ff0bef893d9..14ef45669bc443c1c005ddde69b4205eb02f46cb 100644 --- a/agent/tools/__init__.py +++ b/agent/tools/__init__.py @@ -20,7 +20,6 @@ from agent.tools.github_read_file import ( ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler from agent.tools.types import ToolResult -from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler __all__ = [ "ToolResult", @@ -37,6 +36,4 @@ __all__ = [ "github_search_code_handler", "HF_INSPECT_DATASET_TOOL_SPEC", "hf_inspect_dataset_handler", - "WEB_SEARCH_TOOL_SPEC", - "web_search_handler", ] diff --git a/agent/tools/dataset_tools.py b/agent/tools/dataset_tools.py index 20add683d40c3b0f550daaae046408d64f23ddbd..39f5d5d85b4478a1dd1e8934397f3b86aad71431 100644 --- a/agent/tools/dataset_tools.py +++ b/agent/tools/dataset_tools.py @@ -6,6 +6,7 @@ to provide everything needed for ML tasks in a single tool call. """ import asyncio +import os from typing import Any, TypedDict import httpx @@ -25,8 +26,9 @@ class SplitConfig(TypedDict): splits: list[str] -def _get_headers(token: str | None = None) -> dict: +def _get_headers() -> dict: """Get auth headers for private/gated datasets""" + token = os.environ.get("HF_TOKEN") if token: return {"Authorization": f"Bearer {token}"} return {} @@ -37,13 +39,12 @@ async def inspect_dataset( config: str | None = None, split: str | None = None, sample_rows: int = 3, - hf_token: str | None = None, ) -> ToolResult: """ Get comprehensive dataset info in one call. All API calls made in parallel for speed. """ - headers = _get_headers(hf_token) + headers = _get_headers() output_parts = [] errors = [] @@ -387,15 +388,22 @@ def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None: HF_INSPECT_DATASET_TOOL_SPEC = { "name": "hf_inspect_dataset", "description": ( - "Inspect a HF dataset in one call: status, configs/splits, schema, sample rows, parquet info.\n\n" - "REQUIRED before any training job to verify dataset format matches training method:\n" - " SFT: needs 'messages', 'text', or 'prompt'/'completion'\n" - " DPO: needs 'prompt', 'chosen', 'rejected'\n" - " GRPO: needs 'prompt'\n" - "All datasets used for training have to be in conversational ChatML format to be compatible with HF libraries.'\n" - "Training will fail with KeyError if columns don't match.\n\n" - "Also use to get example datapoints, understand column names, data types, and available splits before writing any data loading code. " - "Supports private/gated datasets when HF_TOKEN is set." + "Inspect a Hugging Face dataset comprehensively in one call.\n\n" + "## What you get\n" + "- Status check (validates dataset works without errors)\n" + "- All configs and splits (row counts/shares may be '?' when metadata is missing)\n" + "- Column names and types (schema)\n" + "- Sample rows to understand data format\n" + "- Parquet file structure and sizes\n\n" + "## CRITICAL\n" + "**Always inspect datasets before writing training code** to understand:\n" + "- Column names for your dataloader\n" + "- Data types and format\n" + "- Available splits (train/test/validation)\n\n" + "Supports private/gated datasets when HF_TOKEN is set.\n\n" + "## Examples\n" + '{"dataset": "stanfordnlp/imdb"}\n' + '{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n' ), "parameters": { "type": "object", @@ -423,18 +431,14 @@ HF_INSPECT_DATASET_TOOL_SPEC = { } -async def hf_inspect_dataset_handler( - arguments: dict[str, Any], session=None -) -> tuple[str, bool]: +async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router""" try: - hf_token = session.hf_token if session else None result = await inspect_dataset( dataset=arguments["dataset"], config=arguments.get("config"), split=arguments.get("split"), sample_rows=min(arguments.get("sample_rows", 3), 10), - hf_token=hf_token, ) return result["formatted"], not result.get("isError", False) except Exception as e: diff --git a/agent/tools/docs_tools.py b/agent/tools/docs_tools.py index ee40ef353ae05b8d32d4c9a17bd0d9eaa8687532..49a330bedfccb47bcfbf2caf4d51aafa2af1babc 100644 --- a/agent/tools/docs_tools.py +++ b/agent/tools/docs_tools.py @@ -4,6 +4,7 @@ Documentation search tools for exploring HuggingFace and Gradio documentation. import asyncio import json +import os from typing import Any import httpx @@ -286,9 +287,7 @@ def _format_results( # --------------------------------------------------------------------------- -async def explore_hf_docs_handler( - arguments: dict[str, Any], session=None -) -> tuple[str, bool]: +async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]: """Explore documentation structure with optional search query.""" endpoint = arguments.get("endpoint", "").lstrip("/") query = arguments.get("query") @@ -317,9 +316,9 @@ async def explore_hf_docs_handler( return f"Error fetching Gradio docs: {str(e)}", False # HF docs - hf_token = session.hf_token if session else None + hf_token = os.environ.get("HF_TOKEN") if not hf_token: - return "Error: No HF token available (not logged in)", False + return "Error: HF_TOKEN environment variable not set", False try: max_results_int = int(max_results) if max_results is not None else None @@ -379,17 +378,15 @@ async def explore_hf_docs_handler( return f"Unexpected error: {str(e)}", False -async def hf_docs_fetch_handler( - arguments: dict[str, Any], session=None -) -> tuple[str, bool]: +async def hf_docs_fetch_handler(arguments: dict[str, Any]) -> tuple[str, bool]: """Fetch full markdown content of a documentation page.""" url = arguments.get("url", "") if not url: return "Error: No URL provided", False - hf_token = session.hf_token if session else None + hf_token = os.environ.get("HF_TOKEN") if not hf_token: - return "Error: No HF token available (not logged in)", False + return "Error: HF_TOKEN environment variable not set", False if not url.endswith(".md"): url = f"{url}.md" @@ -457,30 +454,20 @@ def _extract_all_endpoints(spec: dict[str, Any]) -> list[dict[str, Any]]: endpoints = [] for path, path_item in spec.get("paths", {}).items(): for method, op in path_item.items(): - if method not in [ - "get", - "post", - "put", - "delete", - "patch", - "head", - "options", - ]: + if method not in ["get", "post", "put", "delete", "patch", "head", "options"]: continue - endpoints.append( - { - "path": path, - "method": method.upper(), - "operationId": op.get("operationId", ""), - "summary": op.get("summary", ""), - "description": op.get("description", ""), - "tags": " ".join(op.get("tags", [])), - "parameters": op.get("parameters", []), - "request_body": op.get("requestBody", {}), - "responses": op.get("responses", {}), - "base_url": base_url, - } - ) + endpoints.append({ + "path": path, + "method": method.upper(), + "operationId": op.get("operationId", ""), + "summary": op.get("summary", ""), + "description": op.get("description", ""), + "tags": " ".join(op.get("tags", [])), + "parameters": op.get("parameters", []), + "request_body": op.get("requestBody", {}), + "responses": op.get("responses", {}), + "base_url": base_url, + }) return endpoints @@ -524,12 +511,7 @@ async def _build_openapi_index() -> tuple[Any, MultifieldParser, list[dict[str, parser = MultifieldParser( ["summary", "description", "operationId", "tags", "param_names"], schema=schema, - fieldboosts={ - "summary": 3.0, - "operationId": 2.0, - "description": 1.0, - "tags": 1.5, - }, + fieldboosts={"summary": 3.0, "operationId": 2.0, "description": 1.0, "tags": 1.5}, group=OrGroup, ) @@ -550,20 +532,11 @@ async def _search_openapi( return [], "Query contained unsupported syntax." with index.searcher() as searcher: - results = searcher.search( - query_obj, limit=limit * 2 - ) # Get extra for tag filtering + results = searcher.search(query_obj, limit=limit * 2) # Get extra for tag filtering matches = [] for hit in results: # Find full endpoint data - ep = next( - ( - e - for e in endpoints - if e["path"] == hit["path"] and e["method"] == hit["method"] - ), - None, - ) + ep = next((e for e in endpoints if e["path"] == hit["path"] and e["method"] == hit["method"]), None) if ep is None: continue # Filter by tag if provided @@ -740,10 +713,7 @@ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]: query = arguments.get("query", "").strip() or None if not tag and not query: - return ( - "Error: Provide either 'query' (keyword search) or 'tag' (category filter), or both.", - False, - ) + return "Error: Provide either 'query' (keyword search) or 'tag' (category filter), or both.", False try: note = None @@ -754,9 +724,7 @@ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]: # If Whoosh found results, return them if results: - return _format_openapi_results( - results, tag=tag, query=query, note=search_note - ), True + return _format_openapi_results(results, tag=tag, query=query, note=search_note), True # Whoosh found nothing - fall back to tag-based if tag provided if tag: @@ -769,9 +737,7 @@ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]: if tag: _, _, endpoints = await _build_openapi_index() results = [ep for ep in endpoints if tag in ep.get("tags", "")] - return _format_openapi_results( - results, tag=tag, query=None, note=note - ), True + return _format_openapi_results(results, tag=tag, query=None, note=note), True return "Error: No results found", False @@ -879,12 +845,17 @@ DOC_ENDPOINTS = [ EXPLORE_HF_DOCS_TOOL_SPEC = { "name": "explore_hf_docs", "description": ( - "Browse HF documentation structure β€” discover all available documentation with 200-char previews.\n\n" - "Use this to find relevant documentation and/or examples with detailed parameter docs and API reference. " - "To be used together with github_find_examples and github_read_file to find working examples and documentation.\n\n" - "Pattern: explore_hf_docs (find relevant pages) β†’ fetch_hf_docs (get full content).\n\n" - "For training tasks: fetch the trainer config docs (SFTConfig, DPOConfig, GRPOConfig) to verify parameter names. " - "Returns top 20 results by default; set max_results (max 50) to adjust." + "Explore Hugging Face documentation structure and discover available pages with 200-character previews. " + "⚠️ MANDATORY: ALWAYS use this BEFORE implementing any ML task (training, fine-tuning, data processing, inference). " + "Your training data may be outdated - current documentation is the source of truth. " + "**Use when:** (1) Starting any implementation task, (2) User asks 'how to' questions, " + "(3) Before writing training/processing code, (4) Researching library capabilities, " + "(5) Verifying API syntax and parameters. " + "**Pattern:** explore (discover structure) β†’ fetch_hf_docs (get details) β†’ implement with researched approach. " + "Returns: Sidebar navigation with titles, URLs, and glimpses of all pages in the selected documentation. " + "**Then:** Use fetch_hf_docs with specific URLs from results to get full content. " + "**Critical for reliability:** Never implement based on internal knowledge without checking current docs first - APIs change frequently." + " By default returns the top 20 results; set max_results (max 50) to adjust." ), "parameters": { "type": "object", @@ -932,7 +903,7 @@ EXPLORE_HF_DOCS_TOOL_SPEC = { "β€’ argilla β€” Data annotation, feedback, and human-in-the-loop workflows.\n" "β€’ distilabel β€” Synthetic data generation and distillation pipelines.\n" "β€’ microsoft-azure β€” Azure deployment and integration guides.\n" - "β€’ kernels β€” Load prebuilt compute kernels (E.g. flash-attn2) from the Hub via `attn_implementation`; avoids compiling flash-attn from source.\n" + "β€’ kernels β€” Lightweight execution environments and notebook-style workflows.\n" "β€’ google-cloud β€” GCP deployment and serving workflows.\n" ), }, @@ -957,10 +928,16 @@ EXPLORE_HF_DOCS_TOOL_SPEC = { HF_DOCS_FETCH_TOOL_SPEC = { "name": "fetch_hf_docs", "description": ( - "Fetch full markdown content of an HF documentation page. Use after explore_hf_docs.\n\n" - "Critical for finding documentation e.g. current trainer configuration parameters (SFTConfig, DPOConfig, etc.) " - "Use for researching solutions and before writing training scripts. Your internal knowledge is outdated.\n\n" - "Provide the full URL from explore_hf_docs results. The .md extension is added automatically." + "Fetch full markdown content of a specific HF documentation page. " + "⚠️ CRITICAL: Use this after explore_hf_docs to get detailed implementation guidance. " + "**Use when:** (1) Found relevant page in explore_hf_docs results, (2) Need complete API documentation, " + "(3) Need training method details (SFT/DPO/GRPO), (4) Need configuration examples, " + "(5) Need parameter descriptions and usage patterns. " + "**Pattern:** explore_hf_docs (find relevant page) β†’ fetch_hf_docs (get full content) β†’ implement using documented approach. " + "Provide full URL from explore_hf_docs results (e.g., 'https://huggingface.co/docs/trl/sft_trainer'). " + "Returns: Complete markdown documentation with examples, parameters, and usage patterns. " + "**For training tasks:** ALWAYS fetch trainer docs (SFTConfig, DPOConfig, etc.) before creating training scripts. " + "**Critical for reliability:** This ensures you use current APIs and best practices." ), "parameters": { "type": "object", diff --git a/agent/tools/edit_utils.py b/agent/tools/edit_utils.py deleted file mode 100644 index 1c6b958192ad8a90c9b3268f6fdb688787d97ea6..0000000000000000000000000000000000000000 --- a/agent/tools/edit_utils.py +++ /dev/null @@ -1,273 +0,0 @@ -""" -Shared utilities for file editing tools β€” fuzzy matching, syntax validation, -and richer edit operations. - -Used by both local_tools.py and the embedded sandbox server. -""" - -from __future__ import annotations - -# ── Unicode normalization map ──────────────────────────────────────────── - -UNICODE_MAP = { - "\u2013": "-", # en-dash - "\u2014": "-", # em-dash - "\u2212": "-", # minus sign - "\u2018": "'", # left single quote - "\u2019": "'", # right single quote - "\u201c": '"', # left double quote - "\u201d": '"', # right double quote - "\u00a0": " ", # non-breaking space - "\u2003": " ", # em space - "\u2002": " ", # en space - "\u200b": "", # zero-width space - "\ufeff": "", # BOM -} - - -def _normalize_unicode(s: str) -> str: - return "".join(UNICODE_MAP.get(c, c) for c in s) - - -# ── 4-pass fuzzy matching ──────────────────────────────────────────────── - - -def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]: - """Find *pattern* in *content* with increasingly relaxed matching. - - Returns (start_index_in_original_content, match_note) or (None, None). - The index always refers to the *original* content string so callers can - use ``content[idx : idx + len(matched_text)]`` for replacement. - - Strategy (mirrors Codex): - 1. Exact match - 2. Right-trim each line (trailing whitespace) - 3. Both-sides trim (all surrounding whitespace per line) - 4. Unicode normalization on top of both-sides trim - """ - # Pass 1 β€” exact - if pattern in content: - return content.index(pattern), None - - # Helper: build a line-stripped version *and* a mapping from stripped - # positions back to original positions. We need this so callers can - # apply the replacement on the original content, not the stripped copy. - - def _build_stripped(text: str, strip_fn): - """Return (stripped_text, line_start_map). - - line_start_map[i] = original byte offset of the start of line i. - """ - orig_lines = text.split("\n") - stripped_lines = [strip_fn(line) for line in orig_lines] - return "\n".join(stripped_lines), orig_lines, stripped_lines - - # Pass 2 β€” right-trim - c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip) - p_rt = "\n".join(line.rstrip() for line in pattern.split("\n")) - idx = c_rt.find(p_rt) - if idx != -1: - orig_idx = _map_back(idx, c_orig_lines, c_rt_lines) - return orig_idx, "(matched after trimming trailing whitespace)" - - # Pass 3 β€” both-sides trim - c_st, _, c_st_lines = _build_stripped(content, str.strip) - p_st = "\n".join(line.strip() for line in pattern.split("\n")) - idx = c_st.find(p_st) - if idx != -1: - orig_idx = _map_back(idx, c_orig_lines, c_st_lines) - return orig_idx, "(matched after trimming whitespace)" - - # Pass 4 β€” unicode normalization + both-sides trim - c_norm = _normalize_unicode(c_st) - p_norm = _normalize_unicode(p_st) - idx = c_norm.find(p_norm) - if idx != -1: - orig_idx = _map_back(idx, c_orig_lines, c_st_lines) - return orig_idx, "(matched after unicode normalization)" - - return None, None - - -def _map_back( - stripped_idx: int, - orig_lines: list[str], - stripped_lines: list[str], -) -> int: - """Map a character index in the stripped/joined text back to the original text.""" - # Walk through stripped lines to find which line the index falls on - pos = 0 - for i, sl in enumerate(stripped_lines): - line_end = pos + len(sl) - if stripped_idx <= line_end: - col_in_stripped = stripped_idx - pos - # Find where this stripped line's content starts in the original line - ol = orig_lines[i] - # The stripped line is a subset of the original line; find its offset - lstripped = len(ol) - len(ol.lstrip()) - orig_col = lstripped + col_in_stripped - # Compute absolute position in original text - orig_pos = sum(len(orig_lines[j]) + 1 for j in range(i)) + orig_col - return orig_pos - pos = line_end + 1 # +1 for the \n - # Fallback: return 0 (shouldn't happen if idx is valid) - return 0 - - -def fuzzy_find_original_match( - content: str, pattern: str -) -> tuple[str | None, str | None]: - """Find the *original* text in content that matches pattern fuzzily. - - Returns (original_matched_text, match_note) or (None, None). - This extracts the exact substring from the original content that - corresponds to the fuzzy match, preserving its original whitespace/unicode. - """ - if pattern in content: - return pattern, None - - idx, note = fuzzy_find(content, pattern) - if idx is None: - return None, None - - # We need to find the original text span that corresponds to the match. - # The match covers len(pattern) worth of *logical* content. - # Count how many original lines the pattern spans. - pattern_lines = pattern.split("\n") - n_lines = len(pattern_lines) - - # Find which original line the match starts on - orig_lines = content.split("\n") - char_pos = 0 - start_line = 0 - for i, ol in enumerate(orig_lines): - if char_pos + len(ol) >= idx: - start_line = i - break - char_pos += len(ol) + 1 - - end_line = min(start_line + n_lines, len(orig_lines)) - # Extract the original lines that were matched - matched_lines = orig_lines[start_line:end_line] - original_text = "\n".join(matched_lines) - return original_text, note - - -# ── Richer edit operations ─────────────────────────────────────────────── - - -def apply_edit( - content: str, - old_str: str, - new_str: str, - mode: str = "replace", - replace_all: bool = False, -) -> tuple[str, int, str | None]: - """Apply an edit operation to content. - - Modes: - - replace: replace first occurrence (or all if replace_all=True) - - replace_all: replace all occurrences (alias) - - append_after: insert new_str after old_str - - prepend_before: insert new_str before old_str - - Returns (new_content, num_replacements, fuzzy_note). - Raises ValueError if old_str not found. - """ - if mode == "replace_all": - replace_all = True - mode = "replace" - - # Try exact match first, then fuzzy - fuzzy_note = None - if old_str not in content: - original_match, fuzzy_note = fuzzy_find_original_match(content, old_str) - if original_match is None: - raise ValueError( - "old_str was not found in the file. Make sure old_str matches " - "the file contents exactly, including whitespace and indentation. " - "Use the read tool to verify the current file contents before retrying." - ) - old_str = original_match - - count = content.count(old_str) - - if mode == "replace": - if count > 1 and not replace_all: - raise ValueError( - f"Found {count} matches of old_str in the file, but replace_all is " - f"false. To replace all occurrences, set replace_all to true. To " - f"replace only one, provide a larger old_str with more surrounding " - f"context to uniquely identify the instance." - ) - if replace_all: - new_content = content.replace(old_str, new_str) - return new_content, count, fuzzy_note - else: - new_content = content.replace(old_str, new_str, 1) - return new_content, 1, fuzzy_note - - elif mode == "append_after": - if replace_all: - new_content = content.replace(old_str, old_str + new_str) - return new_content, count, fuzzy_note - else: - idx = content.index(old_str) + len(old_str) - new_content = content[:idx] + new_str + content[idx:] - return new_content, 1, fuzzy_note - - elif mode == "prepend_before": - if replace_all: - new_content = content.replace(old_str, new_str + old_str) - return new_content, count, fuzzy_note - else: - idx = content.index(old_str) - new_content = content[:idx] + new_str + content[idx:] - return new_content, 1, fuzzy_note - - else: - raise ValueError( - f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before." - ) - - -# ── Syntax validation (Python) ─────────────────────────────────────────── - - -def validate_python(content: str, path: str = "") -> list[str]: - """Lightweight post-write validation for Python files. - - Checks syntax and training script conventions. This runs on the host - (not in the sandbox), so it only does static checks β€” no import resolution - or signature inspection since packages are installed in the sandbox, not here. - - The sandbox server has its own richer version that does real signature - inspection against installed packages. - - Returns a list of warning strings (empty = all good). - Never raises β€” validation failures are advisory only. - """ - import ast - - warnings = [] - - # 1. Syntax check via ast.parse - try: - ast.parse(content) - except SyntaxError as e: - warnings.append(f"Python syntax error at line {e.lineno}: {e.msg}") - return warnings - - # 2. Training script heuristics - if any( - kw in content - for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig") - ): - if "push_to_hub" not in content: - warnings.append( - "Training script warning: no 'push_to_hub' found β€” model may be lost when job ends" - ) - if "hub_model_id" not in content: - warnings.append("Training script warning: no 'hub_model_id' found") - - return warnings diff --git a/agent/tools/github_find_examples.py b/agent/tools/github_find_examples.py index f5f2ddaad0a1959ec3418cc45ed88432a40e13c2..c0d795d93363a93f8f4f3e316f71f988017b98c4 100644 --- a/agent/tools/github_find_examples.py +++ b/agent/tools/github_find_examples.py @@ -405,16 +405,55 @@ def find_examples( GITHUB_FIND_EXAMPLES_TOOL_SPEC = { "name": "github_find_examples", "description": ( - "Find working example scripts in GitHub repositories (from a list of predetermined directories e.g. examples/, scripts/, tutorials/, etc.). " - "Uses fuzzy keyword matching.\n\n" - "MANDATORY before writing any ML training, fine-tuning, or inference code. " - "Your internal knowledge of library APIs is outdated β€” working examples show current API patterns.\n\n" - "Sequence: github_find_examples β†’ github_read_file (study the example) β†’ implement based on what you found.\n\n" - "Skip this only for: simple data queries, status checks, non-code tasks.\n\n" - "Examples:\n" - " {keyword: 'sft', repo: 'trl'} β†’ finds examples/scripts/sft.py\n" - " {keyword: 'grpo', repo: 'trl'} β†’ finds GRPO training examples\n" - " {repo: 'trl', max_results: 20} β†’ lists all available training method examples" + "Discover working code examples, tutorials, scripts, and demos in GitHub repositories. " + "⚠️ CRITICAL: ALWAYS use this BEFORE implementing ML tasks - find working reference code first. " + "Your training data may be outdated; real repository examples show current best practices. " + "**Use when:** (1) Starting any ML implementation (training, inference, evaluation), " + "(2) User asks 'how to' questions about libraries, (3) Need reference implementations, " + "(4) Exploring library capabilities, (5) Before writing training/processing scripts. " + "**Pattern:** github_find_examples (discover) β†’ github_read_file (study code) β†’ implement with researched approach. " + "Returns: List of example files (scripts/notebooks/tutorials) with paths and URLs, sorted by relevance. " + "**Then:** Use github_read_file to read the actual implementation code. " + "**Critical for reliability:** Real examples prevent outdated API usage and show proven patterns. " + "## How it works\n\n" + "1. Fetches all example files (examples/, scripts/, tutorials/, demos/, notebooks/, etc.) from repository\n" + "2. If keyword provided, scores files against keyword using fuzzy matching\n" + "3. Returns best matches sorted by relevance and pattern priority\n" + "4. Provides copyable parameters for github_read_file tool\n\n" + "## Examples\n\n" + "\n" + "// ML Workflow Step: Find GRPO training examples before implementation\n" + "// Task: Starting GRPO fine-tuning project, need reference implementation\n" + "{\n" + " keyword: 'grpo',\n" + " repo: 'trl',\n" + " org: 'huggingface'\n" + "}\n" + "// Returns: examples/scripts/grpo_agent.py, examples/scripts/grpo_vlm.py\n" + "// Next step: github_read_file to study working implementation\n" + "\n\n" + "\n" + "// ML Workflow Step: Discover all available training methods\n" + "// Task: Exploring TRL training options before choosing approach\n" + "{\n" + " repo: 'trl',\n" + " org: 'huggingface',\n" + " max_results: 20\n" + "}\n" + "// Lists: SFT, DPO, GRPO, PPO, reward modeling examples\n" + "// Helps user choose appropriate method\n" + "\n\n" + "\n" + "// ML Workflow Step: Find LoRA fine-tuning examples\n" + "// Task: Learning parameter-efficient fine-tuning patterns\n" + "{\n" + " keyword: 'lora',\n" + " repo: 'peft',\n" + " org: 'huggingface'\n" + "}\n" + "// Discovers LoRA configuration and training examples\n" + "// Shows current PEFT API usage patterns\n" + "" ), "parameters": { "type": "object", diff --git a/agent/tools/github_read_file.py b/agent/tools/github_read_file.py index 485fe277972f8ebf6c52ff62cc488ed2b4e97d9b..02bccef05d53120670f95dd7556e40811fad9db0 100644 --- a/agent/tools/github_read_file.py +++ b/agent/tools/github_read_file.py @@ -250,13 +250,59 @@ def read_file( GITHUB_READ_FILE_TOOL_SPEC = { "name": "github_read_file", "description": ( - "Read file contents from GitHub repositories. Returns first 300 lines by default. " - "Auto-converts Jupyter notebooks to markdown.\n\n" - "Use AFTER github_find_examples to study the working implementation. " - "The purpose is to learn current API patterns β€” imports, trainer configs, dataset handling β€” " - "so your implementation uses correct, up-to-date code.\n\n" + "Read file contents from GitHub repositories with line range support (default 300 lines). " + "⚠️ CRITICAL: Use AFTER github_find_examples to study working implementation code. " + "**Use when:** (1) Found example file via github_find_examples and need full code, " + "(2) Need to read trainer class implementation, (3) Study configuration patterns, " + "(4) Read specific code sections with line ranges, (5) Review code from specific branches/commits. " + "**Pattern:** github_find_examples (discover files) β†’ github_read_file (read code) β†’ implement using researched patterns. " + "Returns: File contents with line numbers, formatted for LLM reading. Auto-converts Jupyter notebooks to markdown. " + "**Then:** Implement using patterns and APIs from the example code. " + "**Critical for reliability:** Reading working examples prevents API errors and shows current best practices. " "Use line_start/line_end for large files (>300 lines) to read specific sections.\n\n" - "When NOT to use: when you don't know the file path (use github_find_examples first)." + "## When to use this tool\n\n" + "- When reading example code, trainer implementations, or configuration files\n" + "- After github_find_examples returns file paths you want to study\n" + "- When investigating specific code sections with line ranges\n" + "- When reading from specific branches, tags, or commits (use ref parameter)\n\n" + "## When NOT to use this tool\n\n" + "- When you don't know exact file path (use github_find_examples or github_search_code first)\n" + "- When searching for code patterns across repos (use github_search_code instead)\n\n" + "## Examples\n\n" + "\n" + "// ML Workflow Step: Read GRPO trainer class after finding via github_find_examples\n" + "// Use case: Understand GRPOTrainer API, parameters, and methods\n" + "{\n" + " repo: 'huggingface/trl',\n" + " path: 'trl/trainer/grpo_trainer.py',\n" + " line_start: 1,\n" + " line_end: 200\n" + "}\n" + "// Read class definition and constructor to understand current API\n" + "// Shows: __init__ parameters, configuration, required arguments\n" + "\n\n" + "\n" + "// ML Workflow Step: Study complete training script from examples\n" + "// Use case: Learn end-to-end VLM fine-tuning workflow\n" + "{\n" + " repo: 'huggingface/trl',\n" + " path: 'examples/scripts/grpo_vlm.py'\n" + "}\n" + "// Returns first 300 lines - shows full training setup\n" + "// Use line_start/line_end if need to read more\n" + "\n\n" + "\n" + "// ML Workflow Step: Check TrainingArguments configuration patterns\n" + "// Use case: Learn how to structure training configs correctly\n" + "{\n" + " repo: 'huggingface/transformers',\n" + " path: 'examples/pytorch/language-modeling/run_clm.py',\n" + " line_start: 50,\n" + " line_end: 150\n" + "}\n" + "// Read argument parsing and config setup section\n" + "// Shows: current parameter names, default values, best practices\n" + "" ), "parameters": { "type": "object", diff --git a/agent/tools/hf_repo_files_tool.py b/agent/tools/hf_repo_files_tool.py index aee00b741662838769d25711602b5afefcb623e8..69dd228bdd3f9b16af8eaedbd3b297eecfdd5714 100644 --- a/agent/tools/hf_repo_files_tool.py +++ b/agent/tools/hf_repo_files_tool.py @@ -10,7 +10,6 @@ from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError -from agent.core.hub_artifacts import is_known_hub_artifact, register_hub_artifact from agent.tools.types import ToolResult OperationType = Literal["list", "read", "upload", "delete"] @@ -40,9 +39,8 @@ def _format_size(size_bytes: int) -> str: class HfRepoFilesTool: """Tool for file operations on HF repos.""" - def __init__(self, hf_token: Optional[str] = None, session: Any = None): + def __init__(self, hf_token: Optional[str] = None): self.api = HfApi(token=hf_token) - self.session = session async def execute(self, args: Dict[str, Any]) -> ToolResult: """Execute the specified operation.""" @@ -63,9 +61,7 @@ class HfRepoFilesTool: if handler: return await handler(args) else: - return self._error( - f"Unknown operation: {operation}. Valid: list, read, upload, delete" - ) + return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete") except RepositoryNotFoundError: return self._error(f"Repository not found: {args.get('repo_id')}") @@ -100,23 +96,17 @@ class HfRepoFilesTool: revision = args.get("revision", "main") path = args.get("path", "") - items = list( - await _async_call( - self.api.list_repo_tree, - repo_id=repo_id, - repo_type=repo_type, - revision=revision, - path_in_repo=path, - recursive=True, - ) - ) + items = list(await _async_call( + self.api.list_repo_tree, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + path_in_repo=path, + recursive=True, + )) if not items: - return { - "formatted": f"No files in {repo_id}", - "totalResults": 0, - "resultsShared": 0, - } + return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0} lines = [] total_size = 0 @@ -128,16 +118,9 @@ class HfRepoFilesTool: lines.append(f"{item.path}/") url = _build_repo_url(repo_id, repo_type) - response = ( - f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" - + "\n".join(lines) - ) + response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines) - return { - "formatted": response, - "totalResults": len(items), - "resultsShared": len(items), - } + return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)} async def _read(self, args: Dict[str, Any]) -> ToolResult: """Read file content from a repository.""" @@ -177,13 +160,8 @@ class HfRepoFilesTool: except UnicodeDecodeError: import os - size = os.path.getsize(file_path) - return { - "formatted": f"Binary file ({_format_size(size)})", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1} async def _upload(self, args: Dict[str, Any]) -> ToolResult: """Upload content to a repository.""" @@ -216,16 +194,6 @@ class HfRepoFilesTool: create_pr=create_pr, ) - if not create_pr and is_known_hub_artifact(self.session, repo_id, repo_type): - await _async_call( - register_hub_artifact, - self.api, - repo_id, - repo_type, - session=self.session, - force=path == "README.md", - ) - url = _build_repo_url(repo_id, repo_type) if create_pr and hasattr(result, "pr_url"): response = f"**Uploaded as PR**\n{result.pr_url}" @@ -267,12 +235,7 @@ class HfRepoFilesTool: def _error(self, message: str) -> ToolResult: """Return an error result.""" - return { - "formatted": message, - "totalResults": 0, - "resultsShared": 0, - "isError": True, - } + return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} # Tool specification @@ -349,13 +312,10 @@ HF_REPO_FILES_TOOL_SPEC = { } -async def hf_repo_files_handler( - arguments: Dict[str, Any], session=None -) -> tuple[str, bool]: +async def hf_repo_files_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router.""" try: - hf_token = session.hf_token if session else None - tool = HfRepoFilesTool(hf_token=hf_token, session=session) + tool = HfRepoFilesTool() result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: diff --git a/agent/tools/hf_repo_git_tool.py b/agent/tools/hf_repo_git_tool.py index cfff4120b089aa7923c2a46c5c3da22cf201457f..a2b4063501c1b971e2a40a1414eb7c323ea5dbe3 100644 --- a/agent/tools/hf_repo_git_tool.py +++ b/agent/tools/hf_repo_git_tool.py @@ -10,24 +10,14 @@ from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError -from agent.core.hub_artifacts import register_hub_artifact from agent.tools.types import ToolResult OperationType = Literal[ - "create_branch", - "delete_branch", - "create_tag", - "delete_tag", + "create_branch", "delete_branch", + "create_tag", "delete_tag", "list_refs", - "create_pr", - "list_prs", - "get_pr", - "merge_pr", - "close_pr", - "comment_pr", - "change_pr_status", - "create_repo", - "update_repo", + "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", + "create_repo", "update_repo", ] @@ -46,9 +36,8 @@ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str: class HfRepoGitTool: """Tool for git-like operations on HF repos.""" - def __init__(self, hf_token: Optional[str] = None, session: Any = None): + def __init__(self, hf_token: Optional[str] = None): self.api = HfApi(token=hf_token) - self.session = session async def execute(self, args: Dict[str, Any]) -> ToolResult: """Execute the specified operation.""" @@ -142,11 +131,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}" - return { - "formatted": f"**Branch created:** {branch}\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1} async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult: """Delete a branch.""" @@ -167,11 +152,7 @@ class HfRepoGitTool: repo_type=repo_type, ) - return { - "formatted": f"**Branch deleted:** {branch}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # TAG OPERATIONS @@ -202,11 +183,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}" - return { - "formatted": f"**Tag created:** {tag}\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1} async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult: """Delete a tag.""" @@ -227,11 +204,7 @@ class HfRepoGitTool: repo_type=repo_type, ) - return { - "formatted": f"**Tag deleted:** {tag}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # LIST REFS @@ -253,9 +226,7 @@ class HfRepoGitTool: ) branches = [b.name for b in refs.branches] if refs.branches else [] - tags = ( - [t.name for t in refs.tags] if hasattr(refs, "tags") and refs.tags else [] - ) + tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else [] url = _build_repo_url(repo_id, repo_type) lines = [f"**{repo_id}**", url, ""] @@ -270,11 +241,7 @@ class HfRepoGitTool: else: lines.append("**Tags:** none") - return { - "formatted": "\n".join(lines), - "totalResults": len(branches) + len(tags), - "resultsShared": len(branches) + len(tags), - } + return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)} # ========================================================================= # PR OPERATIONS @@ -303,7 +270,7 @@ class HfRepoGitTool: url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}" return { - "formatted": f'**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision="refs/pr/{result.num}"', + "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"", "totalResults": 1, "resultsShared": 1, } @@ -318,27 +285,17 @@ class HfRepoGitTool: repo_type = args.get("repo_type", "model") status = args.get("status", "all") # open, closed, all - discussions = list( - self.api.get_repo_discussions( - repo_id=repo_id, - repo_type=repo_type, - discussion_status=status if status != "all" else None, - ) - ) + discussions = list(self.api.get_repo_discussions( + repo_id=repo_id, + repo_type=repo_type, + discussion_status=status if status != "all" else None, + )) if not discussions: - return { - "formatted": f"No discussions in {repo_id}", - "totalResults": 0, - "resultsShared": 0, - } + return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0} url = _build_repo_url(repo_id, repo_type) - lines = [ - f"**{repo_id}** - {len(discussions)} discussions", - f"{url}/discussions", - "", - ] + lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""] for d in discussions[:20]: if d.status == "draft": @@ -352,11 +309,7 @@ class HfRepoGitTool: type_label = "PR" if d.is_pull_request else "D" lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}") - return { - "formatted": "\n".join(lines), - "totalResults": len(discussions), - "resultsShared": min(20, len(discussions)), - } + return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))} async def _get_pr(self, args: Dict[str, Any]) -> ToolResult: """Get PR details.""" @@ -382,7 +335,7 @@ class HfRepoGitTool: "draft": "Draft", "open": "Open", "merged": "Merged", - "closed": "Closed", + "closed": "Closed" } status = status_map.get(pr.status, pr.status.capitalize()) type_label = "Pull Request" if pr.is_pull_request else "Discussion" @@ -396,13 +349,9 @@ class HfRepoGitTool: if pr.is_pull_request: if pr.status == "draft": - lines.append( - f'\nTo add commits: upload with revision="refs/pr/{pr_num}"' - ) + lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") elif pr.status == "open": - lines.append( - f'\nTo add commits: upload with revision="refs/pr/{pr_num}"' - ) + lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1} @@ -428,11 +377,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return { - "formatted": f"**PR #{pr_num} merged**\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1} async def _close_pr(self, args: Dict[str, Any]) -> ToolResult: """Close a PR/discussion.""" @@ -456,11 +401,7 @@ class HfRepoGitTool: repo_type=repo_type, ) - return { - "formatted": f"**Discussion #{pr_num} closed**", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1} async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult: """Add a comment to a PR/discussion.""" @@ -486,11 +427,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return { - "formatted": f"**Comment added to #{pr_num}**\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1} async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult: """Change PR/discussion status (mainly to convert draft to open).""" @@ -518,11 +455,7 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return { - "formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # REPO MANAGEMENT @@ -540,9 +473,7 @@ class HfRepoGitTool: space_sdk = args.get("space_sdk") if repo_type == "space" and not space_sdk: - return self._error( - "space_sdk required for spaces (gradio/streamlit/docker/static)" - ) + return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)") kwargs = { "repo_id": repo_id, @@ -554,17 +485,6 @@ class HfRepoGitTool: kwargs["space_sdk"] = space_sdk result = await _async_call(self.api.create_repo, **kwargs) - extra_metadata = None - if repo_type == "space" and space_sdk: - extra_metadata = {"sdk": space_sdk} - await _async_call( - register_hub_artifact, - self.api, - repo_id, - repo_type, - session=self.session, - extra_metadata=extra_metadata, - ) return { "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}", @@ -584,9 +504,7 @@ class HfRepoGitTool: gated = args.get("gated") if private is None and gated is None: - return self._error( - "Specify private (bool) or gated ('auto'/'manual'/false)" - ) + return self._error("Specify private (bool) or gated ('auto'/'manual'/false)") kwargs = {"repo_id": repo_id, "repo_type": repo_type} if private is not None: @@ -603,20 +521,11 @@ class HfRepoGitTool: changes.append(f"gated={gated}") url = f"{_build_repo_url(repo_id, repo_type)}/settings" - return { - "formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", - "totalResults": 1, - "resultsShared": 1, - } + return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1} def _error(self, message: str) -> ToolResult: """Return an error result.""" - return { - "formatted": message, - "totalResults": 0, - "resultsShared": 0, - "isError": True, - } + return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} # Tool specification @@ -662,20 +571,10 @@ HF_REPO_GIT_TOOL_SPEC = { "operation": { "type": "string", "enum": [ - "create_branch", - "delete_branch", - "create_tag", - "delete_tag", - "list_refs", - "create_pr", - "list_prs", - "get_pr", - "merge_pr", - "close_pr", - "comment_pr", - "change_pr_status", - "create_repo", - "update_repo", + "create_branch", "delete_branch", + "create_tag", "delete_tag", "list_refs", + "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", + "create_repo", "update_repo", ], "description": "Operation to execute", }, @@ -754,13 +653,10 @@ HF_REPO_GIT_TOOL_SPEC = { } -async def hf_repo_git_handler( - arguments: Dict[str, Any], session=None -) -> tuple[str, bool]: +async def hf_repo_git_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router.""" try: - hf_token = session.hf_token if session else None - tool = HfRepoGitTool(hf_token=hf_token, session=session) + tool = HfRepoGitTool() result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: diff --git a/agent/tools/jobs_tool.py b/agent/tools/jobs_tool.py index 29d6b3017ab3c5641b01d3573ed560d98c103c6b..18e7705cc79a6c818fb0b6ff7cfa44f871b4c4e2 100644 --- a/agent/tools/jobs_tool.py +++ b/agent/tools/jobs_tool.py @@ -7,24 +7,20 @@ Refactored to use official huggingface-hub library instead of custom HTTP client import asyncio import base64 import http.client -import logging +import os import re -import shlex -from typing import Any, Awaitable, Callable, Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional, Callable, Awaitable + +import logging import httpx from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError -from agent.core.hf_access import ( - JobsAccessError, - is_billing_error, - resolve_jobs_namespace, -) -from agent.core.hub_artifacts import build_hub_artifact_sitecustomize from agent.core.session import Event -from agent.tools.trackio_seed import ensure_trackio_dashboard from agent.tools.types import ToolResult + +logger = logging.getLogger(__name__) from agent.tools.utilities import ( format_job_details, format_jobs_table, @@ -32,36 +28,39 @@ from agent.tools.utilities import ( format_scheduled_jobs_table, ) -logger = logging.getLogger(__name__) - # Hardware flavors -CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"] +CPU_FLAVORS = ["cpu-basic", "cpu-upgrade", "cpu-performance", "cpu-xl"] GPU_FLAVORS = [ + "sprx8", + "zero-a10g", "t4-small", "t4-medium", - "a10g-small", - "a10g-large", - "a10g-largex2", - "a10g-largex4", - "a100-large", - "a100x4", - "a100x8", "l4x1", "l4x4", "l40sx1", "l40sx4", "l40sx8", + "a10g-small", + "a10g-large", + "a10g-largex2", + "a10g-largex4", + "a100-large", + "h100", + "h100x8", ] # Detailed specs for display (vCPU/RAM/GPU VRAM) -CPU_FLAVORS_DESC = "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB)" +CPU_FLAVORS_DESC = ( + "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB), cpu-performance, cpu-xl" +) GPU_FLAVORS_DESC = ( "t4-small(4vCPU/15GB/GPU 16GB), t4-medium(8vCPU/30GB/GPU 16GB), " - "a10g-small(4vCPU/15GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), " - "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), " - "a100-large(12vCPU/142GB/GPU 80GB), a100x4(48vCPU/568GB/GPU 320GB), a100x8(96vCPU/1136GB/GPU 640GB), " "l4x1(8vCPU/30GB/GPU 24GB), l4x4(48vCPU/186GB/GPU 96GB), " - "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB)" + "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB), " + "a10g-small(4vCPU/14GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), " + "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), " + "a100-large(12vCPU/142GB/GPU 80GB), h100(23vCPU/240GB/GPU 80GB), h100x8(184vCPU/1920GB/GPU 640GB), " + "zero-a10g(dynamic alloc)" ) SPECIALIZED_FLAVORS = ["inf2x6"] ALL_FLAVORS = CPU_FLAVORS + GPU_FLAVORS + SPECIALIZED_FLAVORS @@ -123,33 +122,11 @@ def _filter_uv_install_output(logs: list[str]) -> list[str]: return logs -_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07") - - -def _strip_ansi(text: str) -> str: - return _ANSI_RE.sub("", text) - - -_DEFAULT_ENV = { - "HF_HUB_DISABLE_PROGRESS_BARS": "1", - "TQDM_DISABLE": "1", - "TRANSFORMERS_VERBOSITY": "warning", - "HF_HUB_ENABLE_HF_TRANSFER": "1", - "UV_NO_PROGRESS": "1", -} - - -def _add_default_env(params: Dict[str, Any] | None) -> Dict[str, Any]: - """Inject default env vars for clean, agent-friendly output.""" - result = dict(_DEFAULT_ENV) - result.update(params or {}) # user-provided values override defaults - return result - - def _add_environment_variables( params: Dict[str, Any] | None, user_token: str | None = None ) -> Dict[str, Any]: - token = user_token or "" + # Prefer the authenticated user's OAuth token, fall back to global env var + token = user_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or "" # Start with user-provided env vars, then force-set token last result = dict(params or {}) @@ -239,26 +216,6 @@ def _resolve_uv_command( return _build_uv_command(script, with_deps, python, script_args) -def _wrap_command_with_artifact_bootstrap( - command: list[str], session: Any = None -) -> list[str]: - """Install sitecustomize hooks before the user command runs in HF Jobs.""" - sitecustomize = build_hub_artifact_sitecustomize(session) - if not sitecustomize: - return command - - encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii") - original_command = shlex.join(command) - shell = ( - 'set -e; _ml_intern_artifacts_dir="$(mktemp -d)"; ' - f"printf %s {shlex.quote(encoded)} | base64 -d " - '> "$_ml_intern_artifacts_dir/sitecustomize.py"; ' - 'export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"; ' - f"exec {original_command}" - ) - return ["/bin/sh", "-lc", shell] - - async def _async_call(func, *args, **kwargs): """Wrap synchronous HfApi calls for async context""" return await asyncio.to_thread(func, *args, **kwargs) @@ -324,18 +281,12 @@ class HfJobsTool: self, hf_token: Optional[str] = None, namespace: Optional[str] = None, - jobs_access: Any = None, log_callback: Optional[Callable[[str], Awaitable[None]]] = None, - session: Any = None, - tool_call_id: Optional[str] = None, ): self.hf_token = hf_token self.api = HfApi(token=hf_token) self.namespace = namespace - self.jobs_access = jobs_access self.log_callback = log_callback - self.session = session - self.tool_call_id = tool_call_id async def execute(self, params: Dict[str, Any]) -> ToolResult: """Execute the specified operation""" @@ -407,31 +358,6 @@ class HfJobsTool: "isError": True, } - async def _seed_trackio_dashboard(self, space_id: str) -> None: - """Idempotently install trackio dashboard files into *space_id* before - the job runs. Surfaces seed progress as tool_log events but never - raises β€” a seed failure should not block job submission, since trackio - often still works when the Space already has dashboard code from a - previous run. - """ - loop = asyncio.get_running_loop() - - def _log(msg: str) -> None: - if self.session is None: - return - loop.call_soon_threadsafe( - self.session.event_queue.put_nowait, - Event(event_type="tool_log", data={"tool": "hf_jobs", "log": msg}), - ) - - try: - await asyncio.to_thread( - ensure_trackio_dashboard, space_id, self.hf_token, _log - ) - except Exception as e: - logger.warning(f"trackio dashboard seed failed for {space_id}: {e}") - _log(f"trackio dashboard seed failed: {e}") - async def _wait_for_job_completion( self, job_id: str, namespace: Optional[str] = None ) -> tuple[str, list[str]]: @@ -456,9 +382,7 @@ class HfJobsTool: def log_producer(): try: # fetch_job_logs is a blocking sync generator - logs_gen = self.api.fetch_job_logs( - job_id=job_id, namespace=namespace - ) + logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace) for line in logs_gen: # Push line to queue thread-safely loop.call_soon_threadsafe(queue.put_nowait, line) @@ -529,17 +453,11 @@ class HfJobsTool: await asyncio.sleep(retry_delay) continue - # Fetch final job status β€” retry briefly if still RUNNING - # (the API may lag a few seconds behind the log stream ending) - final_status = "UNKNOWN" - for _ in range(6): - job_info = await _async_call( - self.api.inspect_job, job_id=job_id, namespace=namespace - ) - final_status = job_info.status.stage - if final_status in terminal_states: - break - await asyncio.sleep(2.5) + # Fetch final job status + job_info = await _async_call( + self.api.inspect_job, job_id=job_id, namespace=namespace + ) + final_status = job_info.status.stage return final_status, all_logs @@ -582,122 +500,17 @@ class HfJobsTool: image = args.get("image", "python:3.12") job_type = "Docker" - command = _wrap_command_with_artifact_bootstrap(command, self.session) - # Run the job - flavor = args.get("hardware_flavor", "cpu-basic") - timeout_str = args.get("timeout", "30m") - - # Trackio: agent-declared space + project become env vars on the job - # so trackio.init() picks them up automatically. We also surface them - # in tool_state_change so the frontend can embed the dashboard. - env_dict = _add_default_env(args.get("env")) - trackio_space_id = args.get("trackio_space_id") - trackio_project = args.get("trackio_project") - if trackio_space_id: - env_dict["TRACKIO_SPACE_ID"] = trackio_space_id - await self._seed_trackio_dashboard(trackio_space_id) - if trackio_project: - env_dict["TRACKIO_PROJECT"] = trackio_project - - try: - job = await _async_call( - self.api.run_job, - image=image, - command=command, - env=env_dict, - secrets=_add_environment_variables( - args.get("secrets"), self.hf_token - ), - flavor=flavor, - timeout=timeout_str, - namespace=self.namespace, - ) - except HfHubHTTPError as e: - if is_billing_error(str(e)): - if self.session and self.tool_call_id: - await self.session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": self.tool_call_id, - "tool": "hf_jobs", - "state": "billing_required", - "namespace": self.namespace, - }, - ) - ) - return { - "formatted": ( - f"Hugging Face Jobs rejected this run because the " - f"namespace `{self.namespace}` has no available credits. " - "HF Jobs are billed with namespace credits, which are " - "separate from HF Pro membership. Tell the user to add " - "credits at https://huggingface.co/settings/billing β€” " - "once topped up, re-run this same job. (Switching " - "namespaces is fine if another wallet has credits.)" - ), - "totalResults": 0, - "resultsShared": 0, - "isError": True, - } - raise - - # Track job ID for cancellation on interrupt - if self.session: - self.session._running_job_ids.add(job.id) - - # Send job URL immediately after job creation (before waiting for completion) - if self.session and self.tool_call_id: - state_data: Dict[str, Any] = { - "tool_call_id": self.tool_call_id, - "tool": "hf_jobs", - "state": "running", - "jobUrl": job.url, - } - if trackio_space_id: - state_data["trackioSpaceId"] = trackio_space_id - if trackio_project: - state_data["trackioProject"] = trackio_project - await self.session.send_event( - Event(event_type="tool_state_change", data=state_data) - ) - - # Telemetry: job submission + completion (infra consumption signal). - submit_ts = None - if self.session: - from agent.core import telemetry - - submit_ts = await telemetry.record_hf_job_submit( - self.session, - job, - { - **args, - "hardware_flavor": flavor, - "timeout": timeout_str, - "namespace": self.namespace, - }, - image=image, - job_type=job_type, - ) - # Top-up signal: this submit succeeded after a prior billing - # block in the same session, and we haven't fired the event - # yet β€” the user came back from the HF billing flow. - events = self.session.logged_events - already_fired = any( - e.get("event_type") == "credits_topped_up" for e in events - ) - if not already_fired: - blocked = any( - e.get("event_type") == "tool_state_change" - and (e.get("data") or {}).get("state") == "billing_required" - for e in events - ) - if blocked: - await telemetry.record_credits_topped_up( - self.session, - namespace=self.namespace, - ) + job = await _async_call( + self.api.run_job, + image=image, + command=command, + env=args.get("env"), + secrets=_add_environment_variables(args.get("secrets"), self.hf_token), + flavor=args.get("hardware_flavor", "cpu-basic"), + timeout=args.get("timeout", "30m"), + namespace=self.namespace, + ) # Wait for completion and stream logs logger.info(f"{job_type} job started: {job.url}") @@ -708,44 +521,11 @@ class HfJobsTool: namespace=self.namespace, ) - if self.session and submit_ts is not None: - from agent.core import telemetry - - await telemetry.record_hf_job_complete( - self.session, - job, - flavor=flavor, - final_status=final_status, - submit_ts=submit_ts, - ) - - # Untrack job ID (completed or failed, no longer needs cancellation) - if self.session: - self.session._running_job_ids.discard(job.id) - - # Notify frontend of final status - if self.session and self.tool_call_id: - final_data: Dict[str, Any] = { - "tool_call_id": self.tool_call_id, - "tool": "hf_jobs", - "state": final_status.lower(), - "jobUrl": job.url, - } - if trackio_space_id: - final_data["trackioSpaceId"] = trackio_space_id - if trackio_project: - final_data["trackioProject"] = trackio_project - await self.session.send_event( - Event(event_type="tool_state_change", data=final_data) - ) - # Filter out UV package installation output filtered_logs = _filter_uv_install_output(all_logs) # Format all logs for the agent - log_text = ( - _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)" - ) + log_text = "\n".join(filtered_logs) if filtered_logs else "(no logs)" response = f"""{job_type} job completed! @@ -822,7 +602,7 @@ class HfJobsTool: "resultsShared": 0, } - log_text = _strip_ansi("\n".join(logs)) + log_text = "\n".join(logs) return { "formatted": f"**Logs for {job_id}:**\n\n```\n{log_text}\n```", "totalResults": 1, @@ -937,15 +717,13 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}} image = args.get("image", "python:3.12") job_type = "Docker" - command = _wrap_command_with_artifact_bootstrap(command, self.session) - # Create scheduled job scheduled_job = await _async_call( self.api.create_scheduled_job, image=image, command=command, schedule=schedule, - env=_add_default_env(args.get("env")), + env=args.get("env"), secrets=_add_environment_variables(args.get("secrets"), self.hf_token), flavor=args.get("hardware_flavor", "cpu-basic"), timeout=args.get("timeout", "30m"), @@ -1105,34 +883,56 @@ To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_ HF_JOBS_TOOL_SPEC = { "name": "hf_jobs", "description": ( - "Execute Python scripts or Docker containers on HF cloud infrastructure.\n\n" - "Two modes (mutually exclusive): Python mode (script + dependencies) or Docker mode (command + image). " - "Provide exactly ONE of 'script' or 'command'.\n\n" - "BEFORE submitting training/fine-tuning jobs:\n" - "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. " - "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n" - "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n" - "- Training config MUST include push_to_hub=True and hub_model_id. " - "Job storage is EPHEMERAL β€” all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n" - "- Include trackio monitoring and provide the dashboard URL to the user. " - "When the script uses report_to='trackio', also pass `trackio_space_id` " - "(e.g. '/mlintern-<8char>') and `trackio_project` as tool args β€” " - "they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n" - "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. " - "Only then submit the remaining jobs. Never submit all at once β€” if there's a bug, all jobs fail.\n\n" - "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n" - f"Hardware: CPU: {CPU_FLAVORS_DESC}. GPU: {GPU_FLAVORS_DESC}.\n" - "Common picks: t4-small ($0.60/hr, 1-3B), a10g-large ($2/hr, 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+). " - "Note: a10g-small and a10g-large have the SAME 24GB GPU β€” the difference is CPU/RAM only.\n\n" - "OOM RECOVERY: When a training job fails with CUDA OOM:\n" - "1. Reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally (keep effective batch size identical)\n" - "2. Enable gradient_checkpointing=True\n" - "3. Upgrade to larger GPU (a10gβ†’a100β†’h100)\n" - "Do NOT switch training methods (e.g. full SFT to LoRA) or reduce max_length β€” those change what the user gets and require explicit approval.\n\n" - "Examples:\n" - "Training: {'operation': 'run', 'script': '/app/train.py', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a100-large', 'timeout': '8h'}\n" - "Monitor: {'operation': 'ps'}, {'operation': 'logs', 'job_id': 'xxx'}, {'operation': 'cancel', 'job_id': 'xxx'}" - "Docker: {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2'], 'image': 'duckdb/duckdb', 'hardware_flavor': 'cpu-basic', 'timeout': '1h'}\n" + "Execute Python scripts or Docker containers on HF cloud infrastructure (CPUs/GPUs) in one of two modes. " + "\n\n" + "**Two Modes (mutually exclusive):**\n" + "1. Python mode: using 'script' arg (REQUIRED) + 'dependencies'\n" + "2. Docker mode: using 'command' arg (REQUIRED) + 'image'\n\n" + "🚨 **REQUIRED:** You MUST provide exactly ONE of: 'script' (Python code as string) OR 'command' (Docker command as array). " + "They are mutually exclusive - provide one or the other, never both, never neither. " + "Do NOT call with just {'operation': 'run'} - always include your code. Example: {'operation': 'run', 'script': 'import torch; print(torch.cuda.is_available())', 'dependencies': ['torch']} or {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2']', 'image': 'duckdb/duckdb'}\n\n" + "⚠️ CRITICAL for reliability: (1) Jobs run ASYNC - provide monitoring URL immediately, don't poll; " + "(2) Set timeout >30min (default too short - training needs 2-8h); " + "(3) HF_TOKEN auto-loaded to secrets for Hub ops (push_to_hub, private repos); " + "(4) Job storage EPHEMERAL - MUST push_to_hub() or ALL work is LOST. " + "**Use when:** User wants cloud compute, training models, data processing, batch inference, GPU workloads, scheduled tasks. " + "ALWAYS use this tool (βœ“), never bash 'hf jobs' commands (βœ—). Pass script content inline (βœ“), don't save to files unless requested (βœ—). " + "\n\n" + "**Operations:** run, ps, logs, inspect, cancel, scheduled run, scheduled ps, scheduled inspect, scheduled delete, scheduled suspend, scheduled resume. " + "**Available Hardware (vCPU/RAM/GPU):**\n" + f"β€’ CPU: {CPU_FLAVORS_DESC}\n" + f"β€’ GPU: {GPU_FLAVORS_DESC}\n" + " β—¦ Common: t4-small ($0.60/hr, demos/1-3B models), a10g-small ($1/hr), a10g-large ($2/hr, production 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+)\n\n" + "**After Submission Ground Rules:**\n" + "βœ“ Return immediately with job ID and monitoring URL\n" + "βœ“ Provide expected completion time and cost estimate\n" + "βœ“ For training: Include Trackio dashboard URL\n" + "βœ“ Note user can check status later\n" + "βœ— DON'T poll logs automatically\n" + "βœ— DON'T wait for completion\n" + "βœ— DON'T check status unless user asks\n\n" + "**For Training Tasks:**\n" + "β€’ ALWAYS research TRL docs first: explore_hf_docs('trl') β†’ fetch_hf_docs()\n" + "β€’ ALWAYS validate dataset format with hub_repo_details (SFT needs messages/text, DPO needs chosen/rejected)\n" + "β€’ ALWAYS include Trackio monitoring in script (explore_hf_docs('trackio'))\n" + "β€’ ALWAYS enable push_to_hub=True in training config\n" + "β€’ Set timeout 2-8h for training (NOT default 30m)\n" + "β€’ Confirm model/dataset choices with user before submitting\n\n" + "**Examples:**\n\n" + "**Training - Fine-tune LLM:**\n" + "{'operation': 'run', 'script': '# Training script with TRL\\nfrom trl import SFTConfig, SFTTrainer\\nfrom transformers import AutoModelForCausalLM\\nmodel = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen3-4B\")\\n# ... researched implementation from docs ...\\ntrainer.train()\\ntrainer.push_to_hub(\"user-name/my-model\")', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a10g-large', 'timeout': '4h'}\n\n" + "**Data Processing:**\n" + "{'operation': 'run', 'script': 'from datasets import load_dataset\\nds = load_dataset(\"data\")\\n# process...\\nds.push_to_hub(\"user/processed\")', 'dependencies': ['datasets', 'pandas'], 'hardware_flavor': 'cpu-upgrade', 'timeout': '2h'}\n\n" + "**Scheduled Daily Job:**\n" + "{'operation': 'scheduled run', 'schedule': '@daily', 'script': 'from datasets import Dataset\\nimport pandas as pd\\n# scrape/generate data\\ndf = pd.DataFrame(data)\\nds = Dataset.from_pandas(df)\\nds.push_to_hub(\"user-name/daily-dataset\")', 'dependencies': ['datasets', 'pandas'], 'hardware_flavor': 'cpu-basic'}\n\n" + "**Docker Mode:**\n" + "{'operation': 'run', 'image': 'pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime', 'command': ['python', 'train.py', '--epochs', '10'], 'hardware_flavor': 'a100-large'}\n\n" + "**Monitor Operations:**\n" + "{'operation': 'ps'} - List all jobs\n" + "{'operation': 'logs', 'job_id': 'xxx'} - Stream logs (only when user requests)\n" + "{'operation': 'inspect', 'job_id': 'xxx'} - Get job details\n" + "{'operation': 'cancel', 'job_id': 'xxx'} - Stop job\n\n" + "⚠️ CRITICAL: Files created during execution are DELETED when job finishes. MUST push_to_hub() all outputs (models, datasets, artifacts) in script. For logs/scripts, use hf_private_repos after completion." ), "parameters": { "type": "object", @@ -1152,93 +952,58 @@ HF_JOBS_TOOL_SPEC = { "scheduled suspend", "scheduled resume", ], - "description": "Operation to execute.", + "description": ( + "Operation to execute. Valid values: [run, ps, logs, inspect, cancel, " + "scheduled run, scheduled ps, scheduled inspect, scheduled delete, " + "scheduled suspend, scheduled resume]" + ), }, + # Python/UV specific parameters "script": { "type": "string", - "description": ( - "Python code or sandbox file path (e.g. '/app/train.py') or URL. " - "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. " - "Mutually exclusive with 'command'." - ), + "description": "Python code to execute. Triggers Python mode (auto pip install). Use with 'run'/'scheduled run'. Mutually exclusive with 'command'.", }, "dependencies": { "type": "array", "items": {"type": "string"}, - "description": ( - "Pip packages to install. Include ALL required packages. " - "Common training set: ['transformers', 'trl', 'torch', 'datasets', 'trackio', 'accelerate']. " - "Only used with 'script'." - ), + "description": "Pip packages to install. Example: ['trl', 'torch', 'datasets', 'transformers']. Only used with 'script'.", }, + # Docker specific parameters "image": { "type": "string", - "description": "Docker image. Optional β€” auto-selected if not provided. Use with 'command'.", + "description": "Docker image. Example: 'pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime'. Use with 'run'/'scheduled run'. Optional (auto-selected if not provided).", }, "command": { "type": "array", "items": {"type": "string"}, - "description": "Command to execute as list. Triggers Docker mode. Mutually exclusive with 'script'.", + "description": "Command to execute as list. Example: ['python', 'train.py', '--epochs', '10']. Triggers Docker mode. Use with 'run'/'scheduled run'. Mutually exclusive with 'script'.", }, + # Hardware and environment "hardware_flavor": { "type": "string", - "description": ( - "Hardware type. Sizing guide: 1-3B params β†’ t4-small/a10g-small, " - "7-13B β†’ a10g-large, 30B+ β†’ a100-large, 70B+ β†’ h100/h100x8. " - f"All options: CPU: {CPU_FLAVORS}. GPU: {GPU_FLAVORS}." - ), + "description": f"Hardware type. Available CPU flavors: {CPU_FLAVORS}. Available GPU flavors: {GPU_FLAVORS}. Use with 'run'/'scheduled run'.", }, "timeout": { "type": "string", - "description": ( - "Maximum job runtime. MUST be >2h for any training job β€” default 30m kills training mid-run. " - "Guidelines: 1-3B models: 3-4h, 7-13B: 6-8h, 30B+: 12-24h. " - "Use 30m-1h only for quick data processing or inference tasks. Default: '30m'." - ), + "description": "Max runtime. Examples: '30m', '2h', '4h'. Default: '30m'. Important for long training jobs. Use with 'run'/'scheduled run'.", }, "env": { "type": "object", - "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.", - }, - "trackio_space_id": { - "type": "string", - "description": ( - "Optional. The HF Space hosting the trackio dashboard for this run " - "(e.g. '/mlintern-<8char>', under YOUR HF namespace). " - "Injected as TRACKIO_SPACE_ID env var and used by the UI to embed " - "the live dashboard. Set this whenever the script uses " - "report_to='trackio'. The Space is auto-created and seeded with the " - "trackio dashboard before the job starts β€” DO NOT pre-create it via " - "hf_repo_git, that produces an empty Space that breaks the embed." - ), - }, - "trackio_project": { - "type": "string", - "description": ( - "Optional. The trackio project name to log this run under. " - "Injected as TRACKIO_PROJECT env var and used by the UI to filter " - "the embedded dashboard to this project." - ), - }, - "namespace": { - "type": "string", - "description": ( - "Optional namespace to run the job under. Must be the caller's own " - "account or an org they belong to. If omitted, defaults to the " - "caller's personal account. Credits are billed against this namespace." - ), + "description": "Environment variables. Format: {'KEY': 'VALUE'}. HF_TOKEN is automatically included from your auth. Use with 'run'/'scheduled run'.", }, + # Job management parameters "job_id": { "type": "string", - "description": "Job ID. Required for: logs, inspect, cancel.", + "description": "Job ID to operate on. Required for: 'logs', 'inspect', 'cancel'.", }, + # Scheduled job parameters "scheduled_job_id": { "type": "string", - "description": "Scheduled job ID. Required for: scheduled inspect/delete/suspend/resume.", + "description": "Scheduled job ID. Required for: 'scheduled inspect', 'scheduled delete', 'scheduled suspend', 'scheduled resume'.", }, "schedule": { "type": "string", - "description": "Cron schedule or preset (@hourly, @daily, @weekly, @monthly). Required for: scheduled run.", + "description": "Schedule for recurring job. Presets: '@hourly', '@daily', '@weekly', '@monthly'. Cron: '0 9 * * 1' (Mon 9am). Required for: 'scheduled run'.", }, }, "required": ["operation"], @@ -1247,7 +1012,7 @@ HF_JOBS_TOOL_SPEC = { async def hf_jobs_handler( - arguments: Dict[str, Any], session: Any = None, tool_call_id: str | None = None + arguments: Dict[str, Any], session: Any = None ) -> tuple[str, bool]: """Handler for agent tool router""" try: @@ -1258,34 +1023,18 @@ async def hf_jobs_handler( Event(event_type="tool_log", data={"tool": "hf_jobs", "log": log}) ) - # If script is a sandbox file path, read it from the sandbox - script = arguments.get("script", "") - sandbox = getattr(session, "sandbox", None) if session else None - if sandbox and script: - from agent.tools.sandbox_tool import resolve_sandbox_script - - content, error = await resolve_sandbox_script(sandbox, script) - if error: - return error, False - if content: - arguments = {**arguments, "script": content} - - hf_token = session.hf_token if session else None - try: - namespace, jobs_access = await resolve_jobs_namespace( - hf_token or "", - arguments.get("namespace"), - ) - except JobsAccessError as e: - return str(e), False + # Prefer the authenticated user's OAuth token, fall back to global env + hf_token = ( + (getattr(session, "hf_token", None) if session else None) + or os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACE_HUB_TOKEN") + ) + namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None) tool = HfJobsTool( namespace=namespace, hf_token=hf_token, - jobs_access=jobs_access, log_callback=log_callback if session else None, - session=session, - tool_call_id=tool_call_id, ) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) diff --git a/agent/tools/local_tools.py b/agent/tools/local_tools.py deleted file mode 100644 index 50cd5bd65b517f8855ceeb87ffade52a04e25a15..0000000000000000000000000000000000000000 --- a/agent/tools/local_tools.py +++ /dev/null @@ -1,441 +0,0 @@ -""" -Local tool implementations β€” bash/read/write/edit running on the user's machine. - -Drop-in replacement for sandbox tools when running in CLI (local) mode. -Same tool specs (names, parameters) but handlers execute locally via -subprocess/pathlib instead of going through a remote sandbox. -""" - -from __future__ import annotations - -import os -import re -import subprocess -import tempfile -from pathlib import Path -from typing import Any - -from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap - - -MAX_OUTPUT_CHARS = 25_000 -MAX_LINE_LENGTH = 4000 -DEFAULT_READ_LINES = 2000 -DEFAULT_TIMEOUT = 120 -MAX_TIMEOUT = 36000 # 10 hours β€” needed for long training runs (e.g. PostTrainBench) - -_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07") - -# Track files that have been read this session (enforces read-before-write/edit) -_files_read: set[str] = set() - - -def _resolve_path(path: str) -> str: - try: - return str(Path(path).resolve()) - except Exception: - return path - - -def _atomic_write(path: Path, content: str) -> None: - """Write file atomically via temp file + os.replace(). - - Ensures the file is never left in a partial/corrupted state β€” it's either - the old content or the new content, never half-written. - """ - path.parent.mkdir(parents=True, exist_ok=True) - fd = None - tmp_path = None - try: - fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp") - os.write(fd, content.encode("utf-8")) - os.fsync(fd) - os.close(fd) - fd = None - os.replace(tmp_path, str(path)) - tmp_path = None # successfully replaced, nothing to clean up - finally: - if fd is not None: - os.close(fd) - if tmp_path is not None: - try: - os.unlink(tmp_path) - except OSError: - pass - - -def _strip_ansi(text: str) -> str: - return _ANSI_RE.sub("", text) - - -def _truncate_output( - output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25 -) -> str: - """Tail-biased truncation with temp file spillover for full output access.""" - if len(output) <= max_chars: - return output - # Write full output to temp file so LLM can read specific sections - spill_path = None - try: - with tempfile.NamedTemporaryFile( - mode="w", suffix=".txt", prefix="bash_output_", delete=False - ) as f: - f.write(output) - spill_path = f.name - except Exception: - pass - head_budget = int(max_chars * head_ratio) - tail_budget = max_chars - head_budget - head = output[:head_budget] - tail = output[-tail_budget:] - total = len(output) - omitted = total - max_chars - meta = f"\n\n... ({omitted:,} of {total:,} chars omitted, showing first {head_budget:,} + last {tail_budget:,}) ...\n" - if spill_path: - meta += f"Full output saved to {spill_path} β€” use the read tool with offset/limit to inspect specific sections.\n" - meta += "IMPORTANT: The command has finished. Analyze the output above and continue with your next action.\n" - return head + meta + tail - - -# ── Handlers ──────────────────────────────────────────────────────────── - - -async def _bash_handler( - args: dict[str, Any], session: Any = None, **_kw -) -> tuple[str, bool]: - command = args.get("command", "") - if not command: - return "No command provided.", False - command = wrap_shell_command_with_hub_artifact_bootstrap(command, session) - work_dir = args.get("work_dir", ".") - timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT) - try: - result = subprocess.run( - command, - shell=True, - capture_output=True, - text=True, - cwd=work_dir, - timeout=timeout, - ) - output = _strip_ansi(result.stdout + result.stderr) - output = _truncate_output(output) - if not output.strip(): - output = "(no output)" - return output, result.returncode == 0 - except subprocess.TimeoutExpired: - return ( - f"Command timed out after {timeout}s and was killed.\n\n" - f"For long-running commands, run in the background and poll:\n" - f" nohup > /tmp/output.log 2>&1 & echo $!\n" - f"Then check status with:\n" - f" kill -0 2>/dev/null && echo 'running' || echo 'done'\n" - f" tail -n 50 /tmp/output.log" - ), False - except Exception as e: - return f"bash error: {e}", False - - -async def _read_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: - file_path = args.get("path", "") - if not file_path: - return "No path provided.", False - p = Path(file_path) - if not p.exists(): - return f"File not found: {file_path}", False - if p.is_dir(): - return "Cannot read a directory. Use bash with 'ls' instead.", False - try: - raw_content = p.read_text() - except Exception as e: - return f"read error: {e}", False - - _files_read.add(_resolve_path(file_path)) - - lines = raw_content.splitlines() - offset = max((args.get("offset") or 1), 1) - limit = args.get("limit") or DEFAULT_READ_LINES - - selected = lines[offset - 1 : offset - 1 + limit] - numbered = [] - for i, line in enumerate(selected, start=offset): - if len(line) > MAX_LINE_LENGTH: - line = line[:MAX_LINE_LENGTH] + "..." - numbered.append(f"{i:>6}\t{line}") - - return "\n".join(numbered), True - - -async def _write_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: - file_path = args.get("path", "") - content = args.get("content", "") - if not file_path: - return "No path provided.", False - p = Path(file_path) - if p.exists() and _resolve_path(file_path) not in _files_read: - return ( - f"You must read {file_path} before overwriting it. " - f"Use the read tool first to see current contents." - ), False - try: - _atomic_write(p, content) - _files_read.add(_resolve_path(file_path)) - msg = f"Wrote {len(content)} bytes to {file_path}" - # Syntax validation for Python files - if p.suffix == ".py": - from agent.tools.edit_utils import validate_python - - warnings = validate_python(content, file_path) - if warnings: - msg += "\n\nValidation warnings:\n" + "\n".join( - f" ⚠ {w}" for w in warnings - ) - return msg, True - except Exception as e: - return f"write error: {e}", False - - -async def _edit_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: - from agent.tools.edit_utils import apply_edit, validate_python - - file_path = args.get("path", "") - old_str = args.get("old_str", "") - new_str = args.get("new_str", "") - replace_all = args.get("replace_all", False) - mode = args.get("mode", "replace") - - if not file_path: - return "No path provided.", False - if old_str == new_str: - return "old_str and new_str must differ.", False - - p = Path(file_path) - if not p.exists(): - return f"File not found: {file_path}", False - if _resolve_path(file_path) not in _files_read: - return ( - f"You must read {file_path} before editing it. " - f"Use the read tool first to see current contents." - ), False - - try: - text = p.read_text() - except Exception as e: - return f"edit read error: {e}", False - - try: - new_text, replacements, fuzzy_note = apply_edit( - text, old_str, new_str, mode=mode, replace_all=replace_all - ) - except ValueError as e: - return str(e), False - - try: - _atomic_write(p, new_text) - except Exception as e: - return f"edit write error: {e}", False - - msg = f"Edited {file_path} ({replacements} replacement{'s' if replacements > 1 else ''})" - if fuzzy_note: - msg += f" {fuzzy_note}" - # Syntax validation for Python files - if p.suffix == ".py": - warnings = validate_python(new_text, file_path) - if warnings: - msg += "\n\nValidation warnings:\n" + "\n".join( - f" ⚠ {w}" for w in warnings - ) - return msg, True - - -# ── Local tool specs (override sandbox /app references) ──────────────── - -_LOCAL_TOOL_SPECS = { - "bash": { - "description": ( - "Run a shell command on the local machine and return stdout/stderr.\n" - "\n" - "IMPORTANT: Do NOT use bash for file operations β€” use the dedicated tools instead:\n" - "- To read files: use read (not cat/head/tail)\n" - "- To edit files: use edit (not sed/awk)\n" - "- To write files: use write (not echo/cat < > /tmp/output.log 2>&1 & echo $!\n" - "Then check status:\n" - " kill -0 2>/dev/null && echo 'running' || echo 'done'\n" - " tail -n 50 /tmp/output.log\n" - "\n" - "Timeout default 120s, max 36000s." - ), - "parameters": { - "type": "object", - "required": ["command"], - "additionalProperties": False, - "properties": { - "command": { - "type": "string", - "description": "The shell command to execute.", - }, - "description": { - "type": "string", - "description": "Short description (5-10 words, active voice).", - }, - "work_dir": { - "type": "string", - "description": "Working directory (default: current directory).", - }, - "timeout": { - "type": "integer", - "description": "Optional timeout in seconds (default: 120, max: 36000).", - }, - }, - }, - }, - "read": { - "description": ( - "Reads a file from the local filesystem. Returns contents with line numbers " - "(cat -n format).\n" - "\n" - "Usage:\n" - "- By default, reads up to 2000 lines from the beginning of the file.\n" - "- You can optionally specify offset and limit for large files, but prefer " - "reading the whole file first.\n" - "- Lines longer than 4000 chars are truncated.\n" - "- Cannot read directories β€” use bash with 'ls' instead.\n" - "- You should read multiple potentially useful files in parallel when possible.\n" - "- IMPORTANT: Always read a file before editing or overwriting it. The edit and " - "write tools will reject operations on files you haven't read." - ), - "parameters": { - "type": "object", - "required": ["path"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file to read.", - }, - "offset": { - "type": "integer", - "description": "The line number to start reading from (1-based). Only provide if the file is too large to read at once.", - }, - "limit": { - "type": "integer", - "description": "The number of lines to read. Only provide if the file is too large to read at once.", - }, - }, - }, - }, - "write": { - "description": ( - "Writes a file to the local filesystem. Overwrites the existing file if one " - "exists at the path.\n" - "\n" - "- If this is an existing file, you MUST use the read tool first. This tool " - "will fail if you did not read the file first.\n" - "- ALWAYS prefer editing existing files with the edit tool over overwriting " - "with write.\n" - "- Creates parent directories as needed." - ), - "parameters": { - "type": "object", - "required": ["path", "content"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file to write.", - }, - "content": { - "type": "string", - "description": "The complete file content to write.", - }, - }, - }, - }, - "edit": { - "description": ( - "Performs string replacements in files. Supports exact matching with " - "fuzzy fallback.\n" - "\n" - "Usage:\n" - "- You must read the file at least once before editing. This tool will " - "error if you attempt an edit without reading the file.\n" - "- The edit will FAIL if old_str is not unique in the file. Either provide " - "a larger string with more surrounding context to make it unique, or set " - "replace_all to true.\n" - "- old_str and new_str must differ.\n" - "- Preserve indentation exactly as it appears in the file.\n" - "- Do NOT include line number prefixes from read output in old_str or new_str.\n" - "- To delete code, set new_str to empty string.\n" - "- Use replace_all for renaming variables or strings across the file.\n" - "\n" - "Modes:\n" - "- replace (default): replace first occurrence of old_str with new_str.\n" - "- append_after: insert new_str immediately after old_str (old_str is kept).\n" - "- prepend_before: insert new_str immediately before old_str (old_str is kept)." - ), - "parameters": { - "type": "object", - "required": ["path", "old_str", "new_str"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file to edit.", - }, - "old_str": { - "type": "string", - "description": "The text to find in the file. Must match exactly (fuzzy matching is used as fallback).", - }, - "new_str": { - "type": "string", - "description": "The replacement text. For append_after/prepend_before modes, the text to insert.", - }, - "replace_all": { - "type": "boolean", - "description": "Replace all occurrences of old_str (default: false).", - "default": False, - }, - "mode": { - "type": "string", - "enum": ["replace", "append_after", "prepend_before"], - "description": "Edit mode (default: replace).", - "default": "replace", - }, - }, - }, - }, -} - -_HANDLERS = { - "bash": _bash_handler, - "read": _read_handler, - "write": _write_handler, - "edit": _edit_handler, -} - - -def get_local_tools(): - """Return local ToolSpecs for bash/read/write/edit (no sandbox_create).""" - from agent.core.tools import ToolSpec - - tools = [] - for name, spec in _LOCAL_TOOL_SPECS.items(): - handler = _HANDLERS.get(name) - if handler is None: - continue - tools.append( - ToolSpec( - name=name, - description=spec["description"], - parameters=spec["parameters"], - handler=handler, - ) - ) - return tools diff --git a/agent/tools/notify_tool.py b/agent/tools/notify_tool.py deleted file mode 100644 index f926d5a58d5f3c4b877cb8792f812f6e4fa322a7..0000000000000000000000000000000000000000 --- a/agent/tools/notify_tool.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import Any - -from agent.messaging.models import NotificationRequest - -NOTIFY_TOOL_SPEC = { - "name": "notify", - "description": ( - "Send an out-of-band notification to configured messaging destinations. " - "Use this only when the user explicitly asked for proactive notifications " - "or when the task requires reporting progress outside the chat. " - "Destinations must be named server-side configs such as 'slack.ops'." - ), - "parameters": { - "type": "object", - "properties": { - "destinations": { - "type": "array", - "description": "Named messaging destinations to notify.", - "items": {"type": "string"}, - "minItems": 1, - }, - "message": { - "type": "string", - "description": "Main notification body.", - }, - "title": { - "type": "string", - "description": "Optional short title line.", - }, - "severity": { - "type": "string", - "enum": ["info", "success", "warning", "error"], - "description": "Notification severity label.", - }, - }, - "required": ["destinations", "message"], - }, -} - - -async def notify_handler( - arguments: dict[str, Any], session=None, **_kwargs -) -> tuple[str, bool]: - if session is None or session.notification_gateway is None: - return "Messaging is not configured for this session.", False - - raw_destinations = arguments.get("destinations", []) - if not isinstance(raw_destinations, list) or not raw_destinations: - return "destinations must be a non-empty array of destination names.", False - - destinations: list[str] = [] - seen: set[str] = set() - for raw_name in raw_destinations: - if not isinstance(raw_name, str): - return "Each destination must be a string.", False - name = raw_name.strip() - if not name: - return "Destination names must not be empty.", False - if name not in seen: - destinations.append(name) - seen.add(name) - - disallowed = [ - name - for name in destinations - if not session.config.messaging.can_agent_tool_send(name) - ] - if disallowed: - return ( - "These destinations are unavailable for the notify tool: " - + ", ".join(disallowed) - ), False - - message = arguments.get("message", "") - if not isinstance(message, str) or not message.strip(): - return "message must be a non-empty string.", False - - title = arguments.get("title") - severity = arguments.get("severity", "info") - if title is not None and not isinstance(title, str): - return "title must be a string when provided.", False - if severity not in {"info", "success", "warning", "error"}: - return "severity must be one of: info, success, warning, error.", False - - requests = [ - NotificationRequest( - destination=name, - title=title, - message=message, - severity=severity, - metadata={ - "session_id": session.session_id, - "model": session.config.model_name, - }, - ) - for name in destinations - ] - results = await session.notification_gateway.send_many(requests) - - lines = [] - all_ok = True - for result in results: - if result.ok: - lines.append(f"{result.destination}: sent") - else: - all_ok = False - lines.append(f"{result.destination}: failed ({result.error})") - return "\n".join(lines), all_ok diff --git a/agent/tools/papers_tool.py b/agent/tools/papers_tool.py deleted file mode 100644 index dea63d7d327999303e76c7e3e155d90107a2fd4f..0000000000000000000000000000000000000000 --- a/agent/tools/papers_tool.py +++ /dev/null @@ -1,1340 +0,0 @@ -""" -HF Papers Tool β€” Discover papers, read their contents, and find linked resources. - -Operations: trending, search, paper_details, read_paper, - find_datasets, find_models, find_collections, find_all_resources, - citation_graph, snippet_search, recommend -""" - -import asyncio -import os -import re -import time -from typing import Any - -import httpx -from bs4 import BeautifulSoup, Tag - -from agent.tools.types import ToolResult - -HF_API = "https://huggingface.co/api" -ARXIV_HTML = "https://arxiv.org/html" -AR5IV_HTML = "https://ar5iv.labs.arxiv.org/html" - -DEFAULT_LIMIT = 10 -MAX_LIMIT = 50 -MAX_SUMMARY_LEN = 300 -MAX_SECTION_PREVIEW_LEN = 280 -MAX_SECTION_TEXT_LEN = 8000 - -SORT_MAP = { - "downloads": "downloads", - "likes": "likes", - "trending": "trendingScore", -} - -# --------------------------------------------------------------------------- -# Semantic Scholar API -# --------------------------------------------------------------------------- - -S2_API = "https://api.semanticscholar.org" -S2_API_KEY = os.environ.get("S2_API_KEY") -S2_HEADERS: dict[str, str] = {"x-api-key": S2_API_KEY} if S2_API_KEY else {} -S2_TIMEOUT = 12 -_s2_last_request: float = 0.0 - -# Shared response cache (survives across sessions, keyed by (path, params_tuple)) -_s2_cache: dict[str, Any] = {} -_S2_CACHE_MAX = 500 - - -def _s2_paper_id(arxiv_id: str) -> str: - """Convert bare arxiv ID to S2 format.""" - return f"ARXIV:{arxiv_id}" - - -def _s2_cache_key(path: str, params: dict | None) -> str: - """Build a hashable cache key from path + sorted params.""" - p = tuple(sorted((params or {}).items())) - return f"{path}:{p}" - - -async def _s2_request( - client: httpx.AsyncClient, - method: str, - path: str, - **kwargs: Any, -) -> httpx.Response | None: - """S2 request with 2 retries on 429/5xx. Rate-limited only when using API key.""" - global _s2_last_request - url = f"{S2_API}{path}" - kwargs.setdefault("headers", {}).update(S2_HEADERS) - kwargs.setdefault("timeout", S2_TIMEOUT) - - for attempt in range(3): - # Rate limit only when authenticated (1 req/s for search, 10 req/s for others) - if S2_API_KEY: - min_interval = 1.0 if "search" in path else 0.1 - elapsed = time.monotonic() - _s2_last_request - if elapsed < min_interval: - await asyncio.sleep(min_interval - elapsed) - _s2_last_request = time.monotonic() - - try: - resp = await client.request(method, url, **kwargs) - if resp.status_code == 429: - if attempt < 2: - await asyncio.sleep(60) - continue - return None - if resp.status_code >= 500: - if attempt < 2: - await asyncio.sleep(3) - continue - return None - return resp - except (httpx.RequestError, httpx.HTTPStatusError): - if attempt < 2: - await asyncio.sleep(3) - continue - return None - return None - - -async def _s2_get_json( - client: httpx.AsyncClient, - path: str, - params: dict | None = None, -) -> dict | None: - """Cached S2 GET returning parsed JSON or None.""" - key = _s2_cache_key(path, params) - if key in _s2_cache: - return _s2_cache[key] - - resp = await _s2_request(client, "GET", path, params=params or {}) - if resp and resp.status_code == 200: - data = resp.json() - if len(_s2_cache) < _S2_CACHE_MAX: - _s2_cache[key] = data - return data - return None - - -async def _s2_get_paper( - client: httpx.AsyncClient, - arxiv_id: str, - fields: str, -) -> dict | None: - """Fetch a single paper from S2 by arxiv ID. Returns None on failure.""" - return await _s2_get_json( - client, - f"/graph/v1/paper/{_s2_paper_id(arxiv_id)}", - {"fields": fields}, - ) - - -# --------------------------------------------------------------------------- -# HTML paper parsing -# --------------------------------------------------------------------------- - - -def _parse_paper_html(html: str) -> dict[str, Any]: - """Parse arxiv HTML into structured sections. - - Returns: - { - "title": str, - "abstract": str, - "sections": [{"id": str, "title": str, "level": int, "text": str}], - } - """ - soup = BeautifulSoup(html, "html.parser") - - # Title - title_el = soup.find("h1", class_="ltx_title") - title = title_el.get_text(strip=True).removeprefix("Title:") if title_el else "" - - # Abstract - abstract_el = soup.find("div", class_="ltx_abstract") - abstract = "" - if abstract_el: - # Skip the "Abstract" heading itself - for child in abstract_el.children: - if isinstance(child, Tag) and child.name in ("h6", "h2", "h3", "p", "span"): - if child.get_text(strip=True).lower() == "abstract": - continue - if isinstance(child, Tag) and child.name == "p": - abstract += child.get_text(separator=" ", strip=True) + " " - abstract = abstract.strip() - - # Sections β€” collect h2/h3 headings and text between them - sections: list[dict[str, Any]] = [] - headings = soup.find_all(["h2", "h3"], class_=lambda c: c and "ltx_title" in c) - - for heading in headings: - level = 2 if heading.name == "h2" else 3 - heading_text = heading.get_text(separator=" ", strip=True) - - # Collect text from siblings until next heading of same or higher level - text_parts: list[str] = [] - sibling = heading.find_next_sibling() - while sibling: - if isinstance(sibling, Tag): - if sibling.name in ("h2", "h3") and "ltx_title" in ( - sibling.get("class") or [] - ): - break - # Also stop at h2 if we're collecting h3 content - if sibling.name == "h2" and level == 3: - break - text_parts.append(sibling.get_text(separator=" ", strip=True)) - sibling = sibling.find_next_sibling() - - # Also check parent section element for contained paragraphs - parent_section = heading.find_parent("section") - if parent_section and not text_parts: - for p in parent_section.find_all("p", recursive=False): - text_parts.append(p.get_text(separator=" ", strip=True)) - - section_text = "\n\n".join(t for t in text_parts if t) - - # Extract section number from heading text (e.g., "4 Experiments" β†’ "4") - num_match = re.match(r"^([A-Z]?\d+(?:\.\d+)*)\s", heading_text) - section_id = num_match.group(1) if num_match else "" - - sections.append( - { - "id": section_id, - "title": heading_text, - "level": level, - "text": section_text, - } - ) - - return {"title": title, "abstract": abstract, "sections": sections} - - -def _find_section(sections: list[dict], query: str) -> dict | None: - """Find a section by number or name (fuzzy).""" - query_lower = query.lower().strip() - - # Exact match on section number - for s in sections: - if s["id"] == query_lower or s["id"] == query: - return s - - # Exact match on title - for s in sections: - if query_lower == s["title"].lower(): - return s - - # Substring match on title - for s in sections: - if query_lower in s["title"].lower(): - return s - - # Number prefix match (e.g., "4" matches "4.1", "4.2", etc. β€” return parent) - for s in sections: - if s["id"].startswith(query_lower + ".") or s["id"] == query_lower: - return s - - return None - - -# --------------------------------------------------------------------------- -# Formatting helpers -# --------------------------------------------------------------------------- - - -def _clean_description(text: str) -> str: - """Strip HTML card artifacts and collapse whitespace from HF API descriptions.""" - text = re.sub(r"[\t]+", " ", text) - text = re.sub(r"\n{2,}", "\n", text) - return text.strip() - - -def _truncate(text: str, max_len: int) -> str: - if len(text) <= max_len: - return text - return text[:max_len] + "..." - - -def _format_paper_list( - papers: list, title: str, date: str | None = None, query: str | None = None -) -> str: - lines = [f"# {title}"] - if date: - lines[0] += f" ({date})" - if query: - lines.append(f"Filtered by: '{query}'") - lines.append(f"Showing {len(papers)} paper(s)\n") - - for i, item in enumerate(papers, 1): - paper = item.get("paper", item) - arxiv_id = paper.get("id", "") - paper_title = paper.get("title", "Unknown") - upvotes = paper.get("upvotes", 0) - summary = paper.get("ai_summary") or _truncate( - paper.get("summary", ""), MAX_SUMMARY_LEN - ) - keywords = paper.get("ai_keywords") or [] - github = paper.get("githubRepo") or "" - stars = paper.get("githubStars") or 0 - - lines.append(f"## {i}. {paper_title}") - lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}") - lines.append(f"https://huggingface.co/papers/{arxiv_id}") - if keywords: - lines.append(f"**Keywords:** {', '.join(keywords[:5])}") - if github: - lines.append(f"**GitHub:** {github} ({stars} stars)") - if summary: - lines.append(f"**Summary:** {_truncate(summary, MAX_SUMMARY_LEN)}") - lines.append("") - - return "\n".join(lines) - - -def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str: - arxiv_id = paper.get("id", "") - title = paper.get("title", "Unknown") - upvotes = paper.get("upvotes", 0) - ai_summary = paper.get("ai_summary") or "" - summary = paper.get("summary", "") - keywords = paper.get("ai_keywords") or [] - github = paper.get("githubRepo") or "" - stars = paper.get("githubStars") or 0 - authors = paper.get("authors") or [] - - lines = [f"# {title}"] - meta_parts = [f"**arxiv_id:** {arxiv_id}", f"**upvotes:** {upvotes}"] - if s2_data: - cites = s2_data.get("citationCount", 0) - influential = s2_data.get("influentialCitationCount", 0) - meta_parts.append(f"**citations:** {cites} ({influential} influential)") - lines.append(" | ".join(meta_parts)) - lines.append(f"https://huggingface.co/papers/{arxiv_id}") - lines.append(f"https://arxiv.org/abs/{arxiv_id}") - - if authors: - names = [a.get("name", "") for a in authors[:10]] - author_str = ", ".join(n for n in names if n) - if len(authors) > 10: - author_str += f" (+{len(authors) - 10} more)" - lines.append(f"**Authors:** {author_str}") - - if keywords: - lines.append(f"**Keywords:** {', '.join(keywords)}") - if s2_data and s2_data.get("s2FieldsOfStudy"): - fields = [ - f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category") - ] - if fields: - lines.append(f"**Fields:** {', '.join(fields)}") - if s2_data and s2_data.get("venue"): - lines.append(f"**Venue:** {s2_data['venue']}") - if github: - lines.append(f"**GitHub:** {github} ({stars} stars)") - - if s2_data and s2_data.get("tldr"): - tldr_text = s2_data["tldr"].get("text", "") - if tldr_text: - lines.append(f"\n## TL;DR\n{tldr_text}") - if ai_summary: - lines.append(f"\n## AI Summary\n{ai_summary}") - if summary: - lines.append(f"\n## Abstract\n{_truncate(summary, 500)}") - - lines.append( - "\n**Next:** Use read_paper to read specific sections, find_all_resources for linked datasets/models, " - "or citation_graph to trace references and citations." - ) - return "\n".join(lines) - - -def _format_read_paper_toc(parsed: dict[str, Any], arxiv_id: str) -> str: - """Format TOC view: abstract + section list with previews.""" - lines = [f"# {parsed['title']}"] - lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") - - if parsed["abstract"]: - lines.append(f"## Abstract\n{parsed['abstract']}\n") - - lines.append("## Sections") - for s in parsed["sections"]: - prefix = " " if s["level"] == 3 else "" - preview = ( - _truncate(s["text"], MAX_SECTION_PREVIEW_LEN) if s["text"] else "(empty)" - ) - lines.append(f"{prefix}- **{s['title']}**: {preview}") - - lines.append( - '\nCall read_paper with section parameter (e.g. section="4" or section="Experiments") to read a specific section.' - ) - return "\n".join(lines) - - -def _format_read_paper_section(section: dict, arxiv_id: str) -> str: - """Format a single section's full text.""" - lines = [f"# {section['title']}"] - lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") - - text = section["text"] - if len(text) > MAX_SECTION_TEXT_LEN: - text = ( - text[:MAX_SECTION_TEXT_LEN] - + f"\n\n... (truncated at {MAX_SECTION_TEXT_LEN} chars)" - ) - - lines.append(text if text else "(This section has no extractable text content.)") - return "\n".join(lines) - - -def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str: - lines = [f"# Datasets linked to paper {arxiv_id}"] - lines.append(f"https://huggingface.co/papers/{arxiv_id}") - lines.append(f"Showing {len(datasets)} dataset(s), sorted by {sort}\n") - - for i, ds in enumerate(datasets, 1): - ds_id = ds.get("id", "unknown") - downloads = ds.get("downloads", 0) - likes = ds.get("likes", 0) - desc = _truncate( - _clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN - ) - tags = ds.get("tags") or [] - interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5] - - lines.append(f"**{i}. [{ds_id}](https://huggingface.co/datasets/{ds_id})**") - lines.append(f" Downloads: {downloads:,} | Likes: {likes}") - if interesting: - lines.append(f" Tags: {', '.join(interesting)}") - if desc: - lines.append(f" {desc}") - lines.append("") - - if datasets: - top = datasets[0].get("id", "") - lines.append(f'**Inspect top dataset:** hf_inspect_dataset(dataset="{top}")') - return "\n".join(lines) - - -def _format_datasets_compact(datasets: list) -> str: - if not datasets: - return "## Datasets\nNone found" - lines = [f"## Datasets ({len(datasets)})"] - for ds in datasets: - lines.append( - f"- **{ds.get('id', '?')}** ({ds.get('downloads', 0):,} downloads)" - ) - return "\n".join(lines) - - -def _format_models(models: list, arxiv_id: str, sort: str) -> str: - lines = [f"# Models linked to paper {arxiv_id}"] - lines.append(f"https://huggingface.co/papers/{arxiv_id}") - lines.append(f"Showing {len(models)} model(s), sorted by {sort}\n") - - for i, m in enumerate(models, 1): - model_id = m.get("id", "unknown") - downloads = m.get("downloads", 0) - likes = m.get("likes", 0) - pipeline = m.get("pipeline_tag") or "" - library = m.get("library_name") or "" - - lines.append(f"**{i}. [{model_id}](https://huggingface.co/{model_id})**") - meta = f" Downloads: {downloads:,} | Likes: {likes}" - if pipeline: - meta += f" | Task: {pipeline}" - if library: - meta += f" | Library: {library}" - lines.append(meta) - lines.append("") - - return "\n".join(lines) - - -def _format_models_compact(models: list) -> str: - if not models: - return "## Models\nNone found" - lines = [f"## Models ({len(models)})"] - for m in models: - pipeline = m.get("pipeline_tag") or "" - suffix = f" ({pipeline})" if pipeline else "" - lines.append( - f"- **{m.get('id', '?')}** ({m.get('downloads', 0):,} downloads){suffix}" - ) - return "\n".join(lines) - - -def _format_collections(collections: list, arxiv_id: str) -> str: - lines = [f"# Collections containing paper {arxiv_id}"] - lines.append(f"Showing {len(collections)} collection(s)\n") - - for i, c in enumerate(collections, 1): - slug = c.get("slug", "") - title = c.get("title", "Untitled") - upvotes = c.get("upvotes", 0) - owner = c.get("owner", {}).get("name", "") - desc = _truncate(c.get("description") or "", MAX_SUMMARY_LEN) - num_items = len(c.get("items", [])) - - lines.append(f"**{i}. {title}**") - lines.append(f" By: {owner} | Upvotes: {upvotes} | Items: {num_items}") - lines.append(f" https://huggingface.co/collections/{slug}") - if desc: - lines.append(f" {desc}") - lines.append("") - - return "\n".join(lines) - - -def _format_collections_compact(collections: list) -> str: - if not collections: - return "## Collections\nNone found" - lines = [f"## Collections ({len(collections)})"] - for c in collections: - title = c.get("title", "Untitled") - owner = c.get("owner", {}).get("name", "") - upvotes = c.get("upvotes", 0) - lines.append(f"- **{title}** by {owner} ({upvotes} upvotes)") - return "\n".join(lines) - - -# --------------------------------------------------------------------------- -# Operation handlers -# --------------------------------------------------------------------------- - - -def _error(message: str) -> ToolResult: - return { - "formatted": message, - "totalResults": 0, - "resultsShared": 0, - "isError": True, - } - - -def _validate_arxiv_id(args: dict) -> str | None: - """Return arxiv_id or None if missing.""" - return args.get("arxiv_id") - - -async def _op_trending(args: dict[str, Any], limit: int) -> ToolResult: - date = args.get("date") - query = args.get("query") - - params: dict[str, Any] = {"limit": limit if not query else max(limit * 3, 30)} - if date: - params["date"] = date - - async with httpx.AsyncClient(timeout=15) as client: - resp = await client.get(f"{HF_API}/daily_papers", params=params) - resp.raise_for_status() - papers = resp.json() - - if query: - q = query.lower() - papers = [ - p - for p in papers - if q in p.get("title", "").lower() - or q in p.get("paper", {}).get("title", "").lower() - or q in p.get("paper", {}).get("summary", "").lower() - or any( - q in kw.lower() for kw in (p.get("paper", {}).get("ai_keywords") or []) - ) - ] - - papers = papers[:limit] - if not papers: - msg = "No trending papers found" - if query: - msg += f" matching '{query}'" - if date: - msg += f" for {date}" - return {"formatted": msg, "totalResults": 0, "resultsShared": 0} - - formatted = _format_paper_list(papers, "Trending Papers", date=date, query=query) - return { - "formatted": formatted, - "totalResults": len(papers), - "resultsShared": len(papers), - } - - -def _format_s2_paper_list(papers: list[dict], title: str) -> str: - """Format a list of S2 paper results.""" - lines = [f"# {title}"] - lines.append(f"Showing {len(papers)} result(s)\n") - - for i, paper in enumerate(papers, 1): - ptitle = paper.get("title") or "(untitled)" - year = paper.get("year") or "?" - cites = paper.get("citationCount", 0) - venue = paper.get("venue") or "" - ext_ids = paper.get("externalIds") or {} - aid = ext_ids.get("ArXiv", "") - tldr = (paper.get("tldr") or {}).get("text", "") - - lines.append(f"### {i}. {ptitle}") - meta = [f"Year: {year}", f"Citations: {cites}"] - if venue: - meta.append(f"Venue: {venue}") - if aid: - meta.append(f"arxiv_id: {aid}") - lines.append(" | ".join(meta)) - if aid: - lines.append(f"https://arxiv.org/abs/{aid}") - if tldr: - lines.append(f"**TL;DR:** {tldr}") - lines.append("") - - lines.append( - "Use paper_details with arxiv_id for full info, or read_paper to read sections." - ) - return "\n".join(lines) - - -async def _s2_bulk_search( - query: str, args: dict[str, Any], limit: int -) -> ToolResult | None: - """Search via S2 bulk endpoint with filters. Returns None on failure.""" - params: dict[str, Any] = { - "query": query, - "limit": limit, - "fields": "title,externalIds,year,citationCount,tldr,venue,publicationDate", - } - - # Date filter - date_from = args.get("date_from", "") - date_to = args.get("date_to", "") - if date_from or date_to: - params["publicationDateOrYear"] = f"{date_from}:{date_to}" - - # Fields of study - categories = args.get("categories") - if categories: - params["fieldsOfStudy"] = categories - - # Min citations - min_cites = args.get("min_citations") - if min_cites: - params["minCitationCount"] = str(min_cites) - - # Sort - sort_by = args.get("sort_by") - if sort_by and sort_by != "relevance": - params["sort"] = f"{sort_by}:desc" - - async with httpx.AsyncClient(timeout=15) as client: - resp = await _s2_request( - client, "GET", "/graph/v1/paper/search/bulk", params=params - ) - if not resp or resp.status_code != 200: - return None - data = resp.json() - - papers = data.get("data") or [] - if not papers: - return { - "formatted": f"No papers found for '{query}' with the given filters.", - "totalResults": 0, - "resultsShared": 0, - } - - formatted = _format_s2_paper_list( - papers[:limit], f"Papers matching '{query}' (Semantic Scholar)" - ) - return { - "formatted": formatted, - "totalResults": data.get("total", len(papers)), - "resultsShared": min(limit, len(papers)), - } - - -async def _op_search(args: dict[str, Any], limit: int) -> ToolResult: - query = args.get("query") - if not query: - return _error("'query' is required for search operation.") - - # Route to S2 when filters are present - use_s2 = any( - args.get(k) - for k in ("date_from", "date_to", "categories", "min_citations", "sort_by") - ) - if use_s2: - result = await _s2_bulk_search(query, args, limit) - if result is not None: - return result - # Fall back to HF search (without filters) if S2 fails - - async with httpx.AsyncClient(timeout=15) as client: - resp = await client.get( - f"{HF_API}/papers/search", params={"q": query, "limit": limit} - ) - resp.raise_for_status() - papers = resp.json() - - if not papers: - return { - "formatted": f"No papers found for '{query}'", - "totalResults": 0, - "resultsShared": 0, - } - - formatted = _format_paper_list(papers, f"Papers matching '{query}'") - return { - "formatted": formatted, - "totalResults": len(papers), - "resultsShared": len(papers), - } - - -async def _op_paper_details(args: dict[str, Any], limit: int) -> ToolResult: - arxiv_id = _validate_arxiv_id(args) - if not arxiv_id: - return _error("'arxiv_id' is required for paper_details.") - - async with httpx.AsyncClient(timeout=15) as client: - resp = await client.get(f"{HF_API}/papers/{arxiv_id}") - resp.raise_for_status() - paper = resp.json() - - return { - "formatted": _format_paper_detail(paper), - "totalResults": 1, - "resultsShared": 1, - } - - -async def _op_read_paper(args: dict[str, Any], limit: int) -> ToolResult: - arxiv_id = _validate_arxiv_id(args) - if not arxiv_id: - return _error("'arxiv_id' is required for read_paper.") - - section_query = args.get("section") - - # Try fetching HTML from arxiv, then ar5iv, then fallback to abstract - parsed = None - async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: - for base_url in [ARXIV_HTML, AR5IV_HTML]: - try: - resp = await client.get(f"{base_url}/{arxiv_id}") - if resp.status_code == 200: - parsed = _parse_paper_html(resp.text) - if parsed["sections"]: # Only use if we got real sections - break - parsed = None - except httpx.RequestError: - continue - - # Fallback: return abstract from HF API - if not parsed or not parsed["sections"]: - try: - async with httpx.AsyncClient(timeout=15) as client: - resp = await client.get(f"{HF_API}/papers/{arxiv_id}") - resp.raise_for_status() - paper = resp.json() - abstract = paper.get("summary", "") - title = paper.get("title", "") - msg = f"# {title}\nhttps://arxiv.org/abs/{arxiv_id}\n\n" - msg += f"## Abstract\n{abstract}\n\n" - msg += "HTML version not available for this paper. Only abstract shown.\n" - msg += f"PDF: https://arxiv.org/pdf/{arxiv_id}" - return {"formatted": msg, "totalResults": 1, "resultsShared": 1} - except Exception: - return _error( - f"Could not fetch paper {arxiv_id}. Check the arxiv ID is correct." - ) - - # Return TOC or specific section - if not section_query: - formatted = _format_read_paper_toc(parsed, arxiv_id) - return { - "formatted": formatted, - "totalResults": len(parsed["sections"]), - "resultsShared": len(parsed["sections"]), - } - - section = _find_section(parsed["sections"], section_query) - if not section: - available = "\n".join(f"- {s['title']}" for s in parsed["sections"]) - return _error( - f"Section '{section_query}' not found. Available sections:\n{available}" - ) - - formatted = _format_read_paper_section(section, arxiv_id) - return {"formatted": formatted, "totalResults": 1, "resultsShared": 1} - - -# --------------------------------------------------------------------------- -# Citation graph (Semantic Scholar) -# --------------------------------------------------------------------------- - - -def _format_citation_entry(entry: dict, show_context: bool = False) -> str: - """Format a single citation/reference entry.""" - paper = entry.get("citingPaper") or entry.get("citedPaper") or {} - title = paper.get("title") or "(untitled)" - year = paper.get("year") or "?" - cites = paper.get("citationCount", 0) - ext_ids = paper.get("externalIds") or {} - aid = ext_ids.get("ArXiv", "") - influential = " **[influential]**" if entry.get("isInfluential") else "" - - parts = [f"- **{title}** ({year}, {cites} cites){influential}"] - if aid: - parts[0] += f" arxiv:{aid}" - - if show_context: - intents = entry.get("intents") or [] - if intents: - parts.append(f" Intent: {', '.join(intents)}") - contexts = entry.get("contexts") or [] - for ctx in contexts[:2]: - if ctx: - parts.append(f" > {_truncate(ctx, 200)}") - - return "\n".join(parts) - - -def _format_citation_graph( - arxiv_id: str, - references: list[dict] | None, - citations: list[dict] | None, -) -> str: - lines = [f"# Citation Graph for {arxiv_id}"] - lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") - - if references is not None: - lines.append(f"## References ({len(references)})") - if references: - for entry in references: - lines.append(_format_citation_entry(entry)) - else: - lines.append("No references found.") - lines.append("") - - if citations is not None: - lines.append(f"## Citations ({len(citations)})") - if citations: - for entry in citations: - lines.append(_format_citation_entry(entry, show_context=True)) - else: - lines.append("No citations found.") - lines.append("") - - lines.append( - "**Tip:** Use paper_details with an arxiv_id from above to explore further." - ) - return "\n".join(lines) - - -async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult: - arxiv_id = _validate_arxiv_id(args) - if not arxiv_id: - return _error("'arxiv_id' is required for citation_graph.") - - direction = args.get("direction", "both") - s2_id = _s2_paper_id(arxiv_id) - fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential" - params = {"fields": fields, "limit": limit} - - async with httpx.AsyncClient(timeout=15) as client: - refs, cites = None, None - coros = [] - if direction in ("references", "both"): - coros.append( - _s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params) - ) - if direction in ("citations", "both"): - coros.append( - _s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params) - ) - - results = await asyncio.gather(*coros, return_exceptions=True) - idx = 0 - if direction in ("references", "both"): - r = results[idx] - if isinstance(r, dict): - refs = r.get("data", []) - idx += 1 - if direction in ("citations", "both"): - r = results[idx] - if isinstance(r, dict): - cites = r.get("data", []) - - if refs is None and cites is None: - return _error( - f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar." - ) - - total = (len(refs) if refs else 0) + (len(cites) if cites else 0) - return { - "formatted": _format_citation_graph(arxiv_id, refs, cites), - "totalResults": total, - "resultsShared": total, - } - - -async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult: - arxiv_id = _validate_arxiv_id(args) - if not arxiv_id: - return _error("'arxiv_id' is required for find_datasets.") - - sort = args.get("sort", "downloads") - sort_key = SORT_MAP.get(sort, "downloads") - - async with httpx.AsyncClient(timeout=15) as client: - resp = await client.get( - f"{HF_API}/datasets", - params={ - "filter": f"arxiv:{arxiv_id}", - "limit": limit, - "sort": sort_key, - "direction": -1, - }, - ) - resp.raise_for_status() - datasets = resp.json() - - if not datasets: - return { - "formatted": f"No datasets found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", - "totalResults": 0, - "resultsShared": 0, - } - - return { - "formatted": _format_datasets(datasets, arxiv_id, sort), - "totalResults": len(datasets), - "resultsShared": len(datasets), - } - - -async def _op_find_models(args: dict[str, Any], limit: int) -> ToolResult: - arxiv_id = _validate_arxiv_id(args) - if not arxiv_id: - return _error("'arxiv_id' is required for find_models.") - - sort = args.get("sort", "downloads") - sort_key = SORT_MAP.get(sort, "downloads") - - async with httpx.AsyncClient(timeout=15) as client: - resp = await client.get( - f"{HF_API}/models", - params={ - "filter": f"arxiv:{arxiv_id}", - "limit": limit, - "sort": sort_key, - "direction": -1, - }, - ) - resp.raise_for_status() - models = resp.json() - - if not models: - return { - "formatted": f"No models found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", - "totalResults": 0, - "resultsShared": 0, - } - - return { - "formatted": _format_models(models, arxiv_id, sort), - "totalResults": len(models), - "resultsShared": len(models), - } - - -async def _op_find_collections(args: dict[str, Any], limit: int) -> ToolResult: - arxiv_id = _validate_arxiv_id(args) - if not arxiv_id: - return _error("'arxiv_id' is required for find_collections.") - - async with httpx.AsyncClient(timeout=15) as client: - resp = await client.get(f"{HF_API}/collections", params={"paper": arxiv_id}) - resp.raise_for_status() - collections = resp.json() - - if not collections: - return { - "formatted": f"No collections found containing paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", - "totalResults": 0, - "resultsShared": 0, - } - - collections = collections[:limit] - return { - "formatted": _format_collections(collections, arxiv_id), - "totalResults": len(collections), - "resultsShared": len(collections), - } - - -async def _op_find_all_resources(args: dict[str, Any], limit: int) -> ToolResult: - arxiv_id = _validate_arxiv_id(args) - if not arxiv_id: - return _error("'arxiv_id' is required for find_all_resources.") - - per_cat = min(limit, 10) - - async with httpx.AsyncClient(timeout=15) as client: - results = await asyncio.gather( - client.get( - f"{HF_API}/datasets", - params={ - "filter": f"arxiv:{arxiv_id}", - "limit": per_cat, - "sort": "downloads", - "direction": -1, - }, - ), - client.get( - f"{HF_API}/models", - params={ - "filter": f"arxiv:{arxiv_id}", - "limit": per_cat, - "sort": "downloads", - "direction": -1, - }, - ), - client.get(f"{HF_API}/collections", params={"paper": arxiv_id}), - return_exceptions=True, - ) - - sections = [] - total = 0 - - # Datasets - if isinstance(results[0], Exception): - sections.append(f"## Datasets\nError: {results[0]}") - else: - datasets = results[0].json() - total += len(datasets) - sections.append(_format_datasets_compact(datasets[:per_cat])) - - # Models - if isinstance(results[1], Exception): - sections.append(f"## Models\nError: {results[1]}") - else: - models = results[1].json() - total += len(models) - sections.append(_format_models_compact(models[:per_cat])) - - # Collections - if isinstance(results[2], Exception): - sections.append(f"## Collections\nError: {results[2]}") - else: - collections = results[2].json() - total += len(collections) - sections.append(_format_collections_compact(collections[:per_cat])) - - header = f"# Resources linked to paper {arxiv_id}\nhttps://huggingface.co/papers/{arxiv_id}\n" - formatted = header + "\n\n".join(sections) - return {"formatted": formatted, "totalResults": total, "resultsShared": total} - - -# --------------------------------------------------------------------------- -# Snippet search (Semantic Scholar) -# --------------------------------------------------------------------------- - - -def _format_snippets(snippets: list[dict], query: str) -> str: - lines = [f"# Snippet Search: '{query}'"] - lines.append(f"Found {len(snippets)} matching passage(s)\n") - - for i, item in enumerate(snippets, 1): - paper = item.get("paper") or {} - ptitle = paper.get("title") or "(untitled)" - year = paper.get("year") or "?" - cites = paper.get("citationCount", 0) - ext_ids = paper.get("externalIds") or {} - aid = ext_ids.get("ArXiv", "") - - snippet = item.get("snippet") or {} - text = snippet.get("text", "") - section = snippet.get("section") or "" - - lines.append(f"### {i}. {ptitle} ({year}, {cites} cites)") - if aid: - lines.append(f"arxiv:{aid}") - if section: - lines.append(f"Section: {section}") - if text: - lines.append(f"> {_truncate(text, 400)}") - lines.append("") - - lines.append( - "Use paper_details or read_paper with arxiv_id to explore a paper further." - ) - return "\n".join(lines) - - -async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult: - query = args.get("query") - if not query: - return _error("'query' is required for snippet_search.") - - params: dict[str, Any] = { - "query": query, - "limit": limit, - "fields": "title,externalIds,year,citationCount", - } - - # Optional filters (same as search) - date_from = args.get("date_from", "") - date_to = args.get("date_to", "") - if date_from or date_to: - params["publicationDateOrYear"] = f"{date_from}:{date_to}" - if args.get("categories"): - params["fieldsOfStudy"] = args["categories"] - if args.get("min_citations"): - params["minCitationCount"] = str(args["min_citations"]) - - async with httpx.AsyncClient(timeout=15) as client: - resp = await _s2_request( - client, "GET", "/graph/v1/snippet/search", params=params - ) - if not resp or resp.status_code != 200: - return _error("Snippet search failed. Semantic Scholar may be unavailable.") - data = resp.json() - - snippets = data.get("data") or [] - if not snippets: - return { - "formatted": f"No snippets found for '{query}'.", - "totalResults": 0, - "resultsShared": 0, - } - - return { - "formatted": _format_snippets(snippets, query), - "totalResults": len(snippets), - "resultsShared": len(snippets), - } - - -# --------------------------------------------------------------------------- -# Recommendations (Semantic Scholar) -# --------------------------------------------------------------------------- - - -async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult: - positive_ids = args.get("positive_ids") - arxiv_id = _validate_arxiv_id(args) - - if not arxiv_id and not positive_ids: - return _error("'arxiv_id' or 'positive_ids' is required for recommend.") - - fields = "title,externalIds,year,citationCount,tldr,venue" - - async with httpx.AsyncClient(timeout=15) as client: - if positive_ids and not arxiv_id: - # Multi-paper recommendations (POST, not cached) - pos = [ - _s2_paper_id(pid.strip()) - for pid in positive_ids.split(",") - if pid.strip() - ] - neg_raw = args.get("negative_ids", "") - neg = ( - [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] - if neg_raw - else [] - ) - resp = await _s2_request( - client, - "POST", - "/recommendations/v1/papers/", - json={"positivePaperIds": pos, "negativePaperIds": neg}, - params={"fields": fields, "limit": limit}, - ) - if not resp or resp.status_code != 200: - return _error( - "Recommendation request failed. Semantic Scholar may be unavailable." - ) - data = resp.json() - else: - # Single-paper recommendations (cached) - data = await _s2_get_json( - client, - f"/recommendations/v1/papers/forpaper/{_s2_paper_id(arxiv_id)}", - {"fields": fields, "limit": limit, "from": "recent"}, - ) - if not data: - return _error( - "Recommendation request failed. Semantic Scholar may be unavailable." - ) - - papers = data.get("recommendedPapers") or [] - if not papers: - return { - "formatted": "No recommendations found.", - "totalResults": 0, - "resultsShared": 0, - } - - title = f"Recommended papers based on {arxiv_id or positive_ids}" - return { - "formatted": _format_s2_paper_list(papers[:limit], title), - "totalResults": len(papers), - "resultsShared": min(limit, len(papers)), - } - - -# --------------------------------------------------------------------------- -# Operation dispatch -# --------------------------------------------------------------------------- - -_OPERATIONS = { - "trending": _op_trending, - "search": _op_search, - "paper_details": _op_paper_details, - "read_paper": _op_read_paper, - "citation_graph": _op_citation_graph, - "snippet_search": _op_snippet_search, - "recommend": _op_recommend, - "find_datasets": _op_find_datasets, - "find_models": _op_find_models, - "find_collections": _op_find_collections, - "find_all_resources": _op_find_all_resources, -} - - -# --------------------------------------------------------------------------- -# Tool spec + handler -# --------------------------------------------------------------------------- - -HF_PAPERS_TOOL_SPEC = { - "name": "hf_papers", - "description": ( - "Discover ML research papers, analyze citations, search paper contents, and find linked resources.\n\n" - "Combines HuggingFace Hub, arXiv, and Semantic Scholar. Use for exploring research areas, " - "finding datasets for a task, tracing citation chains, or implementing a paper's approach.\n\n" - "Typical flows:\n" - " search β†’ read_paper β†’ find_all_resources β†’ hf_inspect_dataset\n" - " search β†’ paper_details β†’ citation_graph β†’ read_paper (trace influence)\n" - " snippet_search β†’ paper_details β†’ read_paper (find specific claims)\n\n" - "Operations:\n" - "- trending: Get trending daily papers, optionally filter by topic keyword\n" - "- search: Search papers. Uses HF by default (ML-tuned). Add date_from/min_citations/categories to use Semantic Scholar with filters\n" - "- paper_details: Metadata, abstract, AI summary, github link\n" - "- read_paper: Read paper contents β€” without section: abstract + TOC; with section: full text\n" - "- citation_graph: Get references and citations for a paper with influence flags and citation intents\n" - "- snippet_search: Semantic search over full-text passages from 12M+ papers\n" - "- recommend: Find similar papers (single paper or positive/negative examples)\n" - "- find_datasets: Find datasets linked to a paper\n" - "- find_models: Find models linked to a paper\n" - "- find_collections: Find collections that include a paper\n" - "- find_all_resources: Parallel fetch of datasets + models + collections for a paper" - ), - "parameters": { - "type": "object", - "properties": { - "operation": { - "type": "string", - "enum": list(_OPERATIONS.keys()), - "description": "Operation to execute.", - }, - "query": { - "type": "string", - "description": ( - "Search query. Required for: search, snippet_search. " - "Optional for: trending (filters by keyword). " - "Supports boolean syntax for Semantic Scholar: '\"exact phrase\" term1 | term2'." - ), - }, - "arxiv_id": { - "type": "string", - "description": ( - "ArXiv paper ID (e.g. '2305.18290'). " - "Required for: paper_details, read_paper, citation_graph, find_datasets, find_models, find_collections, find_all_resources. " - "Optional for: recommend (single-paper recs). Get IDs from search results first." - ), - }, - "section": { - "type": "string", - "description": ( - "Section name or number to read (e.g. '3', 'Experiments', '4.2'). " - "Optional for: read_paper. Without this, returns abstract + TOC." - ), - }, - "direction": { - "type": "string", - "enum": ["citations", "references", "both"], - "description": "Direction for citation_graph. Default: both.", - }, - "date": { - "type": "string", - "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).", - }, - "date_from": { - "type": "string", - "description": "Start date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", - }, - "date_to": { - "type": "string", - "description": "End date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", - }, - "categories": { - "type": "string", - "description": "Field of study filter (e.g. 'Computer Science'). Triggers Semantic Scholar search.", - }, - "min_citations": { - "type": "integer", - "description": "Minimum citation count filter. Triggers Semantic Scholar search.", - }, - "sort_by": { - "type": "string", - "enum": ["relevance", "citationCount", "publicationDate"], - "description": "Sort order for Semantic Scholar search. Default: relevance.", - }, - "positive_ids": { - "type": "string", - "description": "Comma-separated arxiv IDs for multi-paper recommendations. For: recommend.", - }, - "negative_ids": { - "type": "string", - "description": "Comma-separated arxiv IDs as negative examples. For: recommend.", - }, - "sort": { - "type": "string", - "enum": ["downloads", "likes", "trending"], - "description": ( - "Sort order for find_datasets and find_models. Default: downloads." - ), - }, - "limit": { - "type": "integer", - "description": "Maximum results to return (default: 10, max: 50).", - }, - }, - "required": ["operation"], - }, -} - - -async def hf_papers_handler(arguments: dict[str, Any]) -> tuple[str, bool]: - """Handler for agent tool router.""" - operation = arguments.get("operation") - if not operation: - return "'operation' parameter is required.", False - - handler = _OPERATIONS.get(operation) - if not handler: - valid = ", ".join(_OPERATIONS.keys()) - return f"Unknown operation: '{operation}'. Valid: {valid}", False - - limit = min(arguments.get("limit", DEFAULT_LIMIT), MAX_LIMIT) - - try: - result = await handler(arguments, limit) - return result["formatted"], not result.get("isError", False) - except httpx.HTTPStatusError as e: - return f"API error: {e.response.status_code} β€” {e.response.text[:200]}", False - except httpx.RequestError as e: - return f"Request error: {e}", False - except Exception as e: - return f"Error in {operation}: {e}", False diff --git a/agent/tools/plan_tool.py b/agent/tools/plan_tool.py index a923d53c27068fe81d5fe5dd1e774255c4339601..25ba5f87201ff45d874b94abc8975857f10b40d1 100644 --- a/agent/tools/plan_tool.py +++ b/agent/tools/plan_tool.py @@ -85,11 +85,18 @@ def get_current_plan() -> List[Dict[str, str]]: PLAN_TOOL_SPEC = { "name": "plan_tool", "description": ( - "Track progress on multi-step tasks with a todo list (pending/in_progress/completed).\n\n" - "Use for tasks with 3+ steps. Each call replaces the entire plan (send full list).\n\n" - "Rules: exactly ONE task in_progress at a time. Mark completed immediately after finishing. " - "Only mark completed when the task fully succeeded β€” keep in_progress if there are errors. " - "Update frequently so the user sees progress." + "Manage task planning and progress tracking with todo list (pending/in_progress/completed statuses). " + "⚠️ CRITICAL: ALWAYS use for multi-step tasks (3+ steps) and MUST update frequently to show progress. " + "**Use when:** (1) User provides multiple tasks, (2) Complex workflows (training, evaluation, data processing), " + "(3) Tasks requiring multiple tool calls, (4) Need to communicate progress clearly to user, " + "(5) Breaking down ambiguous requests into concrete steps. " + "**Pattern:** Create plan at start β†’ Mark in_progress when starting task β†’ Mark completed immediately after finishing β†’ User sees clear progress. " + "Each call replaces entire plan (full list required). " + "**Critical for reliability:** Exactly ONE task in_progress at a time (not zero, not multiple). " + "Mark tasks completed IMMEDIATELY after finishing - don't batch completions. " + "**For long-running tasks:** Update plan after each major step to keep user informed. " + "**Only mark completed when:** Task fully accomplished, no errors, all requirements met. " + "Keep tasks pending if blocked/errors occur - create new task to resolve blockers." ), "parameters": { "type": "object", diff --git a/agent/tools/research_tool.py b/agent/tools/research_tool.py deleted file mode 100644 index f5815be8332ef371d3e863652bfc6cdd5127bbc2..0000000000000000000000000000000000000000 --- a/agent/tools/research_tool.py +++ /dev/null @@ -1,543 +0,0 @@ -""" -Research subagent tool β€” spawns a cheap LLM call with a focused -research task and returns a summary. The subagent gets its own -independent context (not the main conversation), so research -work doesn't pollute the main agent's context window. - -Inspired by claude-code's code-explorer agent pattern. -""" - -import json -import logging -import time -from typing import Any - -from litellm import Message, acompletion - -from agent.core import telemetry -from agent.core.doom_loop import check_for_doom_loop -from agent.core.llm_params import _resolve_llm_params -from agent.core.prompt_caching import with_prompt_caching -from agent.core.session import Event - -logger = logging.getLogger(__name__) - -# Context budget for the research subagent (tokens). -# When usage exceeds WARN threshold, the subagent is told to wrap up. -# At MAX, the loop is force-stopped and whatever content exists is returned. -_RESEARCH_CONTEXT_WARN = 170_000 # 85% of 200k -_RESEARCH_CONTEXT_MAX = 190_000 - -# Tools the research agent can use (read-only subset) -RESEARCH_TOOL_NAMES = { - "read", - "bash", - "explore_hf_docs", - "fetch_hf_docs", - "find_hf_api", - "hf_papers", - "github_find_examples", - "github_list_repos", - "github_read_file", - "web_search", - "hf_inspect_dataset", - "hf_repo_files", -} - -RESEARCH_SYSTEM_PROMPT = """\ -You are a research sub-agent for an ML engineering assistant. -Your primary job: mine the literature to find the best training recipes β€” -then back them up with working code and up to date documantation. The main agent will use -your findings to implement the actual solution. - -# Start from the literature - -Your default approach is a deep literature crawl. Do not start from docs or -example scripts β€” start from papers. Papers contain the results, and results -tell you what actually works. - -## The crawl - -1. **Find anchor papers**: Search for the task/domain. Identify the landmark paper(s) β€” high citations, recent, or both. -2. **Crawl the citation graph**: Use `citation_graph` on the anchor paper(s). Look DOWNSTREAM (papers that cite it) β€” these are the ones that built on it, improved it, or applied it to new domains. Prioritize recent papers and papers with many citations. -3. **Read methodology sections**: For the most promising papers (strong results, recent, relevant), use `read_paper` with section parameter to read sections 3, 4, 5 (Methodology, Experiments, Results β€” not the abstract). Extract: - - The exact dataset(s) used (name, source, size, any filtering/preprocessing) - - The training method and configuration (optimizer, lr, schedule, epochs, batch size) - - The results those choices produced (benchmark scores, metrics, comparisons) -4. **Attribute results to recipes**: This is the critical step. Every finding must link a RESULT to the RECIPE that produced it. "Dataset X + method Y + lr Z β†’ score W on benchmark V" is useful. "They used SFT" is not. -5. **Validate datasets**: For the most promising datasets, check if they exist on HF Hub with `hf_inspect_dataset`. Verify format matches the training method. Report if doesnt. -6. **Find code**: Now find working implementation code via `github_find_examples` and `github_read_file`. Use docs (`explore_hf_docs`, `fetch_hf_docs`) to fill in API details. - -## When to go deeper - -- If the anchor paper is old (>1 year), its citation graph is your main source β€” the downstream papers will have better methods. -- If a downstream paper reports significantly better results, crawl ITS citation graph too. -- Use `snippet_search` to find specific claims across papers (e.g., "does dataset X consistently outperform Y for this task?"). -- Use `recommend` to find related papers the citation graph might miss. - -# How to use your tools - -## Papers & citations (USE FIRST) -- `hf_papers(operation="search", query=...)`: Search papers (HF-tuned for ML) -- `hf_papers(operation="search", query=..., min_citations=50, sort_by="citationCount")`: Find highly-cited papers via Semantic Scholar -- `hf_papers(operation="search", query=..., date_from="2024-01-01")`: Search with date filter -- `hf_papers(operation="paper_details", arxiv_id=...)`: Metadata, citations, TL;DR -- `hf_papers(operation="citation_graph", arxiv_id=...)`: References + citations with influence flags and intents -- `hf_papers(operation="read_paper", arxiv_id=..., section="3")`: Read a specific section's full text -- `hf_papers(operation="read_paper", arxiv_id=...)`: Get TOC (abstract + section list) β€” use this to find which section numbers contain methodology/experiments -- `hf_papers(operation="snippet_search", query=...)`: Semantic search across 12M+ full-text paper passages -- `hf_papers(operation="recommend", arxiv_id=...)`: Find related papers -- `hf_papers(operation="find_datasets", arxiv_id=...)`: Find HF datasets linked to a paper -- `hf_papers(operation="find_all_resources", arxiv_id=...)`: Datasets + models + collections for a paper - -## Dataset inspection -- `hf_inspect_dataset`: Check dataset schema, splits, sample rows - CRITICAL for training: verify column format matches training method: - - SFT: needs "messages", "text", or "prompt"/"completion" - - DPO: needs "prompt", "chosen", "rejected" - - GRPO: needs "prompt" only - -## GitHub code research -- `github_find_examples`: Find working example scripts in HF repos (trl, transformers, etc.) -- `github_read_file`: Read the actual implementation code. Use line_start/line_end for large files. - -## Documentation -- `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc. -- `fetch_hf_docs(url)`: Fetch full page content from explore results -- `find_hf_api(query=..., tag=...)`: Find REST API endpoints -- `web_search(query=..., allowed_domains=[...], blocked_domains=[...])`: - Search the current web when papers/docs/GitHub are not enough. - -## Hub repo inspection -- `hf_repo_files`: List/read files in any HF repo (model, dataset, space) - -# Correct research pattern - -``` -# 1. Find anchor paper(s) for the task -hf_papers({"operation": "search", "query": "GPQA graduate questions", "sort_by": "citationCount"}) - -# 2. Crawl citation graph β€” look downstream -hf_papers({"operation": "citation_graph", "arxiv_id": "2311.12022", "direction": "citations"}) - -# 3. Read methodology of promising downstream papers -hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348"}) # TOC first -hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "3"}) # Methodology -hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "4"}) # Experiments - -# 4. Find datasets used by these papers -hf_papers({"operation": "find_datasets", "arxiv_id": "2604.01348"}) -hf_papers({"operation": "find_all_resources", "arxiv_id": "2604.01348"}) - -# 5. Validate datasets exist and have correct format -hf_inspect_dataset({"dataset": "org/dataset-name", "split": "train", "sample_rows": 3}) - -# 6. Now get working code for the training method -github_find_examples({"repo": "trl", "keyword": "sft"}) -github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"}) -explore_hf_docs("trl") -``` - -# Output format - - - -Your output MUST be structured as a ranked list of training recipes, each attributed to published results: - -## Recipe table (REQUIRED) -For each promising approach found, report: -- **Paper**: title, arxiv_id, date, venue -- **Result**: exact benchmark scores and what they were measured on -- **Dataset(s)**: name, size, source, HF Hub availability, format verified (yes/no) -- **Method**: training approach, key hyperparameters (lr, epochs, batch size, optimizer, schedule) -- **What made it work**: the specific insight or trick that drove the result (data curation, curriculum, loss function, etc.) - -Rank recipes by result quality. The main agent will pick the best one that's feasible. - -## Code patterns -- Key imports, configurations, and usage patterns from working examples -- Specific file paths, URLs, function names from docs - -## Recommendations -- Which recipe to implement first and why -- What datasets to use (with HF Hub paths, verified) -- Any gaps: datasets that need preprocessing, methods that need adaptation - -Additionally include: -- **SOTA landscape**: Current best models, datasets, and methods for the task (from recent papers). Flag anything outdated. -- **Essential references**: Specific file paths, URLs, function names, doc sections, code snippets - that the main agent should use directly -- **Code patterns**: Key imports, configurations, and usage patterns from working examples - -Be concise. Your output goes into another agent's context β€” every token counts. -Aim for 500-1500 words max. Include actual code snippets from examples you read, -not paraphrased descriptions. -""" - -RESEARCH_TOOL_SPEC = { - "name": "research", - "description": ( - "Spawn a research sub-agent to explore documentation, codebases, " - "or repos WITHOUT polluting the main conversation context. " - "The sub-agent gets its own independent context window with read-only " - "research tools and returns a concise summary of findings.\n\n" - "Use this for:\n" - "- Researching current API usage before implementing ML tasks " - "(find examples + read docs)\n" - "- Exploring HF docs, reading papers, analyzing GitHub repos\n" - "- Any research where raw tool outputs would be too verbose\n\n" - "The sub-agent knows how to use github_find_examples, github_read_file, " - "explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, hf_papers, etc. " - "Just describe what you need researched." - ), - "parameters": { - "type": "object", - "properties": { - "task": { - "type": "string", - "description": ( - "Detailed description of what to research. Be specific: " - "include library names, trainer types, dataset names, " - "repo names, or doc pages to explore. Example: " - "'Research current TRL SFTTrainer usage: find working " - "example scripts, read the SFT documentation, and check " - "SFTConfig parameters. Also validate that dataset " - "HuggingFaceH4/ultrachat_200k has the right format for SFT.'" - ), - }, - "context": { - "type": "string", - "description": ( - "Optional context from the current conversation that the " - "research agent needs (e.g., what the user wants to build, " - "constraints, what's been tried)." - ), - }, - }, - "required": ["task"], - }, -} - - -def _get_research_model(main_model: str) -> str: - """Pick a cheaper model for research based on the main model.""" - if main_model.startswith("anthropic/"): - return "anthropic/claude-sonnet-4-6" - if main_model.startswith("bedrock/") and "anthropic" in main_model: - return "bedrock/us.anthropic.claude-sonnet-4-6" - # For non-Anthropic models (HF router etc.), use the same model - return main_model - - -async def research_handler( - arguments: dict[str, Any], session=None, tool_call_id: str | None = None, **_kw -) -> tuple[str, bool]: - """Execute a research sub-agent with its own context.""" - task = arguments.get("task", "") - context = arguments.get("context", "") - if not task: - return "No research task provided.", False - - if not session: - return "No session available for research agent.", False - - # Build the sub-agent's messages (independent context) - messages: list[Message] = [ - Message(role="system", content=RESEARCH_SYSTEM_PROMPT), - ] - - user_content = f"Research task: {task}" - if context: - user_content = f"Context: {context}\n\n{user_content}" - messages.append(Message(role="user", content=user_content)) - - # Use a cheaper/faster model for research - main_model = session.config.model_name - research_model = _get_research_model(main_model) - # Research is a cheap sub-call β€” cap the main session's effort at "high" - # so a user preference of ``max``/``xhigh`` (valid for Opus 4.6/4.7) doesn't - # propagate to a Sonnet research model that may not accept those levels. - # We also haven't probed this sub-model so we don't know its ceiling. - _pref = getattr(session.config, "reasoning_effort", None) - _capped = "high" if _pref in ("max", "xhigh") else _pref - llm_params = _resolve_llm_params( - research_model, - getattr(session, "hf_token", None), - reasoning_effort=_capped, - ) - - # Get read-only tool specs from the session's tool router - tool_specs = [ - spec - for spec in session.tool_router.get_tool_specs_for_llm() - if spec["function"]["name"] in RESEARCH_TOOL_NAMES - ] - - # Unique ID + short label so parallel agents show separate status lines. - # Use the tool_call_id when available β€” it's unique per invocation and lets - # the frontend match a research tool card to its agent state. Fall back to - # uuid for offline/test paths. Previously used md5(task), which collided - # when the same task string was researched in parallel. - if tool_call_id: - _agent_id = tool_call_id - else: - import uuid - - _agent_id = uuid.uuid4().hex[:8] - _agent_label = "research: " + (task[:50] + "…" if len(task) > 50 else task) - - async def _log(text: str) -> None: - """Send a progress event to the UI so it doesn't look frozen.""" - try: - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "research", - "log": text, - "agent_id": _agent_id, - "label": _agent_label, - }, - ) - ) - except Exception: - pass - - _tool_uses = 0 - _total_tokens = 0 - _warned_context = False - - await _log("Starting research sub-agent...") - - # Run the research loop β€” context budget is the real limiter - max_iterations = 60 - for _iteration in range(max_iterations): - # ── Doom-loop detection ── - doom_prompt = check_for_doom_loop(messages) - if doom_prompt: - logger.warning( - "Research sub-agent repetition guard activated at iteration %d", - _iteration, - ) - messages.append(Message(role="user", content=doom_prompt)) - - # ── Context budget: warn at 75%, hard-stop at 95% ── - if _total_tokens >= _RESEARCH_CONTEXT_MAX: - logger.warning( - "Research sub-agent hit context max (%d tokens) β€” forcing summary", - _total_tokens, - ) - await _log( - f"Context limit reached ({_total_tokens} tokens) β€” forcing wrap-up" - ) - # Ask for a final summary with no tools - messages.append( - Message( - role="user", - content=( - "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. " - "Summarize your findings NOW. Do NOT call any more tools." - ), - ) - ) - try: - _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model")) - _t0 = time.monotonic() - response = await acompletion( - messages=_msgs, - tools=None, # no tools β€” force text response - stream=False, - timeout=120, - **llm_params, - ) - # Telemetry is best-effort; a logging blip must never mask a - # valid LLM response (the surrounding except would convert it - # to "summary call failed"). - try: - await telemetry.record_llm_call( - session, - model=research_model, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind="research", - ) - except Exception as _telem_err: - logger.debug("research telemetry failed: %s", _telem_err) - content = response.choices[0].message.content or "" - return ( - content or "Research context exhausted β€” no summary produced.", - bool(content), - ) - except Exception: - return "Research context exhausted and summary call failed.", False - - if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN: - _warned_context = True - await _log(f"Context at {_total_tokens} tokens β€” nudging to wrap up") - messages.append( - Message( - role="user", - content=( - "[SYSTEM: You have used 75% of your context budget. " - "Start wrapping up: finish any critical lookups, then " - "produce your final summary within the next 1-2 iterations.]" - ), - ) - ) - - try: - _msgs, _tools = with_prompt_caching( - messages, tool_specs if tool_specs else None, llm_params.get("model") - ) - _t0 = time.monotonic() - response = await acompletion( - messages=_msgs, - tools=_tools, - tool_choice="auto", - stream=False, - timeout=120, - **llm_params, - ) - try: - await telemetry.record_llm_call( - session, - model=research_model, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind="research", - ) - except Exception as _telem_err: - logger.debug("research telemetry failed: %s", _telem_err) - except Exception as e: - logger.error("Research sub-agent LLM error: %s", e) - return f"Research agent LLM error: {e}", False - - # Track tokens - if response.usage: - _total_tokens = response.usage.total_tokens - await _log(f"tokens:{_total_tokens}") - - choice = response.choices[0] - msg = choice.message - - # If no tool calls, we have our final answer - if not msg.tool_calls: - await _log("Research complete.") - content = msg.content or "Research completed but no summary generated." - return content, True - - # Execute tool calls and add results. - # Rebuild the assistant message with only the wire-safe fields β€” - # LiteLLM's raw Message carries `provider_specific_fields` and - # `reasoning_content`, which the HF router's OpenAI schema rejects - # if we echo them back in the next request. - messages.append( - Message( - role="assistant", - content=msg.content, - tool_calls=msg.tool_calls, - ) - ) - for tc in msg.tool_calls: - try: - tool_args = json.loads(tc.function.arguments) - except (json.JSONDecodeError, TypeError): - messages.append( - Message( - role="tool", - content="Invalid tool arguments.", - tool_call_id=tc.id, - name=tc.function.name, - ) - ) - continue - - tool_name = tc.function.name - if tool_name not in RESEARCH_TOOL_NAMES: - messages.append( - Message( - role="tool", - content=f"Tool '{tool_name}' not available for research.", - tool_call_id=tc.id, - name=tool_name, - ) - ) - continue - - try: - import json as _json - - args_str = _json.dumps(tool_args)[:80] - await _log(f"β–Έ {tool_name} {args_str}") - - output, _success = await session.tool_router.call_tool( - tool_name, tool_args, session=session, tool_call_id=tc.id - ) - _tool_uses += 1 - await _log(f"tools:{_tool_uses}") - # Truncate tool output for the research context - if len(output) > 8000: - output = output[:4800] + "\n...(truncated)...\n" + output[-3200:] - except Exception as e: - output = f"Tool error: {e}" - - messages.append( - Message( - role="tool", - content=output, - tool_call_id=tc.id, - name=tool_name, - ) - ) - - # ── Iteration limit: try to salvage findings ── - await _log("Iteration limit reached β€” extracting summary") - messages.append( - Message( - role="user", - content=( - "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research " - "iterations. Summarize ALL findings so far. Do NOT call any more tools." - ), - ) - ) - try: - _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model")) - _t0 = time.monotonic() - response = await acompletion( - messages=_msgs, - tools=None, - stream=False, - timeout=120, - **llm_params, - ) - try: - await telemetry.record_llm_call( - session, - model=research_model, - response=response, - latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason - if response.choices - else None, - kind="research", - ) - except Exception as _telem_err: - logger.debug("research telemetry failed: %s", _telem_err) - content = response.choices[0].message.content or "" - if content: - return content, True - except Exception as e: - logger.error("Research summary call failed: %s", e) - - return ( - "Research agent hit iteration limit (60). " - "Partial findings may be incomplete β€” try a more focused task.", - False, - ) diff --git a/agent/tools/sandbox_client.py b/agent/tools/sandbox_client.py deleted file mode 100644 index 1871d8fce3f1bfcdcb994064f5f83f954d84944c..0000000000000000000000000000000000000000 --- a/agent/tools/sandbox_client.py +++ /dev/null @@ -1,1160 +0,0 @@ -#!/usr/bin/env python3 -# /// script -# requires-python = ">=3.10" -# dependencies = ["huggingface_hub>=0.20.0", "httpx>=0.27.0"] -# /// -""" -Sandbox Tools β€” Agent-native primitives for HF Space dev-mode sandboxes. - -Architecture: - - Creates a sandbox by duplicating a template Space (runs sandbox_server.py) - - Waits for it to come online - - Communicates via HTTPS to the Space's API - - Optionally deletes the Space when done - -Lifecycle: - sb = Sandbox.create(owner="burtenshaw") # duplicate private Space, wait, connect - sb = Sandbox.create(owner="burtenshaw", # with options - hardware="t4-small", - private=True, - sleep_time=3600) - sb = Sandbox.connect("burtenshaw/my-sandbox-abc") # attach to existing - - sb.bash("uv run train.py") - sb.read("/app/train.py") - sb.edit("/app/train.py", old_str="lr=1e-3", new_str="lr=1e-4") - - sb.delete() # tear down when done - - # Or use as a context manager for automatic cleanup - with Sandbox.create(owner="burtenshaw") as sb: - sb.bash("python train.py") - # Space deleted on exit - -Tools: bash, read, write, edit, upload -""" - -from __future__ import annotations - -import io -import secrets as secrets_lib -import sys -import time -import uuid -from dataclasses import dataclass, field -from typing import Any, Callable - -import httpx -from huggingface_hub import CommitOperationAdd, HfApi - -TEMPLATE_SPACE = "burtenshaw/sandbox" -HARDWARE_OPTIONS = [ - "cpu-basic", - "cpu-upgrade", - "t4-small", - "t4-medium", - "a10g-small", - "a10g-large", - "a100-large", -] -OUTPUT_LIMIT = 25000 -LINE_LIMIT = 4000 -DEFAULT_READ_LIMIT = 2000 -DEFAULT_TIMEOUT = 240 -MAX_TIMEOUT = 1200 -WAIT_TIMEOUT = 600 -WAIT_INTERVAL = 5 -API_WAIT_TIMEOUT = 180 -CPU_BASIC_HARDWARE = "cpu-basic" - - -def _is_transient_space_visibility_error(error: Exception) -> bool: - """Return True when a newly duplicated Space is not queryable yet.""" - response = getattr(error, "response", None) - if getattr(response, "status_code", None) == 404: - return True - message = str(error) - return "Repository Not Found" in message or "404 Client Error" in message - - -_DOCKERFILE = """\ -FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim - -RUN apt-get update && \\ - apt-get install -y \\ - bash git git-lfs wget curl procps \\ - htop vim nano jq tmux \\ - build-essential && \\ - rm -rf /var/lib/apt/lists/* - -RUN uv pip install --system fastapi uvicorn python-multipart - -RUN useradd -m -u 1000 user -USER user - -ENV HOME=/home/user \\ - PATH=/home/user/.local/bin:$PATH \\ - PIP_USER=1 \\ - HF_HUB_DISABLE_PROGRESS_BARS=1 \\ - TQDM_DISABLE=1 \\ - HF_HUB_ENABLE_HF_TRANSFER=1 \\ - UV_NO_PROGRESS=1 \\ - PYTHONWARNINGS=ignore::DeprecationWarning - -WORKDIR /app -COPY --chown=user . /app - -EXPOSE 7860 - -CMD ["python", "sandbox_server.py"] -""" - -_SANDBOX_SERVER = '''\ -"""Minimal FastAPI server for sandbox operations.""" -import hmac, os, subprocess, pathlib, signal, threading, re, tempfile -from fastapi import Depends, FastAPI, HTTPException, Request -from pydantic import BaseModel -from typing import Optional -import uvicorn - -_ANSI_RE = re.compile(r'\\x1b\\[[0-9;]*[a-zA-Z]|\\x1b\\].*?\\x07') - -def _strip_ansi(text: str) -> str: - return _ANSI_RE.sub('', text) - -def _truncate_output(output: str, max_chars: int = 25000, head_ratio: float = 0.25) -> str: - if len(output) <= max_chars: - return output - # Write full output to temp file so LLM can read specific sections - spill_path = None - try: - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', prefix='bash_output_', dir='/tmp', delete=False) as f: - f.write(output) - spill_path = f.name - except Exception: - pass - head_budget = int(max_chars * head_ratio) - tail_budget = max_chars - head_budget - head = output[:head_budget] - tail = output[-tail_budget:] - total = len(output) - omitted = total - max_chars - meta = f"\\n\\n... ({omitted:,} of {total:,} chars omitted, showing first {head_budget:,} + last {tail_budget:,}) ...\\n" - if spill_path: - meta += f"Full output saved to {spill_path} β€” use the read tool with offset/limit to inspect specific sections.\\n" - return head + meta + tail - -def _atomic_write(path: pathlib.Path, content: str): - """Write atomically: temp file + fsync + os.replace.""" - path.parent.mkdir(parents=True, exist_ok=True) - fd = None - tmp_path = None - try: - fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp") - os.write(fd, content.encode("utf-8")) - os.fsync(fd) - os.close(fd) - fd = None - os.replace(tmp_path, str(path)) - tmp_path = None - finally: - if fd is not None: - os.close(fd) - if tmp_path is not None: - try: - os.unlink(tmp_path) - except OSError: - pass - -app = FastAPI() - -def _bearer_token(header: str) -> str: - scheme, _, supplied = header.partition(" ") - if scheme.lower() != "bearer" or not supplied: - return "" - return supplied - -def _require_auth(request: Request) -> None: - sandbox_token = os.environ.get("SANDBOX_API_TOKEN") or "" - if not sandbox_token: - raise HTTPException(status_code=503, detail="Sandbox API token not configured") - supplied = _bearer_token(request.headers.get("x-sandbox-authorization", "")) - if not supplied: - raise HTTPException(status_code=401, detail="Missing bearer token") - if not hmac.compare_digest(supplied, sandbox_token): - raise HTTPException(status_code=401, detail="Invalid bearer token") - -_AUTH = [Depends(_require_auth)] - -# Track active bash processes so they can be killed on cancel -_active_procs = {} # pid -> subprocess.Popen -_proc_lock = threading.Lock() - -class BashReq(BaseModel): - command: str - work_dir: str = "/app" - timeout: int = 120 - -class ReadReq(BaseModel): - path: str - offset: Optional[int] = None - limit: Optional[int] = 2000 - -class WriteReq(BaseModel): - path: str - content: str - -class EditReq(BaseModel): - path: str - old_str: str - new_str: str - replace_all: bool = False - mode: str = "replace" - -class ExistsReq(BaseModel): - path: str - -# ── Fuzzy matching & edit utilities (embedded) ── - -UNICODE_MAP = { - "\\u2013": "-", "\\u2014": "-", "\\u2212": "-", - "\\u2018": "'", "\\u2019": "'", - "\\u201c": \'"\', "\\u201d": \'"\', - "\\u00a0": " ", "\\u2003": " ", "\\u2002": " ", - "\\u200b": "", "\\ufeff": "", -} - -def _normalize_unicode(s): - return "".join(UNICODE_MAP.get(c, c) for c in s) - -def _fuzzy_find_original(content, pattern): - """Find the original text in content that matches pattern fuzzily.""" - if pattern in content: - return pattern, None - # Pass 2: right-trim - c_lines = content.split("\\n") - c_rt = "\\n".join(l.rstrip() for l in c_lines) - p_rt = "\\n".join(l.rstrip() for l in pattern.split("\\n")) - if p_rt in c_rt: - idx = c_rt.index(p_rt) - start_line = c_rt[:idx].count("\\n") - n_lines = p_rt.count("\\n") + 1 - matched = "\\n".join(c_lines[start_line:start_line + n_lines]) - return matched, "(matched after trimming trailing whitespace)" - # Pass 3: both-sides trim - c_st = "\\n".join(l.strip() for l in c_lines) - p_st = "\\n".join(l.strip() for l in pattern.split("\\n")) - if p_st in c_st: - idx = c_st.index(p_st) - start_line = c_st[:idx].count("\\n") - n_lines = p_st.count("\\n") + 1 - matched = "\\n".join(c_lines[start_line:start_line + n_lines]) - return matched, "(matched after trimming whitespace)" - # Pass 4: unicode normalization - c_norm = _normalize_unicode(c_st) - p_norm = _normalize_unicode(p_st) - if p_norm in c_norm: - idx = c_norm.index(p_norm) - start_line = c_norm[:idx].count("\\n") - n_lines = p_norm.count("\\n") + 1 - matched = "\\n".join(c_lines[start_line:start_line + n_lines]) - return matched, "(matched after unicode normalization)" - return None, None - -def _apply_edit(content, old_str, new_str, mode="replace", replace_all=False): - """Apply edit. Returns (new_content, count, fuzzy_note) or raises ValueError.""" - if mode == "replace_all": - replace_all = True - mode = "replace" - fuzzy_note = None - if old_str not in content: - matched, fuzzy_note = _fuzzy_find_original(content, old_str) - if matched is None: - raise ValueError("old_str not found in file.") - old_str = matched - count = content.count(old_str) - if mode == "replace": - if count > 1 and not replace_all: - raise ValueError(f"old_str appears {count} times. Use replace_all=true or provide more context.") - if replace_all: - return content.replace(old_str, new_str), count, fuzzy_note - return content.replace(old_str, new_str, 1), 1, fuzzy_note - elif mode == "append_after": - if replace_all: - return content.replace(old_str, old_str + new_str), count, fuzzy_note - idx = content.index(old_str) + len(old_str) - return content[:idx] + new_str + content[idx:], 1, fuzzy_note - elif mode == "prepend_before": - if replace_all: - return content.replace(old_str, new_str + old_str), count, fuzzy_note - idx = content.index(old_str) - return content[:idx] + new_str + content[idx:], 1, fuzzy_note - raise ValueError(f"Unknown mode: {mode}") - -def _validate_python(content, path=""): - """Validate Python: syntax, kwargs against real installed signatures, training heuristics. - - Runs inside the sandbox where packages are pip-installed, so we can actually - import classes and inspect their __init__ signatures to catch kwarg mismatches - before runtime. - """ - import ast as _ast, inspect as _inspect, importlib as _il - warnings = [] - - # 1. Syntax check - try: - tree = _ast.parse(content) - except SyntaxError as e: - warnings.append(f"Python syntax error at line {e.lineno}: {e.msg}") - return warnings - - # 2. Build import map: name -> module path (from the script's own imports) - import_map = {} - for node in _ast.walk(tree): - if isinstance(node, _ast.ImportFrom) and node.module: - for alias in (node.names or []): - local_name = alias.asname or alias.name - import_map[local_name] = (node.module, alias.name) - elif isinstance(node, _ast.Import): - for alias in (node.names or []): - local_name = alias.asname or alias.name - import_map[local_name] = (alias.name, None) - - # 3. For each Call node, resolve the callable and check kwargs against signature - for node in _ast.walk(tree): - if not isinstance(node, _ast.Call): - continue - # Skip calls with **kwargs unpacking β€” we can't statically know those keys - if any(kw.arg is None for kw in node.keywords): - continue - call_kwargs = [kw.arg for kw in node.keywords if kw.arg] - if not call_kwargs: - continue - - # Resolve the callable name - func_name = None - if isinstance(node.func, _ast.Name): - func_name = node.func.id - elif isinstance(node.func, _ast.Attribute): - func_name = node.func.attr - if not func_name or func_name not in import_map: - continue - - # Try to import and inspect the real callable - module_path, attr_name = import_map[func_name] - try: - mod = _il.import_module(module_path) - obj = getattr(mod, attr_name, None) if attr_name else mod - if obj is None: - continue - sig = _inspect.signature(obj) - params = sig.parameters - # If **kwargs is in the signature, any kwarg is valid - if any(p.kind == _inspect.Parameter.VAR_KEYWORD for p in params.values()): - continue - valid_names = set(params.keys()) - for kw_name in call_kwargs: - if kw_name not in valid_names: - warnings.append( - f"Invalid kwarg: {func_name}({kw_name}=...) at line {node.lineno} " - f"-- not accepted by {module_path}.{attr_name or func_name}()" - ) - except Exception: - pass # can't import/inspect β€” skip silently - - # 4. Training script heuristics - if any(kw in content for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")): - if "push_to_hub" not in content: - warnings.append("Training script warning: no \'push_to_hub\' found") - if "hub_model_id" not in content: - warnings.append("Training script warning: no \'hub_model_id\' found") - return warnings - -@app.get("/api/health") -def health(): - return {"status": "ok"} - -@app.post("/api/bash", dependencies=_AUTH) -def bash(req: BashReq): - try: - proc = subprocess.Popen( - req.command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True, cwd=req.work_dir, start_new_session=True, - ) - with _proc_lock: - _active_procs[proc.pid] = proc - try: - stdout, stderr = proc.communicate(timeout=req.timeout) - output = _strip_ansi(stdout + stderr) - output = _truncate_output(output) - return {"success": proc.returncode == 0, "output": output, "error": "" if proc.returncode == 0 else f"Exit code {proc.returncode}"} - except subprocess.TimeoutExpired: - try: - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - except OSError: - proc.kill() - proc.wait() - return {"success": False, "output": "", "error": f"Timeout after {req.timeout}s"} - finally: - with _proc_lock: - _active_procs.pop(proc.pid, None) - except Exception as e: - return {"success": False, "output": "", "error": str(e)} - -@app.post("/api/kill", dependencies=_AUTH) -def kill_all(): - """Kill all active bash processes. Called when user cancels.""" - with _proc_lock: - pids = list(_active_procs.keys()) - killed = [] - for pid in pids: - try: - os.killpg(os.getpgid(pid), signal.SIGTERM) - killed.append(pid) - except OSError: - try: - os.kill(pid, signal.SIGKILL) - killed.append(pid) - except OSError: - pass - return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""} - -@app.post("/api/read", dependencies=_AUTH) -def read(req: ReadReq): - try: - p = pathlib.Path(req.path) - if not p.exists(): - return {"success": False, "output": "", "error": f"File not found: {req.path}"} - if p.is_dir(): - return {"success": False, "output": "", "error": f"Is a directory: {req.path}"} - lines = p.read_text().splitlines() - start = (req.offset or 1) - 1 - end = start + (req.limit or len(lines)) - selected = lines[start:end] - numbered = "\\n".join(f"{start + i + 1}\\t{line}" for i, line in enumerate(selected)) - return {"success": True, "output": numbered, "error": ""} - except Exception as e: - return {"success": False, "output": "", "error": str(e)} - -@app.post("/api/write", dependencies=_AUTH) -def write(req: WriteReq): - try: - p = pathlib.Path(req.path) - _atomic_write(p, req.content) - msg = f"Wrote {len(req.content)} bytes to {req.path}" - if p.suffix == ".py": - warnings = _validate_python(req.content, req.path) - if warnings: - msg += "\\n\\nValidation warnings:\\n" + "\\n".join(f" ! {w}" for w in warnings) - return {"success": True, "output": msg, "error": ""} - except Exception as e: - return {"success": False, "output": "", "error": str(e)} - -@app.post("/api/edit", dependencies=_AUTH) -def edit(req: EditReq): - try: - p = pathlib.Path(req.path) - if not p.exists(): - return {"success": False, "output": "", "error": f"File not found: {req.path}"} - content = p.read_text() - if req.old_str == req.new_str: - return {"success": False, "output": "", "error": "old_str and new_str must differ."} - try: - new_content, count, fuzzy_note = _apply_edit( - content, req.old_str, req.new_str, mode=req.mode, replace_all=req.replace_all - ) - except ValueError as e: - return {"success": False, "output": "", "error": str(e)} - _atomic_write(p, new_content) - msg = f"Edited {req.path} ({count} replacement{'s' if count > 1 else ''})" - if fuzzy_note: - msg += f" {fuzzy_note}" - if p.suffix == ".py": - warnings = _validate_python(new_content, req.path) - if warnings: - msg += "\\n\\nValidation warnings:\\n" + "\\n".join(f" ! {w}" for w in warnings) - return {"success": True, "output": msg, "error": ""} - except Exception as e: - return {"success": False, "output": "", "error": str(e)} - -@app.post("/api/exists", dependencies=_AUTH) -def exists(req: ExistsReq): - return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""} - -if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=7860) -''' - - -@dataclass -class ToolResult: - success: bool - output: str = "" - error: str = "" - - def __str__(self): - if self.success: - return self.output or "(no output)" - return f"ERROR: {self.error}" - - def to_dict(self) -> dict: - return {"success": self.success, "output": self.output, "error": self.error} - - -@dataclass -class Sandbox: - """ - A handle to an HF Space sandbox. - - Use Sandbox.create() to spin up a new one, or Sandbox.connect() to - attach to an existing running Space. - """ - - space_id: str - token: str | None = None - api_token: str | None = field(default=None, repr=False) - work_dir: str = "/app" - timeout: int = DEFAULT_TIMEOUT - _owns_space: bool = field(default=False, repr=False) - _base_url: str = field(init=False, repr=False) - _client: httpx.Client = field(init=False, repr=False) - _hf_api: HfApi = field(init=False, repr=False) - _files_read: set = field(init=False, repr=False, default_factory=set) - - def __post_init__(self): - slug = self.space_id.replace("/", "-") - # Trailing slash is critical: httpx resolves relative paths against base_url. - # Without it, client.get("health") resolves to /health instead of /api/health. - self._base_url = f"https://{slug}.hf.space/api/" - self._client = httpx.Client( - base_url=self._base_url, - headers=self._auth_headers(), - timeout=httpx.Timeout(MAX_TIMEOUT, connect=30), - follow_redirects=True, - ) - self._hf_api = HfApi(token=self.token) - - def _auth_headers(self) -> dict[str, str]: - """Return headers for private HF Space access plus sandbox API auth. - - Private Spaces require the HF token in ``Authorization`` at the Hub - edge. The sandbox server requires its control-plane token in the - dedicated ``X-Sandbox-Authorization`` header. - """ - headers: dict[str, str] = {} - if self.token: - headers["Authorization"] = f"Bearer {self.token}" - if self.api_token: - headers["X-Sandbox-Authorization"] = f"Bearer {self.api_token}" - return headers - - # ── Lifecycle ───────────────────────────────────────────────── - - class Cancelled(Exception): - """Raised when sandbox creation is cancelled by the user.""" - - @classmethod - def create( - cls, - owner: str, - *, - name: str | None = None, - template: str = TEMPLATE_SPACE, - hardware: str = CPU_BASIC_HARDWARE, - private: bool = True, - sleep_time: int | None = None, - token: str | None = None, - secrets: dict[str, str] | None = None, - wait_timeout: int = WAIT_TIMEOUT, - log: "Callable[[str], object] | None" = None, - cancel_event: "Any | None" = None, - ) -> Sandbox: - """ - Create a new sandbox by duplicating the template Space. - - Generates a unique space name, duplicates the template, waits for it - to come online, then returns a connected Sandbox. - - Args: - owner: HF username or org (e.g. "burtenshaw"). - name: Base name for the space. Defaults to "sandbox". - A unique suffix is always appended. - template: Source Space to duplicate (default: burtenshaw/sandbox). - hardware: Hardware tier (cpu-basic, t4-small, etc.). - private: Whether the Space should be private. Defaults to True. - sleep_time: Auto-sleep after N seconds of inactivity. - token: HF API token (from user's OAuth session). - wait_timeout: Max seconds to wait for Space to start (default: 300). - cancel_event: A threading.Event (or compatible) checked during - polling loops. When set, the Space is deleted and - Sandbox.Cancelled is raised. - - Returns: - A Sandbox instance connected to the running Space. - """ - _log = log or print - api = HfApi(token=token) - - def _check_cancel(): - if cancel_event and cancel_event.is_set(): - _log("Sandbox creation cancelled by user, cleaning up...") - try: - api.delete_repo(space_id, repo_type="space") - _log(f"Deleted Space {space_id}") - except Exception: - pass - raise cls.Cancelled(f"Sandbox creation cancelled: {space_id}") - - base = name or "sandbox" - suffix = uuid.uuid4().hex[:8] - space_id = f"{owner}/{base}-{suffix}" - sandbox_api_token = secrets_lib.token_urlsafe(32) - - _log(f"Creating sandbox: {space_id} (from {template})...") - - kwargs = { - "from_id": template, - "to_id": space_id, - "private": private, - "hardware": hardware, - } - if sleep_time is not None: - kwargs["sleep_time"] = sleep_time - - api.duplicate_space(**kwargs) - _log(f"Space created: https://huggingface.co/spaces/{space_id}") - - _check_cancel() - - # ``duplicate_space`` sends hardware and sleepTimeSeconds in the - # initial create request. Avoid a second /hardware call: deployed HF - # OAuth tokens can 401 on that endpoint for a just-created private - # Space even though duplication itself succeeded. We rely on the - # duplicate endpoint to honor sleepTimeSeconds for upgraded hardware; - # cpu-basic auto-sleep is fixed by the Hub. - _log(f"Using duplicated Space hardware: {hardware}") - if sleep_time is not None: - if hardware == CPU_BASIC_HARDWARE: - _log( - f"Requested duplicated Space sleep time: {sleep_time}s " - "(cpu-basic auto-sleep is fixed by the Hub)" - ) - else: - _log(f"Using duplicated Space sleep time: {sleep_time}s") - - # Inject secrets BEFORE uploading server files (which triggers rebuild). - # Secrets added after a Space is running aren't available until restart, - # so they must be set before the build/start cycle. - sandbox_secrets = {**(secrets or {}), "SANDBOX_API_TOKEN": sandbox_api_token} - if sandbox_secrets: - for key, val in sandbox_secrets.items(): - api.add_space_secret(space_id, key, val) - - # Upload sandbox server and Dockerfile (triggers rebuild) - cls._setup_server(space_id, api, log=_log) - - _check_cancel() - - # Wait for it to come online (rebuild + start) - _log(f"Waiting for Space to start (timeout: {wait_timeout}s)...") - deadline = time.time() + wait_timeout - while time.time() < deadline: - _check_cancel() - try: - runtime = api.get_space_runtime(space_id) - except Exception as e: - if _is_transient_space_visibility_error(e): - _log(" Space runtime not visible yet...") - time.sleep(WAIT_INTERVAL) - continue - raise - if runtime.stage == "RUNNING": - current_hardware = runtime.hardware or getattr( - runtime, "requested_hardware", None - ) - if current_hardware != hardware: - _log(f" RUNNING on {current_hardware}; waiting for {hardware}...") - time.sleep(WAIT_INTERVAL) - continue - _log(f"Space is running (hardware: {runtime.hardware})") - break - if runtime.stage in ("RUNTIME_ERROR", "BUILD_ERROR"): - raise RuntimeError( - f"Space failed to start: {runtime.stage}. " - f"Check https://huggingface.co/spaces/{space_id}" - ) - _log(f" {runtime.stage}...") - time.sleep(WAIT_INTERVAL) - else: - raise TimeoutError( - f"Space did not start within {wait_timeout}s. " - f"Check https://huggingface.co/spaces/{space_id}" - ) - - _check_cancel() - - # Wait for the API server to be responsive (non-fatal) - sb = cls( - space_id=space_id, - token=token, - api_token=sandbox_api_token, - _owns_space=True, - ) - try: - sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log) - except TimeoutError as e: - _log( - f"Warning: API health check timed out ({e}), but Space is RUNNING. Continuing." - ) - return sb - - @staticmethod - def _setup_server( - space_id: str, api: HfApi, *, log: Callable[[str], object] = print - ) -> None: - """Upload embedded sandbox server + Dockerfile to the Space (single commit).""" - log(f"Uploading sandbox server to {space_id}...") - api.create_commit( - repo_id=space_id, - repo_type="space", - operations=[ - CommitOperationAdd( - path_in_repo="sandbox_server.py", - path_or_fileobj=io.BytesIO(_SANDBOX_SERVER.encode()), - ), - CommitOperationAdd( - path_in_repo="Dockerfile", - path_or_fileobj=io.BytesIO(_DOCKERFILE.encode()), - ), - ], - commit_message="Setup sandbox server", - ) - log("Server files uploaded, rebuild triggered.") - - @classmethod - def connect( - cls, - space_id: str, - *, - token: str | None = None, - api_token: str | None = None, - ) -> Sandbox: - """ - Connect to an existing running Space. - - Does a health check to verify the Space is reachable. - """ - sb = cls( - space_id=space_id, - token=token, - api_token=api_token, - _owns_space=False, - ) - sb._wait_for_api(timeout=60) - return sb - - def _wait_for_api( - self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print - ): - """Poll the health endpoint until the server responds.""" - deadline = time.time() + timeout - last_err = None - last_status = None - while time.time() < deadline: - try: - resp = self._client.get("health", timeout=10) - last_status = resp.status_code - if resp.status_code == 200: - log(f"API is responsive at {self._base_url}") - return - except Exception as e: - last_err = e - time.sleep(3) - raise TimeoutError( - f"Sandbox API at {self._base_url} not responding after {timeout}s. " - f"Last status: {last_status}, last error: {last_err}" - ) - - def delete(self): - """Delete the Space. Only works if this Sandbox created it.""" - if not self._owns_space: - raise RuntimeError( - f"This Sandbox did not create {self.space_id}. " - f"Use self._hf_api.delete_repo() directly if you're sure." - ) - print(f"Deleting sandbox: {self.space_id}...") - self._hf_api.delete_repo(self.space_id, repo_type="space") - # Clear ownership so a second cleanup call (e.g. delete_session + - # _run_session.finally both fire) early-returns instead of retrying - # a 404 delete and emitting a spurious ERROR log. - self._owns_space = False - self._client.close() - print("Deleted.") - - def pause(self): - """Pause the Space (stops billing, preserves state).""" - self._hf_api.pause_space(self.space_id) - - def restart(self): - """Restart the Space.""" - self._hf_api.restart_space(self.space_id) - self._wait_for_api() - - @property - def url(self) -> str: - """Public URL of the Space.""" - return f"https://huggingface.co/spaces/{self.space_id}" - - @property - def status(self) -> str: - """Current Space stage (RUNNING, BUILDING, PAUSED, etc.).""" - return self._hf_api.get_space_runtime(self.space_id).stage - - def __enter__(self) -> Sandbox: - return self - - def __exit__(self, *exc): - if self._owns_space: - try: - self.delete() - except Exception as e: - print(f"Warning: failed to delete sandbox: {e}", file=sys.stderr) - self._client.close() - - # ── HTTP plumbing ───────────────────────────────────────────── - - def _call( - self, endpoint: str, payload: dict, timeout: float | None = None - ) -> ToolResult: - # Strip leading slash for correct httpx base_url resolution - endpoint = endpoint.lstrip("/") - effective_timeout = timeout or self.timeout - last_error = "" - - # Retry up to 3 times for transient failures (sandbox waking from - # sleep returns empty / non-JSON responses while it starts up). - for attempt in range(3): - try: - resp = self._client.post( - endpoint, - json=payload, - timeout=effective_timeout, - ) - try: - data = resp.json() - except (ValueError, UnicodeDecodeError): - # Non-JSON response β€” sandbox is likely still starting up. - body_preview = resp.text[:200] if resp.text else "(empty)" - last_error = ( - f"Sandbox returned non-JSON response (HTTP {resp.status_code}): " - f"{body_preview}" - ) - if attempt < 2: - time.sleep(3 * (attempt + 1)) - continue - return ToolResult(success=False, error=last_error) - - if resp.status_code == 200: - return ToolResult( - success=data.get("success", True), - output=data.get("output", ""), - error=data.get("error", ""), - ) - return ToolResult( - success=False, - error=data.get("error", f"HTTP {resp.status_code}"), - ) - except httpx.TimeoutException: - return ToolResult( - success=False, error=f"Timeout after {effective_timeout}s" - ) - except httpx.ConnectError: - last_error = ( - f"Cannot connect to sandbox. Is {self.space_id} running? " - f"Status: {self.status}" - ) - if attempt < 2: - time.sleep(3 * (attempt + 1)) - continue - return ToolResult(success=False, error=last_error) - except Exception as e: - return ToolResult(success=False, error=str(e)) - - return ToolResult(success=False, error=last_error or "Unknown error") - - # ── Tools ───────────────────────────────────────────────────── - - def bash( - self, - command: str, - *, - work_dir: str | None = None, - timeout: int | None = None, - description: str | None = None, - ) -> ToolResult: - return self._call( - "bash", - { - "command": command, - "work_dir": work_dir or self.work_dir, - "timeout": min(timeout or self.timeout, MAX_TIMEOUT), - }, - timeout=timeout, - ) - - def read( - self, path: str, *, offset: int | None = None, limit: int | None = None - ) -> ToolResult: - self._files_read.add(path) - return self._call( - "read", - { - "path": path, - "offset": offset, - "limit": limit or (DEFAULT_READ_LIMIT if offset is None else None), - }, - ) - - def write(self, path: str, content: str) -> ToolResult: - if path not in self._files_read: - check = self._call("exists", {"path": path}) - if check.success and check.output == "true": - return ToolResult( - success=False, - error=( - f"File {path} exists but has not been read this session. " - f"Read it first, or use sandbox_edit for targeted changes." - ), - ) - result = self._call("write", {"path": path, "content": content}) - if result.success: - self._files_read.add(path) - return result - - def edit( - self, - path: str, - old_str: str, - new_str: str, - *, - replace_all: bool = False, - mode: str = "replace", - ) -> ToolResult: - if old_str == new_str: - return ToolResult(success=False, error="old_str and new_str are identical.") - if path not in self._files_read: - return ToolResult( - success=False, - error=f"File {path} has not been read this session. Read it first.", - ) - return self._call( - "edit", - { - "path": path, - "old_str": old_str, - "new_str": new_str, - "replace_all": replace_all, - "mode": mode, - }, - ) - - def kill_all(self) -> ToolResult: - """Kill all active bash processes on the sandbox. Used on cancellation.""" - return self._call("kill", {}) - - # ── Tool schemas & dispatch ─────────────────────────────────── - - TOOLS = { - "bash": { - "description": ( - "Run a shell command in the remote sandbox and return stdout/stderr.\n" - "\n" - "IMPORTANT: Do NOT use bash for file operations β€” use the dedicated tools instead:\n" - "- To read files: use read (not cat/head/tail)\n" - "- To edit files: use edit (not sed/awk)\n" - "- To write files: use write (not echo/cat < > /app/output.log 2>&1 & echo $!\n" - "Then check status:\n" - " kill -0 2>/dev/null && echo 'running' || echo 'done'\n" - " tail -n 50 /app/output.log\n" - "\n" - "Timeout default 240s, max 1200s." - ), - "parameters": { - "type": "object", - "required": ["command"], - "additionalProperties": False, - "properties": { - "command": { - "type": "string", - "description": "The shell command to execute.", - }, - "description": { - "type": "string", - "description": "Short description (5-10 words, active voice).", - }, - "work_dir": { - "type": "string", - "description": "Working directory (default: /app).", - }, - "timeout": { - "type": "integer", - "description": "Optional timeout in seconds (default: 240, max: 1200).", - }, - }, - }, - }, - "read": { - "description": ( - "Reads a file from the sandbox filesystem. Returns contents with line " - "numbers (cat -n format).\n" - "\n" - "Usage:\n" - "- By default, reads up to 2000 lines from the beginning of the file.\n" - "- You can optionally specify offset and limit for large files, but prefer " - "reading the whole file first.\n" - "- Lines longer than 4000 chars are truncated.\n" - "- Cannot read directories β€” use bash with 'ls' instead.\n" - "- You should read multiple potentially useful files in parallel when possible.\n" - "- IMPORTANT: Always read a file before editing or overwriting it. The edit and " - "write tools will reject operations on files you haven't read." - ), - "parameters": { - "type": "object", - "required": ["path"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file to read.", - }, - "offset": { - "type": "integer", - "description": "The line number to start reading from (1-based). Only provide if the file is too large to read at once.", - }, - "limit": { - "type": "integer", - "description": "The number of lines to read. Only provide if the file is too large to read at once.", - }, - }, - }, - }, - "write": { - "description": ( - "Writes a file to the sandbox filesystem. Overwrites the existing file if " - "one exists at the path.\n" - "\n" - "- If this is an existing file, you MUST use the read tool first. This tool " - "will fail if you did not read the file first.\n" - "- ALWAYS prefer editing existing files with the edit tool over overwriting " - "with write.\n" - "- Creates parent directories as needed." - ), - "parameters": { - "type": "object", - "required": ["path", "content"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file to write.", - }, - "content": { - "type": "string", - "description": "The complete file content to write.", - }, - }, - }, - }, - "edit": { - "description": ( - "Performs string replacements in files. Supports exact matching with " - "fuzzy fallback.\n" - "\n" - "Usage:\n" - "- You must read the file at least once before editing. This tool will " - "error if you attempt an edit without reading the file.\n" - "- The edit will FAIL if old_str is not unique in the file. Either provide " - "a larger string with more surrounding context to make it unique, or set " - "replace_all to true.\n" - "- old_str and new_str must differ.\n" - "- Preserve indentation exactly as it appears in the file.\n" - "- Do NOT include line number prefixes from read output in old_str or new_str.\n" - "- To delete code, set new_str to empty string.\n" - "- Use replace_all for renaming variables or strings across the file.\n" - "\n" - "Modes:\n" - "- replace (default): replace first occurrence of old_str with new_str.\n" - "- append_after: insert new_str immediately after old_str (old_str is kept).\n" - "- prepend_before: insert new_str immediately before old_str (old_str is kept)." - ), - "parameters": { - "type": "object", - "required": ["path", "old_str", "new_str"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file to edit.", - }, - "old_str": { - "type": "string", - "description": "The text to find in the file. Must match exactly (fuzzy matching is used as fallback).", - }, - "new_str": { - "type": "string", - "description": "The replacement text. For append_after/prepend_before modes, the text to insert.", - }, - "replace_all": { - "type": "boolean", - "description": "Replace all occurrences of old_str (default: false).", - "default": False, - }, - "mode": { - "type": "string", - "enum": ["replace", "append_after", "prepend_before"], - "description": "Edit mode (default: replace).", - "default": "replace", - }, - }, - }, - }, - } - - @classmethod - def tool_definitions(cls) -> list[dict]: - return [{"name": name, **spec} for name, spec in cls.TOOLS.items()] - - def call_tool(self, name: str, arguments: dict[str, Any]) -> ToolResult: - dispatch = { - "bash": lambda a: self.bash( - a["command"], - work_dir=a.get("work_dir"), - timeout=a.get("timeout"), - description=a.get("description"), - ), - "read": lambda a: self.read( - a["path"], - offset=a.get("offset"), - limit=a.get("limit"), - ), - "write": lambda a: self.write(a["path"], a["content"]), - "edit": lambda a: self.edit( - a["path"], - a["old_str"], - a["new_str"], - replace_all=a.get("replace_all", False), - mode=a.get("mode", "replace"), - ), - } - fn = dispatch.get(name) - if not fn: - return ToolResult(success=False, error=f"Unknown tool: {name}") - return fn(arguments) diff --git a/agent/tools/sandbox_tool.py b/agent/tools/sandbox_tool.py deleted file mode 100644 index fbc6a41f9fd9edf05b1565d5782983bde167fa3c..0000000000000000000000000000000000000000 --- a/agent/tools/sandbox_tool.py +++ /dev/null @@ -1,778 +0,0 @@ -""" -Sandbox tools β€” expose the Sandbox client as agent tools. - -5 tools total: - sandbox_create β€” create/replace sandbox for non-default hardware - bash, read, write, edit β€” operations on the active sandbox - -A cpu-basic sandbox is preloaded for each session. Operation tools wait for it -if startup is still in progress. -""" - -from __future__ import annotations - -import asyncio -import logging -import re -import threading -import weakref -from datetime import datetime, timedelta, timezone -from typing import Any - -from huggingface_hub import HfApi, SpaceHardware - -from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap -from agent.core.session import Event -from agent.tools.sandbox_client import Sandbox -from agent.tools.trackio_seed import ensure_trackio_dashboard - -logger = logging.getLogger(__name__) - -DEFAULT_CPU_SANDBOX_HARDWARE = "cpu-basic" - -# Match the exact suffix pattern Sandbox.create produces: "sandbox-<8 hex>". -# Used to identify orphan sandboxes from prior sessions safely (won't match -# user-renamed lookalikes). -SANDBOX_SPACE_NAME_RE = re.compile(r"^sandbox-[a-f0-9]{8}$") - -# How stale a sandbox must be before we treat it as definitely orphan. -# Anything more recent could be tied to a still-live session in another tab, -# so we leave it alone. -_ORPHAN_STALE_AFTER = timedelta(hours=1) - -# HF Space duplication/build APIs can behave poorly when multiple private -# sandboxes are created concurrently for the same namespace. Keep session -# creation non-blocking, but serialize the actual Hub create path per owner. -_SANDBOX_CREATE_LOCKS: weakref.WeakKeyDictionary[ - asyncio.AbstractEventLoop, dict[str, asyncio.Lock] -] = weakref.WeakKeyDictionary() - - -def _get_sandbox_create_lock(owner: str) -> asyncio.Lock: - loop = asyncio.get_running_loop() - locks = _SANDBOX_CREATE_LOCKS.setdefault(loop, {}) - lock = locks.get(owner) - if lock is None: - lock = asyncio.Lock() - locks[owner] = lock - return lock - - -def _looks_like_path(script: str) -> bool: - """Return True if the script string looks like a file path (not inline code).""" - return ( - isinstance(script, str) - and script.strip() == script - and not any(c in script for c in "\r\n\0") - and ( - script.startswith("/") - or script.startswith("./") - or script.startswith("../") - ) - ) - - -async def resolve_sandbox_script( - sandbox: Any, script: str -) -> tuple[str | None, str | None]: - """Read a file from the sandbox if *script* looks like a path. - - Returns: - (content, error) β€” content is the file text on success, - error is a message on failure. Both None means *script* - is not a path (caller should use it as-is). - """ - if not sandbox or not _looks_like_path(script): - return None, None - try: - # Use the read endpoint instead of bash("cat ...") which truncates at 25KB. - result = await asyncio.to_thread(sandbox.read, script, limit=100_000) - if result.success and result.output: - # Strip line number prefixes (read returns "N\tcontent" format) - lines = [] - for line in result.output.split("\n"): - parts = line.split("\t", 1) - lines.append(parts[1] if len(parts) == 2 else line) - return "\n".join(lines), None - return None, f"Failed to read {script} from sandbox: {result.error}" - except Exception as e: - return None, f"Failed to read {script} from sandbox: {e}" - - -async def _seed_trackio_dashboard_safe(session: Any, space_id: str) -> None: - """Idempotently seed *space_id* with trackio dashboard files using the - session's HF token. Logs progress, swallows errors β€” a failed seed should - not block sandbox creation.""" - if not session or not getattr(session, "hf_token", None): - return - loop = asyncio.get_running_loop() - - def _log(msg: str) -> None: - loop.call_soon_threadsafe( - session.event_queue.put_nowait, - Event(event_type="tool_log", data={"tool": "sandbox_create", "log": msg}), - ) - - try: - await asyncio.to_thread( - ensure_trackio_dashboard, space_id, session.hf_token, _log - ) - except Exception as e: - _log(f"trackio dashboard seed failed: {e}") - - -async def _update_persisted_sandbox_fields(session: Any, **fields: Any) -> None: - """Best-effort update of sandbox metadata on the durable session record.""" - store = getattr(session, "persistence_store", None) - session_id = getattr(session, "session_id", None) - if not (store and session_id and hasattr(store, "update_session_fields")): - return - try: - await store.update_session_fields(session_id, **fields) - except Exception as e: - logger.warning("Failed to persist sandbox metadata for %s: %s", session_id, e) - - -async def _persist_active_sandbox( - session: Any, - sandbox: Sandbox, - *, - hardware: str, -) -> None: - space_id = getattr(sandbox, "space_id", None) - if not space_id: - return - owner = space_id.split("/", 1)[0] if "/" in space_id else None - await _update_persisted_sandbox_fields( - session, - sandbox_space_id=space_id, - sandbox_hardware=hardware, - sandbox_owner=owner, - sandbox_created_at=datetime.now(timezone.utc), - sandbox_status="active", - ) - - -async def _clear_persisted_sandbox(session: Any) -> None: - await _update_persisted_sandbox_fields( - session, - sandbox_space_id=None, - sandbox_hardware=None, - sandbox_owner=None, - sandbox_created_at=None, - sandbox_status="destroyed", - ) - - -# ── Tool name mapping (short agent names β†’ Sandbox client names) ────── - - -def _cleanup_user_orphan_sandboxes( - api: HfApi, - owner: str, - log: Any, -) -> int: - """Delete stale ``sandbox-<8hex>`` Spaces in ``owner``'s account. - - "Stale" = not modified in the last hour. The naming pattern + staleness - filter together make this safe: - - * Naming: only matches ``sandbox-``, the - pattern Sandbox.create produces. Won't touch user-renamed Spaces. - * Staleness: anything modified in the last hour might still be tied - to a live session in another tab/replica, so we leave it alone. - - Runs blocking β€” call via ``asyncio.to_thread``. Best-effort: failures - are logged but never raised, so a flaky HF API never blocks creation. - """ - cutoff = datetime.now(timezone.utc) - _ORPHAN_STALE_AFTER - deleted = 0 - try: - spaces = list(api.list_spaces(author=owner, limit=200, full=True)) - except Exception as e: - log(f"orphan sweep: list_spaces failed: {e}") - return 0 - - for space in spaces: - space_name = space.id.rsplit("/", 1)[-1] - if not SANDBOX_SPACE_NAME_RE.match(space_name): - continue - - last_mod = getattr(space, "lastModified", None) or getattr( - space, "last_modified", None - ) - if isinstance(last_mod, str): - try: - last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00")) - except ValueError: - last_mod = None - if last_mod is None: - log(f"orphan sweep: skipping {space.id}; missing lastModified") - continue - if last_mod and last_mod > cutoff: - # Recent β€” could be a concurrent live session. Skip. - continue - - try: - api.delete_repo(repo_id=space.id, repo_type="space") - deleted += 1 - log(f"orphan sweep: deleted {space.id}") - except Exception as e: - log(f"orphan sweep: failed to delete {space.id}: {e}") - - if deleted: - log(f"orphan sweep: cleaned up {deleted} stale sandbox(es) before create") - return deleted - - -async def _ensure_sandbox( - session: Any, - hardware: str = DEFAULT_CPU_SANDBOX_HARDWARE, - extra_secrets: dict[str, str] | None = None, - cancel_event: threading.Event | None = None, - **create_kwargs, -) -> tuple[Sandbox | None, str | None]: - """ - Ensure a sandbox exists on the session. Auto-creates with given hardware if needed. - - Returns: - (sandbox, error_message) β€” one will be None. - """ - if session and getattr(session, "sandbox", None): - return session.sandbox, None - - if not session: - return None, "No session available." - - token = session.hf_token - if not token: - return None, "No HF token available. Cannot create sandbox." - - api = HfApi(token=token) - user_info = api.whoami() - owner = user_info.get("name", user_info.get("user", "")) - if not owner: - return None, "Could not determine HF username from token." - - create_lock = _get_sandbox_create_lock(owner) - if create_lock.locked(): - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "sandbox", - "log": "Waiting for sandbox creation slot...", - }, - ) - ) - - async with create_lock: - if getattr(session, "sandbox", None): - return session.sandbox, None - - return await _create_sandbox_locked( - session, - api=api, - owner=owner, - hardware=hardware, - extra_secrets=extra_secrets, - cancel_event=cancel_event, - **create_kwargs, - ) - - -async def _create_sandbox_locked( - session: Any, - *, - api: HfApi, - owner: str, - hardware: str, - extra_secrets: dict[str, str] | None = None, - cancel_event: threading.Event | None = None, - **create_kwargs, -) -> tuple[Sandbox | None, str | None]: - """Create the Space while the per-owner sandbox creation lock is held.""" - token = session.hf_token - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "sandbox", - "log": f"Auto-creating sandbox for {owner} ({hardware})...", - }, - ) - ) - - # Thread-safe log callback: posts tool_log events from the worker thread - loop = asyncio.get_running_loop() - - def _log(msg: str) -> None: - loop.call_soon_threadsafe( - session.event_queue.put_nowait, - Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}), - ) - - # Bridge asyncio cancel event to a threading.Event for the blocking create call. - # We poll session._cancelled from the main loop in a background task and set - # a threading.Event that Sandbox.create checks during its polling loops. - cancel_flag = cancel_event or threading.Event() - - async def _watch_cancel(): - await session._cancelled.wait() - cancel_flag.set() - - watcher_task = asyncio.create_task(_watch_cancel()) - - secrets: dict[str, str] = {"HF_TOKEN": token} - if extra_secrets: - secrets.update({k: v for k, v in extra_secrets.items() if v}) - - create_kwargs["private"] = True # enforce: overrides any caller-supplied value - kwargs = { - "owner": owner, - "hardware": hardware, - "token": token, - "secrets": secrets, - "log": _log, - "cancel_event": cancel_flag, - **create_kwargs, - } - if hardware != DEFAULT_CPU_SANDBOX_HARDWARE: - kwargs["sleep_time"] = 2700 - import time as _t - - _t_start = _t.monotonic() - try: - sb = await asyncio.to_thread(Sandbox.create, **kwargs) - except Sandbox.Cancelled: - return None, "Sandbox creation cancelled by user." - finally: - watcher_task.cancel() - - if cancel_flag.is_set(): - if getattr(sb, "_owns_space", False): - try: - await asyncio.to_thread(sb.delete) - except Exception as e: - logger.warning( - "Failed to delete cancelled sandbox %s: %s", sb.space_id, e - ) - return None, "Sandbox creation cancelled by user." - - session.sandbox = sb - session.sandbox_hardware = hardware - session.sandbox_preload_error = None - await _persist_active_sandbox(session, sb, hardware=hardware) - - # Telemetry: sandbox creation (infra consumption signal) - from agent.core import telemetry - - await telemetry.record_sandbox_create( - session, - sb, - hardware=hardware, - create_latency_s=int(_t.monotonic() - _t_start), - ) - - await session.send_event( - Event( - event_type="tool_log", - data={"tool": "sandbox", "log": f"Sandbox ready: {sb.space_id} ({sb.url})"}, - ) - ) - - return sb, None - - -def start_cpu_sandbox_preload(session: Any) -> asyncio.Task | None: - """Start a background ``cpu-basic`` sandbox for this session.""" - if not session or getattr(session, "sandbox", None): - return None - - existing_task = getattr(session, "sandbox_preload_task", None) - if existing_task and not existing_task.done(): - return existing_task - - cancel_event = threading.Event() - session.sandbox_preload_cancel_event = cancel_event - session.sandbox_preload_error = None - - async def _preload() -> Sandbox | None: - try: - sb, error = await _ensure_sandbox( - session, - hardware=DEFAULT_CPU_SANDBOX_HARDWARE, - cancel_event=cancel_event, - ) - if error: - session.sandbox_preload_error = error - return None - return sb - except asyncio.CancelledError: - cancel_event.set() - session.sandbox_preload_error = "Sandbox creation cancelled by user." - raise - except Exception as e: - session.sandbox_preload_error = f"Failed to create sandbox: {e}" - logger.warning("CPU sandbox preload failed: %s", e) - return None - - task = asyncio.create_task(_preload()) - session.sandbox_preload_task = task - return task - - -async def cancel_sandbox_preload(session: Any) -> None: - """Best-effort cancellation for an in-flight CPU sandbox preload.""" - cancel_event = getattr(session, "sandbox_preload_cancel_event", None) - if cancel_event is not None: - cancel_event.set() - - task = getattr(session, "sandbox_preload_task", None) - if not task or task.done(): - return - - current_task = asyncio.current_task() - if task is current_task: - return - - try: - await asyncio.wait_for(asyncio.shield(task), timeout=30) - except asyncio.TimeoutError: - logger.warning( - "Timed out waiting for CPU sandbox preload cancellation; " - "task is still live, cancelling asyncio wrapper" - ) - task.cancel() - except asyncio.CancelledError: - raise - except Exception: - pass - - -async def get_active_or_preloaded_sandbox( - session: Any, -) -> tuple[Sandbox | None, str | None]: - """Return the active sandbox, waiting for the startup preload if needed.""" - if not session: - return None, "No session available." - if getattr(session, "sandbox", None): - return session.sandbox, None - - task = getattr(session, "sandbox_preload_task", None) - if task: - try: - await asyncio.shield(task) - except asyncio.CancelledError: - raise - except Exception as e: - session.sandbox_preload_error = f"Failed to create sandbox: {e}" - - if getattr(session, "sandbox", None): - return session.sandbox, None - - preload_error = getattr(session, "sandbox_preload_error", None) - if preload_error: - return None, preload_error - - return None, "Sandbox is still starting. Please retry shortly." - - -async def teardown_session_sandbox(session: Any) -> None: - """Cancel sandbox preload and delete the active owned sandbox, if present.""" - if not session: - return - - await cancel_sandbox_preload(session) - - sandbox = getattr(session, "sandbox", None) - session.sandbox = None - session.sandbox_hardware = None - - if not sandbox: - return - - try: - if not getattr(sandbox, "_owns_space", False): - return - - space_id = getattr(sandbox, "space_id", None) - last_err: Exception | None = None - for attempt in range(3): - try: - logger.info( - "Deleting sandbox %s (attempt %s/3)...", - space_id, - attempt + 1, - ) - await asyncio.to_thread(sandbox.delete) - from agent.core import telemetry - - await telemetry.record_sandbox_destroy(session, sandbox) - return - except Exception as e: - last_err = e - if attempt < 2: - await asyncio.sleep(2**attempt) - logger.error( - "Failed to delete sandbox %s after 3 attempts: %s. " - "Orphan β€” sweep script will pick it up.", - space_id, - last_err, - ) - finally: - await _clear_persisted_sandbox(session) - - -# ── sandbox_create tool ────────────────────────────────────────────── - -SANDBOX_CREATE_TOOL_SPEC = { - "name": "sandbox_create", - "description": ( - "Create or replace the session sandbox when non-default hardware is needed.\n\n" - "A private cpu-basic sandbox is already started automatically for each session. " - "For normal CPU code execution, call bash/read/write/edit directly; do NOT call sandbox_create first.\n\n" - "Use sandbox_create when: you need GPU hardware, cpu-upgrade, or Trackio secrets before running code. " - "The active sandbox persists across tool calls within the session. pip install works out of the box. " - "Sandboxes are always created as private HF Spaces.\n\n" - "For ML code that uses CUDA, bf16, or model loading: use GPU hardware (t4-small minimum). " - "CPU sandboxes cannot run GPU code paths β€” your test will not catch GPU-related errors.\n\n" - "Before choosing hardware, estimate your VRAM needs (models you run, training data size). Rule of thumb: bf16/fp16 β‰ˆ 2 bytes/param, " - "fp32 β‰ˆ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n" - "Common picks: t4-small (16GB VRAM, fits ≀1-3B), a10g-small (24GB, ≀7B), a100-large (80GB, ≀30B). " - "If the model won't fit, pick larger hardware upfront β€” OOM on a sandbox wastes time.\n\n" - "If you intend to run a training script in this sandbox that uses report_to='trackio', " - "pass `trackio_space_id` (e.g. '/mlintern-<8char>') and `trackio_project` so they " - "are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n" - "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n" - ), - "parameters": { - "type": "object", - "required": [], - "additionalProperties": False, - "properties": { - "hardware": { - "type": "string", - "enum": [e.value for e in SpaceHardware], - "description": ( - "Hardware tier for the sandbox. Omit for the existing auto-started " - "cpu-basic sandbox; choose GPU/cpu-upgrade only when needed." - ), - }, - "trackio_space_id": { - "type": "string", - "description": ( - "Optional. The HF Space hosting the trackio dashboard for runs in this sandbox " - "(e.g. '/mlintern-<8char>', under YOUR HF namespace). Injected as " - "TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and " - "seeded with the trackio dashboard β€” DO NOT pre-create it via hf_repo_git, " - "that produces an empty Space that breaks the embed." - ), - }, - "trackio_project": { - "type": "string", - "description": ( - "Optional. The trackio project name. Injected as TRACKIO_PROJECT secret and " - "used by the UI to filter the embedded dashboard to this project." - ), - }, - }, - }, -} - - -async def sandbox_create_handler( - args: dict[str, Any], session: Any = None, tool_call_id: str | None = None -) -> tuple[str, bool]: - """Handle sandbox_create tool calls.""" - hardware = args.get("hardware", DEFAULT_CPU_SANDBOX_HARDWARE) - trackio_space_id = args.get("trackio_space_id") or None - trackio_project = args.get("trackio_project") or None - - async def _emit_trackio_state(sb: Sandbox) -> None: - """Tell the frontend which trackio dashboard to embed for this sandbox.""" - if not (session and tool_call_id and trackio_space_id): - return - data: dict[str, Any] = { - "tool_call_id": tool_call_id, - "tool": "sandbox_create", - "state": "running", - "trackioSpaceId": trackio_space_id, - } - if trackio_project: - data["trackioProject"] = trackio_project - await session.send_event(Event(event_type="tool_state_change", data=data)) - - preload_task = getattr(session, "sandbox_preload_task", None) - if ( - session - and not getattr(session, "sandbox", None) - and preload_task - and not preload_task.done() - and hardware == DEFAULT_CPU_SANDBOX_HARDWARE - ): - sb, error = await get_active_or_preloaded_sandbox(session) - if error: - return error, False - if sb: - await _emit_trackio_state(sb) - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Hardware: {DEFAULT_CPU_SANDBOX_HARDWARE}\n" - f"Use bash/read/write/edit to interact with it." - ), True - - if ( - session - and not getattr(session, "sandbox", None) - and preload_task - and not preload_task.done() - and hardware != DEFAULT_CPU_SANDBOX_HARDWARE - ): - await cancel_sandbox_preload(session) - - # If sandbox already exists, return its info or replace the auto CPU sandbox - if session and getattr(session, "sandbox", None): - sb = session.sandbox - active_hardware = getattr(session, "sandbox_hardware", None) - if active_hardware == hardware: - await _emit_trackio_state(sb) - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Hardware: {active_hardware}\n" - f"Use bash/read/write/edit to interact with it." - ), True - - requested_hardware = args.get("hardware") - lockout_note = "" - if ( - active_hardware == DEFAULT_CPU_SANDBOX_HARDWARE - and hardware != DEFAULT_CPU_SANDBOX_HARDWARE - ): - await teardown_session_sandbox(session) - elif requested_hardware: - lockout_note = ( - f"\nRequested hardware: {requested_hardware}\n" - "Hardware cannot be changed by calling sandbox_create again. " - "Delete the existing sandbox first if you need a different tier." - ) - await _emit_trackio_state(sb) - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"{lockout_note}\n" - f"Use bash/read/write/edit to interact with it." - ), True - else: - await _emit_trackio_state(sb) - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Hardware: {active_hardware or 'unknown'}\n" - f"Use bash/read/write/edit to interact with it." - ), True - - create_kwargs: dict[str, Any] = {} - - extra_secrets: dict[str, str] = {} - if trackio_space_id: - extra_secrets["TRACKIO_SPACE_ID"] = trackio_space_id - await _seed_trackio_dashboard_safe(session, trackio_space_id) - if trackio_project: - extra_secrets["TRACKIO_PROJECT"] = trackio_project - - try: - sb, error = await _ensure_sandbox( - session, - hardware=hardware, - extra_secrets=extra_secrets or None, - **create_kwargs, - ) - except Exception as e: - return f"Failed to create sandbox: {e}", False - - if error: - return error, False - - await _emit_trackio_state(sb) - - return ( - f"Sandbox created: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Hardware: {hardware}\n" - "Visibility: private\n" - f"Use bash/read/write/edit to interact with it." - ), True - - -def _make_tool_handler(sandbox_tool_name: str): - """Factory: create a handler for a sandbox operation tool.""" - - async def handler(args: dict[str, Any], session: Any = None) -> tuple[str, bool]: - sb, error = await get_active_or_preloaded_sandbox(session) - if error: - return error, False - if not sb: - return "Sandbox is still starting. Please retry shortly.", False - - try: - if sandbox_tool_name == "bash" and args.get("command"): - args = { - **args, - "command": wrap_shell_command_with_hub_artifact_bootstrap( - args["command"], - session, - ), - } - result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args) - if result.success: - output = result.output or "(no output)" - return output, True - else: - error_msg = result.error or "Unknown error" - output = result.output - if output: - return f"{output}\n\nERROR: {error_msg}", False - return f"ERROR: {error_msg}", False - except Exception as e: - return f"Sandbox operation failed: {e}", False - - return handler - - -def get_sandbox_tools(): - """Return all 5 sandbox ToolSpecs (sandbox_create + 4 operation tools).""" - from agent.core.tools import ToolSpec - - tools = [] - - # sandbox_create (for GPU or other non-default hardware) - tools.append( - ToolSpec( - name=SANDBOX_CREATE_TOOL_SPEC["name"], - description=SANDBOX_CREATE_TOOL_SPEC["description"], - parameters=SANDBOX_CREATE_TOOL_SPEC["parameters"], - handler=sandbox_create_handler, - ) - ) - - # Operation tools (auto-execute, no approval needed) - for name in Sandbox.TOOLS.keys(): - spec = Sandbox.TOOLS[name] - description = ( - "Uses the session's active sandbox. A private cpu-basic sandbox is " - "started automatically for normal CPU work; call sandbox_create only " - "for GPU or other non-default hardware.\n\n" + spec["description"] - ) - tools.append( - ToolSpec( - name=name, - description=description, - parameters=spec["parameters"], - handler=_make_tool_handler(name), - ) - ) - - return tools diff --git a/agent/tools/trackio_seed.py b/agent/tools/trackio_seed.py deleted file mode 100644 index 1062e1b5eda2701833aad7c1c895727d7fbd191e..0000000000000000000000000000000000000000 --- a/agent/tools/trackio_seed.py +++ /dev/null @@ -1,205 +0,0 @@ -"""Seed an HF Space with the trackio dashboard. - -Background: when the agent creates a Space via `hf_repo_git create_repo` (or -the user pre-creates one), it ships with no app.py β€” so the iframe shows the -default Gradio "Get started" template instead of charts. Trackio's `init()` -detects the existing Space but does NOT auto-bootstrap dashboard files into it, -so the dashboard never materializes. - -This helper writes the three files trackio's runtime expects (README.md, -requirements.txt, app.py) into the Space, idempotently, BEFORE the job that -will call `trackio.init()` runs. We deliberately omit `hf_oauth: true` from -the README so the embedded iframe in ml-intern renders without a login click β€” -per-user privacy is enforced by namespace ownership instead. - -Beyond the dashboard files, the helper also creates the metrics bucket and -mounts it on the Space at `/data` (with `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` -Space variables). Without this, the running job writes metrics into a bucket -that the dashboard Space can't read, and the iframe shows "No projects". -""" - -from __future__ import annotations - -import io -from typing import Callable, Optional - -from huggingface_hub import ( - HfApi, - Volume, - add_space_variable, - create_bucket, - create_repo, -) -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError - - -_README = """--- -title: Trackio Dashboard -emoji: πŸ“Š -colorFrom: pink -colorTo: gray -sdk: gradio -app_file: app.py -pinned: false -tags: - - trackio ---- - -Embedded trackio dashboard for ml-intern runs. -""" - -_REQUIREMENTS = "trackio\n" -_APP_PY = "import trackio\ntrackio.show()\n" - -# ml-intern brand mark surfaced inside the trackio dashboard. Trackio reads -# `TRACKIO_LOGO_LIGHT_URL` / `TRACKIO_LOGO_DARK_URL` from Space variables and -# renders them in place of its own logo. We point at the publicly-resolvable -# copy on the smolagents/ml-intern Space repo so any seeded dashboard inherits -# the ml-intern branding without each user having to host the asset. -_LOGO_URL = ( - "https://huggingface.co/spaces/smolagents/ml-intern/" - "resolve/main/frontend/public/smolagents.webp" -) - -_FILES = { - "README.md": _README, - "requirements.txt": _REQUIREMENTS, - "app.py": _APP_PY, -} - - -def _already_seeded(api: HfApi, space_id: str) -> bool: - """Cheap check: does the Space already have a trackio dashboard app.py? - - Avoids re-uploading the same three files on every job submission. We look - for the literal `trackio.show` call which is the load-bearing line β€” any - other app.py shape (the default gradio shell, a stale custom one) means - we should re-seed. - """ - try: - path = api.hf_hub_download( - repo_id=space_id, repo_type="space", filename="app.py" - ) - except (EntryNotFoundError, RepositoryNotFoundError, OSError): - return False - try: - with open(path, "r", encoding="utf-8") as f: - return "trackio.show" in f.read() - except OSError: - return False - - -def _get_space_volumes(api: HfApi, space_id: str) -> list: - """Return mounted volumes for a Space. - - `get_space_runtime()` doesn't always populate `volumes` even when the - mount exists; mirror trackio's fallback to `space_info().runtime.volumes`. - """ - runtime = api.get_space_runtime(space_id) - if getattr(runtime, "volumes", None): - return list(runtime.volumes) - info = api.space_info(space_id) - if info.runtime and getattr(info.runtime, "volumes", None): - return list(info.runtime.volumes) - return [] - - -def _ensure_bucket_mounted( - api: HfApi, - space_id: str, - bucket_id: str, - hf_token: str, - log: Optional[Callable[[str], None]] = None, -) -> None: - """Create the bucket if missing, mount it at `/data` on the Space, and - set the `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` Space variables. Idempotent β€” - skips work that has already been done. - """ - create_bucket(bucket_id, private=True, exist_ok=True, token=hf_token) - - existing = _get_space_volumes(api, space_id) - already_mounted = any( - getattr(v, "type", None) == "bucket" - and getattr(v, "source", None) == bucket_id - and getattr(v, "mount_path", None) == "/data" - for v in existing - ) - if not already_mounted: - preserved = [ - v - for v in existing - if not ( - getattr(v, "type", None) == "bucket" - and ( - getattr(v, "source", None) == bucket_id - or getattr(v, "mount_path", None) == "/data" - ) - ) - ] - api.set_space_volumes( - space_id, - preserved + [Volume(type="bucket", source=bucket_id, mount_path="/data")], - ) - if log: - log(f"mounted bucket {bucket_id} at /data on {space_id}") - - variables = api.get_space_variables(space_id) - desired = { - "TRACKIO_DIR": "/data/trackio", - "TRACKIO_BUCKET_ID": bucket_id, - "TRACKIO_LOGO_LIGHT_URL": _LOGO_URL, - "TRACKIO_LOGO_DARK_URL": _LOGO_URL, - } - for key, value in desired.items(): - if getattr(variables.get(key), "value", None) != value: - add_space_variable(space_id, key, value, token=hf_token) - - -def ensure_trackio_dashboard( - space_id: str, - hf_token: str, - log: Optional[Callable[[str], None]] = None, -) -> bool: - """Make sure *space_id* is fully wired for trackio: - 1. Space exists with our dashboard files (README without `hf_oauth`, - `requirements.txt`, `app.py` calling `trackio.show`). - 2. Bucket `-bucket` exists, is mounted at `/data`, and the - Space has `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` variables set. - - Idempotent β€” re-running is cheap. Returns True if any seeding happened - in step (1), False if the dashboard files were already in place. Bucket - mount is always re-checked. - """ - api = HfApi(token=hf_token) - - create_repo( - repo_id=space_id, - repo_type="space", - space_sdk="gradio", - exist_ok=True, - token=hf_token, - ) - - seeded_files = False - if _already_seeded(api, space_id): - if log: - log(f"trackio dashboard already seeded on {space_id}") - else: - if log: - log(f"seeding trackio dashboard files into {space_id}") - for path_in_repo, content in _FILES.items(): - api.upload_file( - path_or_fileobj=io.BytesIO(content.encode("utf-8")), - path_in_repo=path_in_repo, - repo_id=space_id, - repo_type="space", - commit_message=f"ml-intern: seed trackio dashboard ({path_in_repo})", - ) - seeded_files = True - - bucket_id = f"{space_id}-bucket" - _ensure_bucket_mounted(api, space_id, bucket_id, hf_token, log) - - if log: - log(f"trackio dashboard ready: https://huggingface.co/spaces/{space_id}") - return seeded_files diff --git a/agent/tools/web_search_tool.py b/agent/tools/web_search_tool.py deleted file mode 100644 index 5c18410855bebdee305997d90de4c9e56f942461..0000000000000000000000000000000000000000 --- a/agent/tools/web_search_tool.py +++ /dev/null @@ -1,276 +0,0 @@ -"""DuckDuckGo HTML web search tool. - -This mirrors Claw Code's Rust WebSearch behavior: fetch DuckDuckGo's HTML -endpoint, extract result links, optionally filter domains, and return a -JSON payload the model can cite. -""" - -from __future__ import annotations - -import asyncio -import html -import json -import os -import time -from dataclasses import dataclass -from html.parser import HTMLParser -from typing import Any -from urllib.parse import parse_qsl, parse_qs, urlencode, urlparse, urlunparse - -import requests - -DEFAULT_SEARCH_URL = "https://html.duckduckgo.com/html/" -WEB_SEARCH_BASE_URL_ENV = "CLAWD_WEB_SEARCH_BASE_URL" -USER_AGENT = "clawd-rust-tools/0.1" -REQUEST_TIMEOUT_SECONDS = 20 -MAX_RESULTS = 8 - - -@dataclass(frozen=True) -class SearchHit: - title: str - url: str - - def as_json(self) -> dict[str, str]: - return {"title": self.title, "url": self.url} - - -class _AnchorParser(HTMLParser): - def __init__(self, *, require_result_class: bool) -> None: - super().__init__(convert_charrefs=True) - self.require_result_class = require_result_class - self.hits: list[tuple[str, str]] = [] - self._active_href: str | None = None - self._active_text: list[str] = [] - - def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: - if tag.lower() != "a": - return - attr_map = {key.lower(): value or "" for key, value in attrs} - href = attr_map.get("href") - if not href: - return - if self.require_result_class and "result__a" not in attr_map.get("class", ""): - return - self._active_href = href - self._active_text = [] - - def handle_data(self, data: str) -> None: - if self._active_href is not None: - self._active_text.append(data) - - def handle_entityref(self, name: str) -> None: - if self._active_href is not None: - self._active_text.append(f"&{name};") - - def handle_charref(self, name: str) -> None: - if self._active_href is not None: - self._active_text.append(f"&#{name};") - - def handle_endtag(self, tag: str) -> None: - if tag.lower() != "a" or self._active_href is None: - return - title = collapse_whitespace(html.unescape("".join(self._active_text))).strip() - self.hits.append((self._active_href, title)) - self._active_href = None - self._active_text = [] - - -def build_search_url(query: str) -> str: - base = os.environ.get(WEB_SEARCH_BASE_URL_ENV, DEFAULT_SEARCH_URL) - parsed = urlparse(base) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: - raise ValueError(f"invalid search base URL: {base}") - - query_pairs = parse_qsl(parsed.query, keep_blank_values=True) - query_pairs.append(("q", query)) - return urlunparse(parsed._replace(query=urlencode(query_pairs))) - - -def collapse_whitespace(value: str) -> str: - return " ".join(value.split()) - - -def decode_duckduckgo_redirect(url: str) -> str | None: - if url.startswith("http://") or url.startswith("https://"): - return html.unescape(url) - if url.startswith("//"): - joined = f"https:{url}" - elif url.startswith("/"): - joined = f"https://duckduckgo.com{url}" - else: - return None - - parsed = urlparse(joined) - if parsed.path in {"/l", "/l/"}: - uddg = parse_qs(parsed.query).get("uddg", []) - if uddg: - return html.unescape(uddg[0]) - return joined - - -def _extract_links(search_html: str, *, require_result_class: bool) -> list[SearchHit]: - parser = _AnchorParser(require_result_class=require_result_class) - parser.feed(search_html) - - hits: list[SearchHit] = [] - for raw_url, title in parser.hits: - if not title: - continue - decoded_url = decode_duckduckgo_redirect(raw_url) - if decoded_url and ( - decoded_url.startswith("http://") or decoded_url.startswith("https://") - ): - hits.append(SearchHit(title=title, url=decoded_url)) - return hits - - -def extract_search_hits(search_html: str) -> list[SearchHit]: - return _extract_links(search_html, require_result_class=True) - - -def extract_search_hits_from_generic_links(search_html: str) -> list[SearchHit]: - return _extract_links(search_html, require_result_class=False) - - -def normalize_domain_filter(domain: str) -> str: - trimmed = domain.strip() - parsed = urlparse(trimmed) - candidate = parsed.hostname if parsed.scheme and parsed.hostname else trimmed - return candidate.strip().lstrip(".").rstrip("/").lower() - - -def host_matches_list(url: str, domains: list[str]) -> bool: - host = urlparse(url).hostname - if not host: - return False - normalized_host = host.lower() - for domain in domains: - normalized = normalize_domain_filter(domain) - if normalized and ( - normalized_host == normalized or normalized_host.endswith(f".{normalized}") - ): - return True - return False - - -def dedupe_hits(hits: list[SearchHit]) -> list[SearchHit]: - seen: set[str] = set() - deduped: list[SearchHit] = [] - for hit in hits: - if hit.url in seen: - continue - seen.add(hit.url) - deduped.append(hit) - return deduped - - -def execute_web_search( - query: str, - allowed_domains: list[str] | None = None, - blocked_domains: list[str] | None = None, - tool_use_id: str = "web_search_1", -) -> dict[str, Any]: - started = time.monotonic() - search_url = build_search_url(query) - response = requests.get( - search_url, - headers={"User-Agent": USER_AGENT}, - timeout=REQUEST_TIMEOUT_SECONDS, - allow_redirects=True, - ) - - hits = extract_search_hits(response.text) - if not hits and urlparse(response.url or search_url).hostname: - hits = extract_search_hits_from_generic_links(response.text) - - if allowed_domains is not None: - hits = [hit for hit in hits if host_matches_list(hit.url, allowed_domains)] - if blocked_domains is not None: - hits = [hit for hit in hits if not host_matches_list(hit.url, blocked_domains)] - - hits = dedupe_hits(hits)[:MAX_RESULTS] - rendered_hits = "\n".join(f"- [{hit.title}]({hit.url})" for hit in hits) - if hits: - summary = ( - f"Search results for {query!r}. Include a Sources section in the final answer.\n" - f"{rendered_hits}" - ) - else: - summary = f"No web search results matched the query {query!r}." - - return { - "query": query, - "results": [ - summary, - { - "tool_use_id": tool_use_id, - "content": [hit.as_json() for hit in hits], - }, - ], - "durationSeconds": time.monotonic() - started, - } - - -WEB_SEARCH_TOOL_SPEC = { - "name": "web_search", - "description": "Search the web for current information and return cited results.", - "parameters": { - "type": "object", - "properties": { - "query": {"type": "string", "minLength": 2}, - "allowed_domains": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional allowlist of domains or URLs. Subdomains match.", - }, - "blocked_domains": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional blocklist of domains or URLs. Subdomains match.", - }, - }, - "required": ["query"], - "additionalProperties": False, - }, -} - - -def _optional_string_list(arguments: dict[str, Any], key: str) -> list[str] | None: - value = arguments.get(key) - if value is None: - return None - if not isinstance(value, list) or not all(isinstance(item, str) for item in value): - raise ValueError(f"{key} must be an array of strings") - return value - - -async def web_search_handler( - arguments: dict[str, Any], - session: Any = None, - tool_call_id: str | None = None, - **_kw: Any, -) -> tuple[str, bool]: - query_value = arguments.get("query", "") - if not isinstance(query_value, str): - return ( - "Error: web_search requires a query string with at least 2 characters.", - False, - ) - - query = query_value.strip() - if len(query) < 2: - return "Error: web_search requires a query with at least 2 characters.", False - - try: - output = await asyncio.to_thread( - execute_web_search, - query=query, - allowed_domains=_optional_string_list(arguments, "allowed_domains"), - blocked_domains=_optional_string_list(arguments, "blocked_domains"), - tool_use_id=tool_call_id or "web_search_1", - ) - except Exception as exc: - return f"Error executing web search: {exc}", False - - return json.dumps(output, indent=2), True diff --git a/agent/utils/boot_timing.py b/agent/utils/boot_timing.py deleted file mode 100644 index 0c0884d03380f07a05a26c059247fa1b393552e9..0000000000000000000000000000000000000000 --- a/agent/utils/boot_timing.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Shared timing and color helpers for startup visual effects.""" - -import math - - -def settle_curve(progress: float, sharpness: float = 3.0) -> float: - """Return noise amount in range 1..0 for normalized progress 0..1.""" - t = max(0.0, min(1.0, progress)) - return math.exp(-sharpness * t) - - -def warm_gold_from_white(progress: float) -> tuple[int, int, int]: - """Interpolate from white to warm gold for progress 0..1.""" - t = max(0.0, min(1.0, progress)) - return 255, int(255 - 55 * t), int(255 - 175 * t) diff --git a/agent/utils/braille.py b/agent/utils/braille.py deleted file mode 100644 index 4621b735b7cff25d453afbc93f443f2bae4e7e4b..0000000000000000000000000000000000000000 --- a/agent/utils/braille.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Braille-character canvas for high-resolution terminal graphics. - -Each terminal cell maps to a 2x4 dot grid using Unicode braille characters -(U+2800–U+28FF), giving 2Γ— horizontal and 4Γ— vertical resolution. -""" - -# Braille dot positions: (0,0) (1,0) dots 1,4 -# (0,1) (1,1) dots 2,5 -# (0,2) (1,2) dots 3,6 -# (0,3) (1,3) dots 7,8 -_DOT_MAP = ( - (0x01, 0x08), - (0x02, 0x10), - (0x04, 0x20), - (0x40, 0x80), -) - - -class BrailleCanvas: - """A pixel canvas that renders to braille characters.""" - - def __init__(self, term_width: int, term_height: int): - self.term_width = term_width - self.term_height = term_height - self.pixel_width = term_width * 2 - self.pixel_height = term_height * 4 - self._buf = bytearray(term_width * term_height) - - def clear(self) -> None: - for i in range(len(self._buf)): - self._buf[i] = 0 - - def set_pixel(self, x: int, y: int) -> None: - if 0 <= x < self.pixel_width and 0 <= y < self.pixel_height: - cx, rx = divmod(x, 2) - cy, ry = divmod(y, 4) - self._buf[cy * self.term_width + cx] |= _DOT_MAP[ry][rx] - - def render(self) -> list[str]: - lines = [] - for row in range(self.term_height): - offset = row * self.term_width - line = "".join( - chr(0x2800 + self._buf[offset + col]) for col in range(self.term_width) - ) - lines.append(line) - return lines - - -# ── Bitmap font (5Γ—7 uppercase + digits) ────────────────────────────── - -_FONT: dict[str, list[str]] = {} - - -def _define_font() -> None: - """Define a simple 5Γ—7 bitmap font for uppercase ASCII.""" - glyphs = { - "A": [" ## ", "# #", "# #", "####", "# #", "# #", "# #"], - "B": ["### ", "# #", "# #", "### ", "# #", "# #", "### "], - "C": [" ## ", "# #", "# ", "# ", "# ", "# #", " ## "], - "D": ["### ", "# #", "# #", "# #", "# #", "# #", "### "], - "E": ["####", "# ", "# ", "### ", "# ", "# ", "####"], - "F": ["####", "# ", "# ", "### ", "# ", "# ", "# "], - "G": [" ## ", "# #", "# ", "# ##", "# #", "# #", " ###"], - "H": ["# #", "# #", "# #", "####", "# #", "# #", "# #"], - "I": ["###", " # ", " # ", " # ", " # ", " # ", "###"], - "J": [" ##", " # ", " # ", " # ", " # ", "# # ", " # "], - "K": ["# #", "# # ", "## ", "## ", "# # ", "# #", "# #"], - "L": ["# ", "# ", "# ", "# ", "# ", "# ", "####"], - "M": ["# #", "## ##", "# # #", "# # #", "# #", "# #", "# #"], - "N": ["# #", "## #", "## #", "# ##", "# ##", "# #", "# #"], - "O": [" ## ", "# #", "# #", "# #", "# #", "# #", " ## "], - "P": ["### ", "# #", "# #", "### ", "# ", "# ", "# "], - "Q": [" ## ", "# #", "# #", "# #", "# ##", "# #", " ## "], - "R": ["### ", "# #", "# #", "### ", "# # ", "# #", "# #"], - "S": [" ## ", "# #", "# ", " ## ", " #", "# #", " ## "], - "T": ["#####", " # ", " # ", " # ", " # ", " # ", " # "], - "U": ["# #", "# #", "# #", "# #", "# #", "# #", " ## "], - "V": ["# #", "# #", "# #", " # # ", " # # ", " # ", " # "], - "W": ["# #", "# #", "# #", "# # #", "# # #", "## ##", "# #"], - "X": ["# #", "# #", " ## ", " ## ", " ## ", "# #", "# #"], - "Y": ["# #", "# #", " # # ", " # ", " # ", " # ", " # "], - "Z": ["####", " #", " # ", " # ", "# ", "# ", "####"], - " ": [" ", " ", " ", " ", " ", " ", " "], - "0": [" ## ", "# #", "# #", "# #", "# #", "# #", " ## "], - "1": [" # ", "## ", " # ", " # ", " # ", " # ", "###"], - "2": [" ## ", "# #", " #", " # ", " # ", "# ", "####"], - "3": [" ## ", "# #", " #", " ## ", " #", "# #", " ## "], - "4": ["# #", "# #", "# #", "####", " #", " #", " #"], - "5": ["####", "# ", "### ", " #", " #", "# #", " ## "], - "6": [" ## ", "# ", "### ", "# #", "# #", "# #", " ## "], - "7": ["####", " #", " # ", " # ", " # ", " # ", " # "], - "8": [" ## ", "# #", "# #", " ## ", "# #", "# #", " ## "], - "9": [" ## ", "# #", "# #", " ###", " #", " #", " ## "], - } - _FONT.update(glyphs) - - -_define_font() - - -def text_to_pixels(text: str, scale: int = 1) -> list[tuple[int, int]]: - """Convert text string to a list of (x, y) pixel positions using bitmap font.""" - pixels = [] - cursor_x = 0 - for ch in text.upper(): - glyph = _FONT.get(ch) - if glyph is None: - cursor_x += 4 * scale - continue - for row_idx, row in enumerate(glyph): - for col_idx, cell in enumerate(row): - if cell == "#": - for sy in range(scale): - for sx in range(scale): - pixels.append( - (cursor_x + col_idx * scale + sx, row_idx * scale + sy) - ) - glyph_width = max(len(r) for r in glyph) - cursor_x += (glyph_width + 1) * scale - return pixels diff --git a/agent/utils/crt_boot.py b/agent/utils/crt_boot.py deleted file mode 100644 index da0867188961ff08952005c7d098879dfd2a4279..0000000000000000000000000000000000000000 --- a/agent/utils/crt_boot.py +++ /dev/null @@ -1,116 +0,0 @@ -"""CRT / glitch boot sequence effect for CLI startup. - -Simulates an old CRT terminal booting up: text appearing character by character -with noise artifacts, then settling into a clean display. -""" - -import random -import time - -from rich.console import Console -from rich.text import Text -from rich.live import Live - -from agent.utils.boot_timing import settle_curve - - -def _glitch_text(text: str, intensity: float, rng: random.Random) -> str: - """Add random glitch characters to text.""" - glitch_chars = "β–ˆβ–“β–’β–‘β”ƒβ”«β”£β•‹β•β•Žβ”€β”β”…β”„" - result = list(text) - for i in range(len(result)): - if rng.random() < intensity: - result[i] = rng.choice(glitch_chars) - return "".join(result) - - -def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> None: - """Run the CRT boot sequence effect. - - Args: - console: Rich console instance. - boot_lines: List of (text, rich_style) tuples to display. - """ - term_height = min(console.height - 2, 40) - rng = random.Random(42) - - with Live(console=console, refresh_per_second=30, transient=True) as live: - displayed_lines: list[tuple[str, str]] = [] - - for line_text, line_style in boot_lines: - if not line_text: - displayed_lines.append(("", "")) - continue - - line_len = max(1, len(line_text)) - # Type out each character - for char_idx in range(len(line_text) + 1): - result = Text() - progress = char_idx / line_len - noise = settle_curve(progress) - prev_glitch_chance = 0.01 + 0.06 * noise - prev_glitch_intensity = 0.02 + 0.12 * noise - scanline_chance = 0.005 + 0.03 * noise - - # Render previously completed lines - for prev_text, prev_style in displayed_lines: - if rng.random() < prev_glitch_chance: - result.append( - _glitch_text(prev_text, prev_glitch_intensity, rng), - style=prev_style, - ) - else: - result.append(prev_text, style=prev_style) - result.append("\n") - - # Current line being typed - typed = line_text[:char_idx] - cursor = "β–ˆ" if char_idx < len(line_text) else "" - - # Noise after cursor - noise_tail = "" - if char_idx < len(line_text): - noise_len = rng.randint(0, int(1 + 5 * noise)) - noise_tail = "".join(rng.choice("β–‘β–’β–“") for _ in range(noise_len)) - - result.append(typed, style=line_style) - result.append(cursor, style="bold rgb(255,200,80)") - result.append(noise_tail, style="dim rgb(180,140,40)") - result.append("\n") - - # Faint scanlines in remaining space - remaining = term_height - len(displayed_lines) - 2 - for _ in range(max(0, remaining)): - if rng.random() < scanline_chance: - scan_len = rng.randint(5, 30) - result.append("─" * scan_len, style="dim rgb(180,140,40)") - result.append("\n") - - live.update(result) - - # Variable typing speed - if line_text[char_idx - 1 : char_idx] in " .": - time.sleep(0.025) - else: - time.sleep(0.010) - - displayed_lines.append((line_text, line_style)) - time.sleep(0.06) - - # Hold with blinking cursor - for frame in range(20): - result = Text() - for prev_text, prev_style in displayed_lines: - result.append(prev_text, style=prev_style) - result.append("\n") - if frame % 8 < 4: - result.append("β–ˆ", style="rgb(255,200,80)") - live.update(result) - time.sleep(0.05) - - # Print final clean frame - final = Text() - for prev_text, prev_style in displayed_lines: - final.append(prev_text, style=prev_style) - final.append("\n") - console.print(final) diff --git a/agent/utils/particle_logo.py b/agent/utils/particle_logo.py deleted file mode 100644 index 9c3338152a8b2fd29031c4eadaa19e9078f6da2b..0000000000000000000000000000000000000000 --- a/agent/utils/particle_logo.py +++ /dev/null @@ -1,230 +0,0 @@ -"""Particle coalesce effect for the HUGGING FACE ML INTERN logo. - -Random particles swirl in from the edges, converge to form the text -"HUGGING FACE / ML INTERN", hold briefly, then the final frame is printed. -Rendered with braille characters for high detail. - -Based on Leandro's particle_coalesce.py demo. -""" - -import math -import random -import time - -from rich.console import Console -from rich.text import Text -from rich.align import Align -from rich.live import Live - -from agent.utils.braille import BrailleCanvas, text_to_pixels -from agent.utils.boot_timing import settle_curve, warm_gold_from_white - - -class Particle: - __slots__ = ("x", "y", "target_x", "target_y", "vx", "vy", "phase", "delay") - - def __init__( - self, x: float, y: float, target_x: float, target_y: float, delay: float = 0 - ): - self.x = x - self.y = y - self.target_x = target_x - self.target_y = target_y - self.vx = 0.0 - self.vy = 0.0 - self.phase = random.uniform(0, math.pi * 2) - self.delay = delay - - def update_converge(self, t: float, strength: float = 0.08, damping: float = 0.92): - """Move toward target with spring-like physics.""" - if t < self.delay: - # Still in swirl phase - self.x += self.vx - self.y += self.vy - self.vx *= 0.99 - self.vy *= 0.99 - # Gentle spiral - angle = self.phase + t * 2 - self.vx += math.cos(angle) * 0.3 - self.vy += math.sin(angle) * 0.3 - return - - # Spring toward target - dx = self.target_x - self.x - dy = self.target_y - self.y - self.vx += dx * strength - self.vy += dy * strength - self.vx *= damping - self.vy *= damping - self.x += self.vx - self.y += self.vy - - @property - def at_target(self) -> bool: - return abs(self.x - self.target_x) < 1.5 and abs(self.y - self.target_y) < 1.5 - - -def run_particle_logo(console: Console, hold_seconds: float = 1.5) -> None: - """Run the particle coalesce effect.""" - term_width = min(console.width, 120) - term_height = min(console.height - 4, 35) - - canvas = BrailleCanvas(term_width, term_height) - - # Get target positions from text - text_pixels_line1 = text_to_pixels("HUGGING FACE", scale=2) - text_pixels_line2 = text_to_pixels("ML INTERN", scale=2) - - # Calculate dimensions for centering - def get_bounds(pixels): - if not pixels: - return 0, 0, 0, 0 - xs = [p[0] for p in pixels] - ys = [p[1] for p in pixels] - return min(xs), max(xs), min(ys), max(ys) - - min_x1, max_x1, min_y1, max_y1 = get_bounds(text_pixels_line1) - min_x2, max_x2, min_y2, max_y2 = get_bounds(text_pixels_line2) - - w1, h1 = max_x1 - min_x1 + 1, max_y1 - min_y1 + 1 - w2, h2 = max_x2 - min_x2 + 1, max_y2 - min_y2 + 1 - - total_h = h1 + 6 + h2 # gap between lines - start_y = (canvas.pixel_height - total_h) // 2 - - # Center line 1 - offset_x1 = (canvas.pixel_width - w1) // 2 - min_x1 - offset_y1 = start_y - min_y1 - targets_1 = [(p[0] + offset_x1, p[1] + offset_y1) for p in text_pixels_line1] - - # Center line 2 - offset_x2 = (canvas.pixel_width - w2) // 2 - min_x2 - offset_y2 = start_y + h1 + 6 - min_y2 - targets_2 = [(p[0] + offset_x2, p[1] + offset_y2) for p in text_pixels_line2] - - all_targets = targets_1 + targets_2 - - # Subsample for performance β€” take every Nth pixel - step = max(1, len(all_targets) // 1500) - sampled_targets = all_targets[::step] - - # Create particles at random edge positions - rng = random.Random(42) - particles = [] - pw, ph = canvas.pixel_width, canvas.pixel_height - - for i, (tx, ty) in enumerate(sampled_targets): - # Spawn from random edge - side = rng.choice(["top", "bottom", "left", "right"]) - if side == "top": - sx, sy = rng.uniform(0, pw), rng.uniform(-20, -5) - elif side == "bottom": - sx, sy = rng.uniform(0, pw), rng.uniform(ph + 5, ph + 20) - elif side == "left": - sx, sy = rng.uniform(-20, -5), rng.uniform(0, ph) - else: - sx, sy = rng.uniform(pw + 5, pw + 20), rng.uniform(0, ph) - - delay = rng.uniform(0, 0.4) # staggered start - p = Particle(sx, sy, tx, ty, delay=delay) - # Initial velocity β€” gentle swirl - angle = math.atan2(ph / 2 - sy, pw / 2 - sx) + rng.gauss(0, 0.8) - speed = rng.uniform(1.0, 2.5) - p.vx = math.cos(angle) * speed - p.vy = math.sin(angle) * speed - particles.append(p) - - # Also add some extra ambient particles that never converge - ambient = [] - for _ in range(200): - ax = rng.uniform(0, pw) - ay = rng.uniform(0, ph) - ap = Particle(ax, ay, ax, ay) - ap.vx = rng.gauss(0, 1) - ap.vy = rng.gauss(0, 1) - ambient.append(ap) - - # Timing: 1s converge + 2s hold = 3s total - fps = 24 - converge_frames = int(fps * 0.9) - hold_frames = int(fps * hold_seconds) - total_frames = converge_frames + hold_frames - - with Live(console=console, refresh_per_second=fps, transient=True) as live: - for frame in range(total_frames): - canvas.clear() - t = frame * 0.03 - - # Update ambient particles (always drifting) - for ap in ambient: - ap.x += ap.vx + math.sin(t + ap.phase) * 0.5 - ap.y += ap.vy + math.cos(t + ap.phase * 1.3) * 0.5 - # Wrap around - ap.x = ap.x % pw - ap.y = ap.y % ph - - # Fade out ambient during hold phase - if frame < converge_frames: - alpha = 0.3 + 0.2 * math.sin(t * 2 + ap.phase) - else: - fade = (frame - converge_frames) / hold_frames - alpha = (0.3 + 0.2 * math.sin(t * 2 + ap.phase)) * (1 - fade) - if alpha > 0.25: - canvas.set_pixel(int(ap.x), int(ap.y)) - - if frame < converge_frames: - # Converge phase - progress = frame / converge_frames - noise = settle_curve(progress) - for p in particles: - p.update_converge(t, strength=0.06, damping=0.90) - canvas.set_pixel(int(p.x), int(p.y)) - - # Trail effect - trail_scale = 0.2 + 0.5 * noise - trail_x = int(p.x - p.vx * trail_scale) - trail_y = int(p.y - p.vy * trail_scale) - canvas.set_pixel(trail_x, trail_y) - - # Color transitions from white to warm gold - r, g, b = warm_gold_from_white(progress) - else: - # Hold phase β€” settle into solid logo - settle_t = (frame - converge_frames) / hold_frames - for p in particles: - # Jitter decays to zero - jitter = (1 - settle_t) * 0.7 - jx = p.target_x + math.sin(t * 3 + p.phase) * jitter - jy = p.target_y + math.cos(t * 3 + p.phase * 1.5) * jitter - canvas.set_pixel(int(jx), int(jy)) - canvas.set_pixel(int(p.target_x), int(p.target_y)) - - r, g, b = 255, 200, 80 - - # Render with color - lines = canvas.render() - result = Text() - for line in lines: - for ch in line: - if ch == chr(0x2800): - result.append(ch) - else: - result.append(ch, style=f"rgb({r},{g},{b})") - result.append("\n") - - live.update(Align.center(result)) - time.sleep(1.0 / fps) - - # Print final settled frame - canvas.clear() - for p in particles: - canvas.set_pixel(int(p.target_x), int(p.target_y)) - final = Text() - for line in canvas.render(): - for ch in line: - if ch == chr(0x2800): - final.append(ch) - else: - final.append(ch, style="rgb(255,200,80)") - final.append("\n") - console.print(Align.center(final)) diff --git a/agent/utils/reliability_checks.py b/agent/utils/reliability_checks.py index 3ed76d72b3517c077144d2c659add85f7caf547e..80dc8eaa5f422c866b8f3943b9457895af923308 100644 --- a/agent/utils/reliability_checks.py +++ b/agent/utils/reliability_checks.py @@ -1,5 +1,7 @@ """Reliability checks for job submissions and other operations""" +from agent.utils.terminal_display import Colors + def check_training_script_save_pattern(script: str) -> str | None: """Check if a training script properly saves models.""" @@ -7,8 +9,8 @@ def check_training_script_save_pattern(script: str) -> str | None: has_push_to_hub = "push_to_hub" in script if has_from_pretrained and not has_push_to_hub: - return "\n\033[91mWARNING: No model save detected in this script. Ensure this is intentional.\033[0m" + return f"\n{Colors.RED}WARNING: We've detected that no model will be saved at the end of this training script. Please ensure this is what you want.{Colors.RESET}" elif has_from_pretrained and has_push_to_hub: - return "\n\033[92mModel will be pushed to hub after training.\033[0m" + return f"\n{Colors.GREEN}We've detected that a model will be pushed to hub at the end of this training.{Colors.RESET}" return None diff --git a/agent/utils/terminal_display.py b/agent/utils/terminal_display.py index a10ac33f7db402dab0fa6b10b04cd974aaf85509..84d47465c5608ea62b3363dd81f82d489552aca3 100644 --- a/agent/utils/terminal_display.py +++ b/agent/utils/terminal_display.py @@ -1,533 +1,155 @@ """ -Terminal display utilities β€” rich-powered CLI formatting. +Terminal display utilities with colors and formatting """ -import asyncio -import re - -from rich.console import Console -from rich.markdown import Heading, Markdown -from rich.panel import Panel -from rich.theme import Theme - - -class _LeftHeading(Heading): - """Rich's default Markdown renders h1/h2 centered via Align.center. - Yield the styled text directly so headings stay left-aligned.""" - - def __rich_console__(self, console, options): - self.text.justify = "left" - yield self.text - - -Markdown.elements["heading_open"] = _LeftHeading +# ANSI color codes +class Colors: + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + MAGENTA = "\033[95m" + CYAN = "\033[96m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + RESET = "\033[0m" -_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") - -def _clip_to_width(s: str, width: int) -> str: - """Truncate a string to `width` visible columns, preserving ANSI styles. - - Needed for the sub-agent live redraw: cursor-up-and-erase assumes one - logical line == one terminal row. If a line wraps, cursor-up undershoots - and the next redraw corrupts the display. Truncating prevents wrap. - """ - if width <= 0: - return s - out: list[str] = [] - visible = 0 - i = 0 - # Reserve 1 char for the trailing ellipsis - limit = width - 1 - truncated = False - while i < len(s): - m = _ANSI_RE.match(s, i) - if m: - out.append(m.group()) - i = m.end() - continue - if visible >= limit: - truncated = True - break - out.append(s[i]) - visible += 1 - i += 1 - if truncated: - # Strip styles (so ellipsis isn't left hanging inside a style run) - out.append("\033[0m…") - return "".join(out) - - -_THEME = Theme( - { - "tool.name": "bold rgb(255,200,80)", - "tool.args": "dim", - "tool.ok": "dim green", - "tool.fail": "dim red", - "info": "dim", - "muted": "dim", - # Markdown emphasis colors - "markdown.strong": "bold rgb(255,200,80)", - "markdown.emphasis": "italic rgb(180,140,40)", - "markdown.code": "rgb(120,220,255)", - "markdown.code_block": "rgb(120,220,255)", - "markdown.link": "underline rgb(90,180,255)", - "markdown.h1": "bold rgb(255,200,80)", - "markdown.h2": "bold rgb(240,180,95)", - "markdown.h3": "bold rgb(220,165,100)", - } -) - -_console = Console(theme=_THEME, highlight=False) - -# Indent prefix for all agent output (aligns under the `>` prompt) -_I = " " - - -def get_console() -> Console: - return _console - - -# ── Banner ───────────────────────────────────────────────────────────── - - -def print_banner(model: str | None = None, hf_user: str | None = None) -> None: - """Print particle logo then CRT boot sequence with system info.""" - from agent.utils.particle_logo import run_particle_logo - from agent.utils.crt_boot import run_boot_sequence - - # Particle coalesce logo β€” 1.5s converge, 2s hold - run_particle_logo(_console, hold_seconds=2.0) - - # Clear screen for CRT boot β€” starts from top - _console.file.write("\033[2J\033[H") - _console.file.flush() - - model_label = model or "unknown" - user_label = hf_user or "not logged in" - - # Warm gold palette matching the shimmer highlight (255, 200, 80) - gold = "rgb(255,200,80)" - dim_gold = "rgb(180,140,40)" - - boot_lines = [ - (f"{_I}Initializing agent runtime...", gold), - (f"{_I} User: {user_label}", dim_gold), - (f"{_I} Model: {model_label}", dim_gold), - (f"{_I} Tools: loading...", dim_gold), - ("", ""), - (f"{_I}/help for commands Β· /model to switch Β· /quit to exit", gold), - ] - - run_boot_sequence(_console, boot_lines) - - -# ── Init progress ────────────────────────────────────────────────────── - - -def print_init_done(tool_count: int = 0) -> None: - import time - - f = _console.file - # Overwrite the "Tools: loading..." line with actual count - f.write( - "\033[A\033[A\033[A\033[K" - ) # Move up 3 lines (blank + help + blank) then up to tools line - f.write("\033[A\033[K") - gold = "\033[38;2;180;140;40m" - reset = "\033[0m" - tool_text = f"{_I} Tools: {tool_count} loaded" - for ch in tool_text: - f.write(f"{gold}{ch}{reset}") - f.flush() - time.sleep(0.012) - f.write("\n\n") - # Reprint the help line - f.write( - f"{_I}\033[38;2;255;200;80m/help for commands Β· /model to switch Β· /quit to exit{reset}\n\n" - ) - # Ready message β€” minimal padding - f.write( - f"{_I}\033[38;2;255;200;80mReady. Let's build something impressive.{reset}\n" - ) - f.flush() - - -# ── Tool calls ───────────────────────────────────────────────────────── - - -def print_tool_call(tool_name: str, args_preview: str) -> None: - import time - - f = _console.file - # CRT-style: type out tool name in HF yellow - gold = "\033[38;2;255;200;80m" - reset = "\033[0m" - f.write(f"{_I}{gold}β–Έ ") - for ch in tool_name: - f.write(ch) - f.flush() - time.sleep(0.015) - f.write(f"{reset} \033[2m{args_preview}{reset}\n") - f.flush() - - -def print_tool_output(output: str, success: bool, truncate: bool = True) -> None: - if truncate: - output = _truncate(output, max_lines=10) - style = "tool.ok" if success else "tool.fail" - # Indent each line of tool output - indented = "\n".join(f"{_I} {line}" for line in output.split("\n")) - _console.print(f"[{style}]{indented}[/{style}]") - - -class SubAgentDisplayManager: - """Manages multiple concurrent sub-agent displays. - - Each agent gets its own stats and rolling tool-call log. - All agents are rendered together so terminal escape-code - erase/redraw stays consistent. - """ - - _MAX_VISIBLE = 4 # tool-call lines shown per agent - - def __init__(self): - self._agents: dict[str, dict] = {} # agent_id -> state dict - self._lines_on_screen = 0 - - def start(self, agent_id: str, label: str = "research") -> None: - import time - - self._agents[agent_id] = { - "label": label, - "calls": [], - "tool_count": 0, - "token_count": 0, - "start_time": time.monotonic(), - } - self._redraw() - - def set_tokens(self, agent_id: str, tokens: int) -> None: - if agent_id in self._agents: - self._agents[agent_id]["token_count"] = tokens - - def set_tool_count(self, agent_id: str, count: int) -> None: - if agent_id in self._agents: - self._agents[agent_id]["tool_count"] = count - - def add_call(self, agent_id: str, tool_desc: str) -> None: - if agent_id in self._agents: - self._agents[agent_id]["calls"].append(tool_desc) - self._redraw() - - def clear(self, agent_id: str) -> None: - # On completion: erase the live region, freeze a single-line summary - # for this agent ("βœ“ research: … (stats)") above the live region so - # the user sees each sub-agent finish cleanly without the tool-call - # noise, then redraw remaining live agents. - agent = self._agents.pop(agent_id, None) - self._erase() - if agent is not None: - width = max(10, _console.width) - line = _clip_to_width(self._render_completion_line(agent), width) - _console.file.write(line + "\n") - _console.file.flush() - self._lines_on_screen = 0 - if self._agents: - self._redraw() - - @staticmethod - def _render_completion_line(agent: dict) -> str: - stats = SubAgentDisplayManager._format_stats(agent) - label = agent["label"] - # dim green check + dim label; stats in parens - line = f"{_I}\033[38;2;120;200;140mβœ“\033[0m \033[2m{label}\033[0m" - if stats: - line += f" \033[2m({stats})\033[0m" - return line - - @staticmethod - def _format_stats(agent: dict) -> str: - import time - - start = agent["start_time"] - if start is None: - return "" - elapsed = time.monotonic() - start - if elapsed < 60: - time_str = f"{elapsed:.0f}s" - else: - time_str = f"{elapsed / 60:.0f}m {elapsed % 60:.0f}s" - tok = agent["token_count"] - tok_str = f"{tok / 1000:.1f}k" if tok >= 1000 else str(tok) - return f"{agent['tool_count']} tool uses Β· {tok_str} tokens Β· {time_str}" - - def _erase(self) -> None: - if self._lines_on_screen > 0: - f = _console.file - for _ in range(self._lines_on_screen): - f.write("\033[A\033[K") - f.flush() - - def _render_agent_lines(self, agent: dict, compact: bool = False) -> list[str]: - """Render one agent's block. - - compact=True β†’ single line (label + stats + most-recent tool name); - compact=False β†’ header + up to _MAX_VISIBLE rolling tool-call lines. - We use compact mode when multiple agents are live so the total live - region stays small enough to fit on one screen. Otherwise cursor-up - can't reach lines that have scrolled into scrollback, and every - redraw pollutes history with a stale copy. - """ - stats = self._format_stats(agent) - label = agent["label"] - header = f"{_I}\033[38;2;255;200;80mβ–Έ {label}\033[0m" - if stats: - header += f" \033[2m({stats})\033[0m" - if compact: - latest = agent["calls"][-1] if agent["calls"] else "" - if latest: - # Strip long json tails for the inline view - short = latest.split(" ")[0] if " " in latest else latest - header += f" \033[2mΒ·\033[0m \033[2m{short}\033[0m" - return [header] - lines = [header] - visible = agent["calls"][-self._MAX_VISIBLE :] - for desc in visible: - lines.append(f"{_I} \033[2m{desc}\033[0m") - return lines - - def _redraw(self) -> None: - f = _console.file - self._erase() - compact = len(self._agents) > 1 - width = max(10, _console.width) - lines: list[str] = [] - for agent in self._agents.values(): - for ln in self._render_agent_lines(agent, compact=compact): - lines.append(_clip_to_width(ln, width)) - for line in lines: - f.write(line + "\n") - f.flush() - self._lines_on_screen = len(lines) - - -_subagent_display = SubAgentDisplayManager() - - -def print_tool_log(tool: str, log: str, agent_id: str = "", label: str = "") -> None: - """Handle tool log events β€” sub-agent calls get the rolling display.""" - if tool == "research": - aid = agent_id or "research" - if log == "Starting research sub-agent...": - _subagent_display.start(aid, label or "research") - elif log == "Research complete.": - _subagent_display.clear(aid) - elif log.startswith("tokens:"): - _subagent_display.set_tokens(aid, int(log[7:])) - elif log.startswith("tools:"): - _subagent_display.set_tool_count(aid, int(log[6:])) - else: - _subagent_display.add_call(aid, log) - else: - _console.print(f"{_I}[dim]{tool}: {log}[/dim]") - - -# ── Messages ─────────────────────────────────────────────────────────── - - -async def print_markdown( - text: str, - cancel_event: "asyncio.Event | None" = None, - instant: bool = False, -) -> None: - import io - import random - from rich.padding import Padding - - _console.print() - - # Render markdown to a string buffer so we can type it out - buf = io.StringIO() - # Important: StringIO is not a TTY, so Rich would normally strip styles. - # Force terminal rendering so ANSI style codes are preserved for typewriter output. - buf_console = Console( - file=buf, - width=_console.width, - highlight=False, - theme=_THEME, - force_terminal=True, - color_system=_console.color_system or "truecolor", - ) - buf_console.print(Padding(Markdown(text), (0, 0, 0, 2))) - rendered = buf.getvalue() - - # Strip trailing whitespace from each line so we don't type across the full width - lines = rendered.split("\n") - rendered = "\n".join(line.rstrip() for line in lines) - - f = _console.file - - # Headless / non-interactive: dump the rendered markdown in one write. - if instant: - f.write(rendered) - f.write("\n") - f.flush() - return - - # CRT typewriter effect β€” async so the event loop can service signal - # handlers (Ctrl+C during streaming) between characters. If cancelled - # mid-type, stop cleanly: write an ANSI reset so half-open color state - # doesn't bleed onto the "interrupted" line, and return. - rng = random.Random(42) - cancelled = False - for ch in rendered: - if cancel_event is not None and cancel_event.is_set(): - cancelled = True - break - f.write(ch) - f.flush() - if ch == "\n": - await asyncio.sleep(0.002) - elif ch == " ": - await asyncio.sleep(0.002) - elif rng.random() < 0.03: - await asyncio.sleep(0.015) - else: - await asyncio.sleep(0.004) - f.write("\033[0m\n" if cancelled else "\n") - f.flush() - - -def print_error(message: str) -> None: - _console.print(f"\n{_I}[bold red]Error:[/bold red] {message}") - - -def print_turn_complete() -> None: - pass # no separator β€” clean output - - -def print_interrupted() -> None: - _console.print(f"\n{_I}[dim italic]interrupted[/dim italic]") - - -def print_compacted(old_tokens: int, new_tokens: int) -> None: - _console.print( - f"{_I}[dim]context compacted: {old_tokens:,} β†’ {new_tokens:,} tokens[/dim]" +def truncate_to_lines(text: str, max_lines: int = 6) -> str: + """Truncate text to max_lines, adding '...' if truncated""" + lines = text.split("\n") + if len(lines) <= max_lines: + return text + return ( + "\n".join(lines[:max_lines]) + + f"\n{Colors.CYAN}... ({len(lines) - max_lines} more lines){Colors.RESET}" ) -# ── Approval ─────────────────────────────────────────────────────────── +def format_header(text: str, emoji: str = "") -> str: + """Format a header with bold""" + full_text = f"{emoji} {text}" if emoji else text + return f"{Colors.BOLD}{full_text}{Colors.RESET}" -def print_approval_header(count: int) -> None: - label = f"Approval required β€” {count} item{'s' if count != 1 else ''}" - _console.print() - _console.print( - f"{_I}", - Panel( - f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False - ), - ) +def format_plan_display() -> str: + """Format the current plan for display (no colors, full visibility)""" + from agent.tools.plan_tool import get_current_plan + plan = get_current_plan() + if not plan: + return "" -def print_approval_item(index: int, total: int, tool_name: str, operation: str) -> None: - _console.print( - f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}" - ) + lines = ["\n" + "=" * 60] + lines.append("CURRENT PLAN") + lines.append("=" * 60 + "\n") + # Group by status + completed = [t for t in plan if t["status"] == "completed"] + in_progress = [t for t in plan if t["status"] == "in_progress"] + pending = [t for t in plan if t["status"] == "pending"] -def print_yolo_approve(count: int) -> None: - _console.print( - f"{_I}[bold yellow]yolo β†’[/bold yellow] auto-approved {count} item(s)" + if completed: + lines.append("Completed:") + for todo in completed: + lines.append(f" [x] {todo['id']}. {todo['content']}") + lines.append("") + + if in_progress: + lines.append("In Progress:") + for todo in in_progress: + lines.append(f" [~] {todo['id']}. {todo['content']}") + lines.append("") + + if pending: + lines.append("Pending:") + for todo in pending: + lines.append(f" [ ] {todo['id']}. {todo['content']}") + lines.append("") + + lines.append( + f"Total: {len(plan)} todos ({len(completed)} completed, {len(in_progress)} in progress, {len(pending)} pending)" ) + lines.append("=" * 60 + "\n") + return "\n".join(lines) -# ── Help ─────────────────────────────────────────────────────────────── - -HELP_TEXT = f"""\ -{_I}[bold]Commands[/bold] -{_I} [cyan]/help[/cyan] Show this help -{_I} [cyan]/undo[/cyan] Undo last turn -{_I} [cyan]/compact[/cyan] Compact context window -{_I} [cyan]/resume[/cyan] [index|id|path] Pick up from a log in ./session_logs -{_I} [cyan]/model[/cyan] [id] Show available models or switch -{_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off) -{_I} [cyan]/yolo[/cyan] Toggle auto-approve mode -{_I} [cyan]/status[/cyan] Current model & turn count -{_I} [cyan]/share-traces[/cyan] [public|private] Show/flip visibility of your HF trace dataset -{_I} [cyan]/quit[/cyan] Exit""" +def format_error(message: str) -> str: + """Format an error message in red""" + return f"{Colors.RED}ERROR: {message}{Colors.RESET}" -def print_help() -> None: - _console.print() - _console.print(HELP_TEXT) - _console.print() +def format_success(message: str, emoji: str = "") -> str: + """Format a success message in green""" + prefix = f"{emoji} " if emoji else "" + return f"{Colors.GREEN}{prefix}{message}{Colors.RESET}" -# ── Plan display ─────────────────────────────────────────────────────── +def format_tool_call(tool_name: str, arguments: str) -> str: + """Format a tool call message""" + return f"{Colors.YELLOW}Calling tool: {Colors.BOLD}{tool_name}{Colors.RESET}{Colors.YELLOW} with arguments: {arguments}{Colors.RESET}" -def format_plan_display() -> str: - """Format the current plan for display.""" - from agent.tools.plan_tool import get_current_plan - - plan = get_current_plan() - if not plan: - return "" - - completed = [t for t in plan if t["status"] == "completed"] - in_progress = [t for t in plan if t["status"] == "in_progress"] - pending = [t for t in plan if t["status"] == "pending"] - lines = [] - for t in completed: - lines.append(f"{_I}[green]βœ“[/green] [dim]{t['content']}[/dim]") - for t in in_progress: - lines.append(f"{_I}[yellow]β–Έ[/yellow] {t['content']}") - for t in pending: - lines.append(f"{_I}[dim]β—‹ {t['content']}[/dim]") +def format_tool_output(output: str, success: bool, truncate: bool = True) -> str: + """Format tool output with color and optional truncation""" + original_length = len(output) + if truncate: + output = truncate_to_lines(output, max_lines=6) - summary = f"[dim]{len(completed)}/{len(plan)} done[/dim]" - lines.append(f"{_I}{summary}") - return "\n".join(lines) + if success: + return ( + f"{Colors.YELLOW}Tool output ({original_length} tkns): {Colors.RESET}\n{output}" + ) + else: + return ( + f"{Colors.RED}Tool output ({original_length} tokens): {Colors.RESET}\n{output}" + ) -def print_plan() -> None: - plan_str = format_plan_display() - if plan_str: - _console.print(plan_str) +def format_turn_complete() -> str: + """Format turn complete message in green with hugging face emoji""" + return f"{Colors.GREEN}{Colors.BOLD}\U0001f917 Turn complete{Colors.RESET}\n" -# ── Formatting for plan_tool output (used by plan_tool handler) ──────── +def format_separator(char: str = "=", length: int = 60) -> str: + """Format a separator line""" + return char * length def format_plan_tool_output(todos: list) -> str: + """Format the plan tool output (no colors, full visibility)""" if not todos: return "Plan is empty." - lines = ["Plan updated:", ""] + lines = ["Plan updated successfully", ""] + + # Group by status completed = [t for t in todos if t["status"] == "completed"] in_progress = [t for t in todos if t["status"] == "in_progress"] pending = [t for t in todos if t["status"] == "pending"] - for t in completed: - lines.append(f" [x] {t['id']}. {t['content']}") - for t in in_progress: - lines.append(f" [~] {t['id']}. {t['content']}") - for t in pending: - lines.append(f" [ ] {t['id']}. {t['content']}") + if completed: + lines.append("Completed:") + for todo in completed: + lines.append(f" [x] {todo['id']}. {todo['content']}") + lines.append("") + + if in_progress: + lines.append("In Progress:") + for todo in in_progress: + lines.append(f" [~] {todo['id']}. {todo['content']}") + lines.append("") + + if pending: + lines.append("Pending:") + for todo in pending: + lines.append(f" [ ] {todo['id']}. {todo['content']}") + lines.append("") + + lines.append( + f"Total: {len(todos)} todos ({len(completed)} completed, {len(in_progress)} in progress, {len(pending)} pending)" + ) - lines.append(f"\n{len(completed)}/{len(todos)} done") return "\n".join(lines) - - -# ── Internal helpers ─────────────────────────────────────────────────── - - -def _truncate(text: str, max_lines: int = 6) -> str: - lines = text.split("\n") - if len(lines) <= max_lines: - return text - return "\n".join(lines[:max_lines]) + f"\n... ({len(lines) - max_lines} more lines)" diff --git a/backend/dependencies.py b/backend/dependencies.py index 58e02b7ee9a108e586496516bbd076e85b735fa2..03a1bb284507b6b78ad2a7534492934e416f6bed 100644 --- a/backend/dependencies.py +++ b/backend/dependencies.py @@ -1,5 +1,6 @@ """Authentication dependencies for FastAPI routes. +Provides auth validation for both REST and WebSocket endpoints. - In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user. - In production: validates Bearer tokens or cookies against HF OAuth. """ @@ -7,91 +8,26 @@ import logging import os import time -from collections.abc import Iterable -from hashlib import sha256 from typing import Any import httpx -from fastapi import HTTPException, Request, status - -from agent.core.hf_tokens import bearer_token_from_header, clean_hf_token - -from agent.core.hf_access import fetch_whoami_v2 +from fastapi import HTTPException, Request, WebSocket, status logger = logging.getLogger(__name__) OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", "")) -HF_EMPLOYEE_ORG = os.environ.get("HF_EMPLOYEE_ORG", "huggingface") # Simple in-memory token cache: token -> (user_info, expiry_time) _token_cache: dict[str, tuple[dict[str, Any], float]] = {} TOKEN_CACHE_TTL = 300 # 5 minutes -# Org membership cache: key -> expiry_time (only caches positive results) -_org_member_cache: dict[str, float] = {} - DEV_USER: dict[str, Any] = { "user_id": "dev", "username": "dev", "authenticated": True, - "plan": "pro", # Dev runs at the Pro quota tier so local testing isn't capped. } -INTERNAL_HF_TOKEN_KEY = "_hf_token" -OAUTH_SCOPE_COOKIE = "hf_oauth_scope_hash" -REQUIRED_OAUTH_SCOPES: tuple[str, ...] = ( - "openid", - "profile", - "read-repos", - "write-repos", - "contribute-repos", - "manage-repos", - "write-collections", - "inference-api", - "jobs", - "write-discussions", -) - -# Log the whoami-v2 shape once at DEBUG so we can confirm the production Pro -# signal without hammering the HF API. -_WHOAMI_SHAPE_LOGGED = False - - -def normalize_oauth_scopes(scopes: Iterable[str]) -> tuple[str, ...]: - """Return stable, de-duplicated OAuth scopes preserving declaration order.""" - seen: set[str] = set() - normalized: list[str] = [] - for scope in scopes: - value = str(scope).strip() - if not value or value in seen: - continue - seen.add(value) - normalized.append(value) - return tuple(normalized) - - -def configured_oauth_scopes() -> tuple[str, ...]: - """Return the scopes this backend should request from HF OAuth. - - Spaces expose README ``hf_oauth_scopes`` through ``OAUTH_SCOPES``. Unioning - that value with the app-required scopes keeps the local request and Space - metadata in sync while ensuring new required scopes are never omitted. - """ - env_scopes = os.environ.get("OAUTH_SCOPES", "").split() - return normalize_oauth_scopes((*env_scopes, *REQUIRED_OAUTH_SCOPES)) - - -def oauth_scope_fingerprint(scopes: Iterable[str] | None = None) -> str: - """Return a non-secret fingerprint for the current OAuth scope contract.""" - scope_list = configured_oauth_scopes() if scopes is None else scopes - payload = " ".join(sorted(normalize_oauth_scopes(scope_list))) - return sha256(payload.encode("utf-8")).hexdigest()[:16] - - -def _cookie_has_current_oauth_scope_marker(request: Request) -> bool: - return request.cookies.get(OAUTH_SCOPE_COOKIE) == oauth_scope_fingerprint() - async def _validate_token(token: str) -> dict[str, Any] | None: """Validate a token against HF OAuth userinfo endpoint. @@ -136,109 +72,12 @@ def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]: } -def _normalize_user_plan(whoami: Any) -> str: - """Normalize a whoami-v2 payload to the app's personal quota tiers.""" - if not isinstance(whoami, dict): - return "free" - - if whoami.get("isPro") is True: - return "pro" - - return "free" - - -async def _fetch_user_plan(token: str) -> str: - """Look up the user's HF plan via /api/whoami-v2. - - Returns 'free' | 'pro'. Non-200, network errors, or an unknown - payload shape all collapse to 'free' β€” safe default; we'd rather under- - grant the Pro cap than over-grant it on bad data. - """ - global _WHOAMI_SHAPE_LOGGED - whoami = await fetch_whoami_v2(token) - if whoami is None: - return "free" - - if not _WHOAMI_SHAPE_LOGGED: - _WHOAMI_SHAPE_LOGGED = True - logger.debug( - "whoami-v2 payload keys: %s (sample values: isPro=%r)", - sorted(whoami.keys()) - if isinstance(whoami, dict) - else type(whoami).__name__, - whoami.get("isPro") if isinstance(whoami, dict) else None, - ) - - return _normalize_user_plan(whoami) - - async def _extract_user_from_token(token: str) -> dict[str, Any] | None: """Validate a token and return a user dict, or None.""" user_info = await _validate_token(token) - if user_info is None: - return None - user = _user_from_info(user_info) - user["plan"] = await _fetch_user_plan(token) - user[INTERNAL_HF_TOKEN_KEY] = clean_hf_token(token) - return user - - -async def _dev_user_from_env() -> dict[str, Any]: - """Use HF_TOKEN as the dev identity when available. - - Local dev often runs without OAuth, but session trace uploads still need a - real HF namespace. Deriving the dev user from HF_TOKEN keeps local uploads - pointed at the token owner's dataset instead of dev/ml-intern-sessions. - """ - token = clean_hf_token(os.environ.get("HF_TOKEN", "")) - if not token: - return dict(DEV_USER) - - whoami = await fetch_whoami_v2(token) - if not isinstance(whoami, dict): - return dict(DEV_USER) - - username = None - for key in ("name", "user", "preferred_username"): - value = whoami.get(key) - if isinstance(value, str) and value: - username = value - break - if not username: - return dict(DEV_USER) - - return { - "user_id": username, - "username": username, - "authenticated": True, - "plan": await _fetch_user_plan(token), - INTERNAL_HF_TOKEN_KEY: token, - } - - -async def check_org_membership(token: str, org_name: str) -> bool: - """Check if the token owner belongs to an HF org. Only caches positive results.""" - now = time.time() - key = token + org_name - cached = _org_member_cache.get(key) - if cached and cached > now: - return True - - async with httpx.AsyncClient(timeout=10.0) as client: - try: - resp = await client.get( - f"{OPENID_PROVIDER_URL}/api/whoami-v2", - headers={"Authorization": f"Bearer {token}"}, - ) - if resp.status_code != 200: - return False - orgs = {o.get("name") for o in resp.json().get("orgs", [])} - if org_name in orgs: - _org_member_cache[key] = now + TOKEN_CACHE_TTL - return True - return False - except httpx.HTTPError: - return False + if user_info: + return _user_from_info(user_info) + return None async def get_current_user(request: Request) -> dict[str, Any]: @@ -248,15 +87,15 @@ async def get_current_user(request: Request) -> dict[str, Any]: 1. Authorization: Bearer header 2. hf_access_token cookie - In dev mode (AUTH_ENABLED=False), uses HF_TOKEN as the user when possible. + In dev mode (AUTH_ENABLED=False), returns a default dev user. """ if not AUTH_ENABLED: - return await _dev_user_from_env() + return DEV_USER - # Bearer callers manage token lifecycle themselves; only browser cookie - # auth is forced through the scope-freshness marker below. - token = bearer_token_from_header(request.headers.get("Authorization", "")) - if token: + # Try Authorization header + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] user = await _extract_user_from_token(token) if user: return user @@ -264,15 +103,6 @@ async def get_current_user(request: Request) -> dict[str, Any]: # Try cookie token = request.cookies.get("hf_access_token") if token: - if not _cookie_has_current_oauth_scope_marker(request): - logger.info( - "Rejecting stale HF OAuth cookie; current scopes require refresh." - ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication scopes changed. Please log in again.", - headers={"WWW-Authenticate": "Bearer"}, - ) user = await _extract_user_from_token(token) if user: return user @@ -284,27 +114,31 @@ async def get_current_user(request: Request) -> dict[str, Any]: ) -def _extract_token(request: Request) -> str | None: - """Pull the HF access token from the Authorization header or cookie. +async def get_ws_user(websocket: WebSocket) -> dict[str, Any] | None: + """Extract and validate user from WebSocket connection. + + WebSocket doesn't support custom headers from browser, so we check: + 1. ?token= query parameter + 2. hf_access_token cookie (sent automatically for same-origin) - Mirrors the lookup order used by ``get_current_user``. + Returns user dict or None if not authenticated. + In dev mode, returns the default dev user. """ - token = bearer_token_from_header(request.headers.get("Authorization", "")) - if token: - return token - return request.cookies.get("hf_access_token") + if not AUTH_ENABLED: + return DEV_USER + # Try query param + token = websocket.query_params.get("token") + if token: + user = await _extract_user_from_token(token) + if user: + return user -async def require_huggingface_org_member(request: Request) -> bool: - """Return True if the caller is a member of the ``huggingface`` org. + # Try cookie (works for same-origin WebSocket) + token = websocket.cookies.get("hf_access_token") + if token: + user = await _extract_user_from_token(token) + if user: + return user - Used to gate endpoints that can push a session onto an Anthropic model - billed to the Space's ``ANTHROPIC_API_KEY``. Returns True unconditionally - in dev mode so local testing isn't blocked. - """ - if not AUTH_ENABLED: - return True - token = _extract_token(request) - if not token: - return False - return await check_org_membership(token, HF_EMPLOYEE_ORG) + return None diff --git a/backend/kpis_scheduler.py b/backend/kpis_scheduler.py deleted file mode 100644 index 9b2199c69151118762ed2cfaddde579fb5a694d3..0000000000000000000000000000000000000000 --- a/backend/kpis_scheduler.py +++ /dev/null @@ -1,148 +0,0 @@ -"""In-process hourly KPI rollup, owned by the backend Space lifespan. - -Replaces an external GitHub Actions cron so the rollup lives next to the data -and reuses the Space's existing HF token β€” no production secrets on the -public source repo. See ``scripts/build_kpis.py`` for the data-flow diagram -and metric definitions. - -Behaviour:: - - lifespan startup β†’ start APScheduler with cron("5 * * * *", UTC) - β†’ fire a best-effort 6-hour backfill (fire-and-forget) - each :05 β†’ run ``build_kpis.run_for_hour`` for the just-completed hour - lifespan shutdown β†’ scheduler.shutdown(wait=False) - -Environment:: - - HF_KPI_WRITE_TOKEN | HF_SESSION_UPLOAD_TOKEN | HF_TOKEN | HF_ADMIN_TOKEN - First one found is used. Least-privilege first. - KPI_SOURCE_REPO default smolagents/ml-intern-sessions - KPI_TARGET_REPO default smolagents/ml-intern-kpis - ML_INTERN_KPIS_DISABLED if truthy, the scheduler is not started -""" - -from __future__ import annotations - -import asyncio -import importlib.util -import logging -import os -from datetime import datetime, timedelta, timezone -from pathlib import Path -from typing import Optional - -logger = logging.getLogger(__name__) - -_PROJECT_ROOT = Path(__file__).resolve().parent.parent - -# Hold strong refs to backfill tasks so asyncio doesn't GC them mid-run. -_background_tasks: set[asyncio.Task] = set() - -_scheduler = None # AsyncIOScheduler instance (lazy import) - - -def _resolve_token() -> Optional[str]: - """Pick the first available HF token. Least-privilege first.""" - for var in ( - "HF_KPI_WRITE_TOKEN", - "HF_SESSION_UPLOAD_TOKEN", - "HF_TOKEN", - "HF_ADMIN_TOKEN", - ): - val = os.environ.get(var) - if val: - return val - return None - - -def _load_build_kpis(): - """Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path.""" - spec = importlib.util.spec_from_file_location( - "build_kpis", - _PROJECT_ROOT / "scripts" / "build_kpis.py", - ) - mod = importlib.util.module_from_spec(spec) - assert spec.loader is not None - spec.loader.exec_module(mod) - return mod - - -async def _run_hour(hour_dt: datetime) -> None: - """Run one hourly rollup off the event loop. Best-effort, never raises.""" - token = _resolve_token() - if not token: - logger.warning("kpis_scheduler: no HF token available, skipping %s", hour_dt) - return - try: - mod = _load_build_kpis() - from huggingface_hub import HfApi - - api = HfApi() - source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions") - target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis") - await asyncio.to_thread(mod.run_for_hour, api, source, target, hour_dt, token) - except Exception as e: - logger.warning("kpis_scheduler: rollup for %s failed: %s", hour_dt, e) - - -async def run_last_completed_hour() -> None: - """The scheduled-at-:05 job. Rolls up the previous whole hour.""" - now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) - await _run_hour(now - timedelta(hours=1)) - - -async def backfill(hours: int = 6) -> None: - """Catch-up pass for hours the Space was down. Idempotent (overwrites).""" - now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) - for i in range(1, hours + 1): - await _run_hour(now - timedelta(hours=i)) - - -def start(backfill_hours: int = 6) -> None: - """Called from FastAPI lifespan startup.""" - global _scheduler - if os.environ.get("ML_INTERN_KPIS_DISABLED"): - logger.info("kpis_scheduler: disabled via ML_INTERN_KPIS_DISABLED") - return - if _scheduler is not None: - return - - try: - from apscheduler.schedulers.asyncio import AsyncIOScheduler - from apscheduler.triggers.cron import CronTrigger - except ImportError: - logger.warning("kpis_scheduler: apscheduler not installed, skipping") - return - - _scheduler = AsyncIOScheduler(timezone="UTC") - _scheduler.add_job( - run_last_completed_hour, - CronTrigger(minute=5), - id="kpis_hourly", - misfire_grace_time=600, # tolerate a 10-min misfire window - coalesce=True, # collapse multiple missed fires into one - max_instances=1, - replace_existing=True, - ) - _scheduler.start() - logger.info("kpis_scheduler: started (cron '5 * * * *' UTC)") - - # Non-blocking backfill. Hold a strong ref until done so asyncio doesn't - # GC the task before it finishes. - try: - task = asyncio.get_running_loop().create_task(backfill(backfill_hours)) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) - except RuntimeError: - # Not in an event loop (tests); skip backfill. - pass - - -async def shutdown() -> None: - """Called from FastAPI lifespan shutdown.""" - global _scheduler - if _scheduler is None: - return - _scheduler.shutdown(wait=False) - _scheduler = None - logger.info("kpis_scheduler: stopped") diff --git a/backend/main.py b/backend/main.py index 3a8983055db6e1dbd8c263d6bba6022edc3a1a01..fc75ab9e11696664776cc2370d68e589196af7ad 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,17 +6,19 @@ from contextlib import asynccontextmanager from pathlib import Path from dotenv import load_dotenv + +load_dotenv() + +# Ensure HF_TOKEN is set β€” fall back to HF_ADMIN_TOKEN if available (HF Spaces) +if not os.environ.get("HF_TOKEN") and os.environ.get("HF_ADMIN_TOKEN"): + os.environ["HF_TOKEN"] = os.environ["HF_ADMIN_TOKEN"] + from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -# Load .env before importing routes/session_manager so persistence and quota -# modules see local Mongo settings during startup. -load_dotenv(Path(__file__).parent.parent / ".env") - -from routes.agent import router as agent_router # noqa: E402 -from routes.auth import router as auth_router # noqa: E402 -from session_manager import session_manager # noqa: E402 +from routes.agent import router as agent_router +from routes.auth import router as auth_router # Configure logging logging.basicConfig( @@ -30,54 +32,15 @@ logger = logging.getLogger(__name__) async def lifespan(app: FastAPI): """Application lifespan handler.""" logger.info("Starting HF Agent backend...") - await session_manager.start() - # Start in-process hourly KPI rollup. Replaces an external cron so the - # rollup lives next to the data and reuses the Space's HF token. - try: - import kpis_scheduler - - kpis_scheduler.start() - except Exception as e: - logger.warning("KPI scheduler failed to start: %s", e) yield - logger.info("Shutting down HF Agent backend...") - try: - import kpis_scheduler - - await kpis_scheduler.shutdown() - except Exception as e: - logger.warning("KPI scheduler shutdown failed: %s", e) - - # Final-flush: save every still-active session so we don't lose traces on - # server restart. Uploads are detached subprocesses β€” this is fast. - try: - for sid, agent_session in list(session_manager.sessions.items()): - sess = agent_session.session - if sess.config.save_sessions: - try: - sess.save_and_upload_detached(sess.config.session_dataset_repo) - logger.info("Flushed session %s on shutdown", sid) - except Exception as e: - logger.warning("Failed to flush session %s: %s", sid, e) - except Exception as e: - logger.warning("Lifespan final-flush skipped: %s", e) - await session_manager.close() - - -# Disable FastAPI auto-docs when running on HF Spaces (SPACE_ID is set by the -# platform) to avoid exposing the full API surface to anonymous visitors. Local -# dev keeps /docs and /redoc available. -_DOCS_DISABLED = os.environ.get("SPACE_ID") is not None + app = FastAPI( title="HF Agent", description="ML Engineering Assistant API", version="1.0.0", lifespan=lifespan, - docs_url=None if _DOCS_DISABLED else "/docs", - redoc_url=None if _DOCS_DISABLED else "/redoc", - openapi_url=None if _DOCS_DISABLED else "/openapi.json", ) # CORS middleware for development diff --git a/backend/models.py b/backend/models.py index 1126c2b92b93e1f7f429abfd02e97be4b145ae1f..4ebf0caa01cb48675bd55a5fffeadbacf53669c2 100644 --- a/backend/models.py +++ b/backend/models.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel class OpType(str, Enum): @@ -37,8 +37,6 @@ class ToolApproval(BaseModel): tool_call_id: str approved: bool feedback: str | None = None - edited_script: str | None = None - namespace: str | None = None class ApprovalRequest(BaseModel): @@ -52,16 +50,7 @@ class SubmitRequest(BaseModel): """Request to submit user input.""" session_id: str - # Cap text size to prevent context-bloat / cost-amplification: a malicious - # or runaway client could otherwise attach megabytes that then ride along - # in every subsequent turn until /api/compact is called. - text: str = Field(..., min_length=1, max_length=100_000) - - -class TruncateRequest(BaseModel): - """Request to truncate conversation history to before a specific user message.""" - - user_message_index: int + text: str class SessionResponse(BaseModel): @@ -69,24 +58,6 @@ class SessionResponse(BaseModel): session_id: str ready: bool = True - model: str | None = None - - -class PendingApprovalTool(BaseModel): - """A tool waiting for user approval.""" - - tool: str - tool_call_id: str - arguments: dict[str, Any] = {} - - -class SessionAutoApprovalInfo(BaseModel): - """Per-session auto-approval budget state.""" - - enabled: bool = False - cost_cap_usd: float | None = None - estimated_spend_usd: float = 0.0 - remaining_usd: float | None = None class SessionInfo(BaseModel): @@ -95,29 +66,8 @@ class SessionInfo(BaseModel): session_id: str created_at: str is_active: bool - is_processing: bool = False message_count: int user_id: str = "dev" - pending_approval: list[PendingApprovalTool] | None = None - model: str | None = None - title: str | None = None - notification_destinations: list[str] = Field(default_factory=list) - auto_approval: SessionAutoApprovalInfo = Field( - default_factory=SessionAutoApprovalInfo - ) - - -class SessionNotificationsRequest(BaseModel): - """Replace the session's auto-notification destinations.""" - - destinations: list[str] - - -class SessionYoloRequest(BaseModel): - """Update a session's auto-approval policy.""" - - enabled: bool - cost_cap_usd: float | None = Field(default=None, ge=0) class HealthResponse(BaseModel): @@ -134,6 +84,4 @@ class LLMHealthResponse(BaseModel): status: str # "ok" | "error" model: str error: str | None = None - error_type: str | None = ( - None # "auth" | "credits" | "rate_limit" | "network" | "unknown" - ) + error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown" diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 0a742b7c793a1949c2f8dcd72f7213bc4403eaf2..381ae17dbb1cf9690effd378be892953b71830ed 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -1,235 +1,71 @@ -"""Agent API routes β€” REST + SSE endpoints. +"""Agent API routes - WebSocket and REST endpoints. All routes (except /health) require authentication via the get_current_user dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically. """ -import asyncio -import json import logging +import os from typing import Any -from dependencies import ( - INTERNAL_HF_TOKEN_KEY, - get_current_user, -) +from dependencies import get_current_user, get_ws_user from fastapi import ( APIRouter, Depends, HTTPException, Request, + WebSocket, + WebSocketDisconnect, ) -from fastapi.exceptions import RequestValidationError -from fastapi.responses import StreamingResponse from litellm import acompletion -from pydantic import ValidationError from models import ( ApprovalRequest, HealthResponse, LLMHealthResponse, SessionInfo, - SessionNotificationsRequest, SessionResponse, - SessionYoloRequest, SubmitRequest, - TruncateRequest, -) -from session_manager import ( - MAX_SESSIONS, - AgentSession, - SessionCapacityError, - session_manager, ) - -import user_quotas - -from agent.core.hf_access import get_jobs_access -from agent.core.hf_tokens import resolve_hf_request_token, resolve_hf_router_token -from agent.core.llm_params import _resolve_llm_params +from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager +from websocket import manager as ws_manager logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) -_background_teardown_tasks: set[asyncio.Task] = set() - -DEFAULT_CLAUDE_MODEL_ID = "bedrock/us.anthropic.claude-opus-4-6-v1" -DEFAULT_FREE_MODEL_ID = "moonshotai/Kimi-K2.6" -PREMIUM_MODEL_IDS = { - DEFAULT_CLAUDE_MODEL_ID, - "openai/gpt-5.5", -} - - -def _claude_picker_model_id() -> str: - """Return the model ID used by the Claude option in the UI. - The frontend config sets ``session_manager.config.model_name`` from - ``ML_INTERN_CLAUDE_MODEL_ID`` when that env var is present, otherwise it - falls back to the production Bedrock Claude model. This function only - exposes that resolved config value for the Claude picker; non-Claude models - are listed separately in the model switcher. - """ - return session_manager.config.model_name - - -def _available_models() -> list[dict[str, Any]]: - models = [ - { - "id": "moonshotai/Kimi-K2.6", - "label": "Kimi K2.6", - "provider": "huggingface", - "tier": "free", - "recommended": True, - }, - { - "id": _claude_picker_model_id(), - "label": "Claude Opus 4.6", - "provider": "anthropic", - "tier": "pro", - "recommended": True, - }, - { - "id": "openai/gpt-5.5", - "label": "GPT-5.5", - "provider": "openai", - "tier": "pro", - }, - { - "id": "MiniMaxAI/MiniMax-M2.7", - "label": "MiniMax M2.7", - "provider": "huggingface", - "tier": "free", - }, - { - "id": "zai-org/GLM-5.1", - "label": "GLM 5.1", - "provider": "huggingface", - "tier": "free", - }, - { - "id": "deepseek-ai/DeepSeek-V4-Pro:deepinfra", - "label": "DeepSeek V4 Pro", - "provider": "huggingface", - "tier": "free", - }, - ] - return models - - -AVAILABLE_MODELS = _available_models() - - -def _is_premium_model(model_id: str) -> bool: - return model_id in PREMIUM_MODEL_IDS - - -async def _model_override_for_new_session( - request: Request, - requested_model: str | None, -) -> str | None: - """Return the model override to use when creating a new session. - - Explicit premium model requests are allowed and charged at message-submit - time. Implicit default sessions are more forgiving: when the configured - default is premium, start them on the first free model instead of spending - premium quota accidentally. - """ - resolved_model = requested_model or session_manager.config.model_name - if not _is_premium_model(resolved_model): - return requested_model - if requested_model: - return requested_model - - logger.info( - "Default premium model %s would spend quota; " - "creating session with free fallback %s", - resolved_model, - DEFAULT_FREE_MODEL_ID, - ) - return DEFAULT_FREE_MODEL_ID - - -async def _enforce_premium_model_quota( - user: dict[str, Any], - agent_session: AgentSession, -) -> None: - """Charge the user's daily premium-model quota on first use in a session. - - Runs at *message-submit* time, not session-create time β€” so spinning up a - premium-model session to look around doesn't burn quota. The - ``claude_counted`` flag on ``AgentSession`` guards against re-counting the - same session; the stored field name is kept for persistence compatibility. - - No-ops when the session's current model isn't premium, or when this - session has already been charged. Raises 429 when the user has hit - their daily cap. - """ - if agent_session.claude_counted: - return - model_name = agent_session.session.config.model_name - if not _is_premium_model(model_name): - return - user_id = user["user_id"] - plan = user.get("plan", "free") - cap = user_quotas.daily_cap_for(plan) - new_count = await user_quotas.try_increment_claude(user_id, cap) - if new_count is None: - if plan == "pro": - message = ( - "Daily premium model limit reached. Use a free model and try " - "premium models again tomorrow." - ) - else: - message = ( - "Daily premium model limit reached. Upgrade to HF Pro for " - f"{user_quotas.CLAUDE_PRO_DAILY}/day or use a free model." - ) - raise HTTPException( - status_code=429, - detail={ - "error": "premium_model_daily_cap", - "plan": plan, - "cap": cap, - "message": message, - }, - ) - agent_session.claude_counted = True - await session_manager.persist_session_snapshot(agent_session) - - -def _user_hf_token(user: dict[str, Any] | None) -> str | None: - if not isinstance(user, dict): - return None - return user.get(INTERNAL_HF_TOKEN_KEY) - - -async def _check_session_access( - session_id: str, - user: dict[str, Any], - request: Request | None = None, - preload_sandbox: bool = True, -) -> AgentSession: - """Verify and lazily load the user's session. Raises 403 or 404.""" - hf_token = ( - resolve_hf_request_token(request) - if request is not None - else _user_hf_token(user) - ) - agent_session = await session_manager.ensure_session_loaded( - session_id, - user["user_id"], - hf_token=hf_token, - hf_username=user.get("username"), - preload_sandbox=preload_sandbox, - ) - if not agent_session: +AVAILABLE_MODELS = [ + { + "id": "huggingface/novita/MiniMaxAI/MiniMax-M2.1", + "label": "MiniMax M2.1", + "provider": "huggingface", + "recommended": True, + }, + { + "id": "anthropic/claude-opus-4-5-20251101", + "label": "Claude Opus 4.5", + "provider": "anthropic", + "recommended": True, + }, + { + "id": "huggingface/novita/moonshotai/Kimi-K2.5", + "label": "Kimi K2.5", + "provider": "huggingface", + }, + { + "id": "huggingface/novita/zai-org/GLM-5", + "label": "GLM 5", + "provider": "huggingface", + }, +] + + +def _check_session_access(session_id: str, user: dict[str, Any]) -> None: + """Verify the user has access to the given session. Raises 403 or 404.""" + info = session_manager.get_session_info(session_id) + if not info: raise HTTPException(status_code=404, detail="Session not found") - if user["user_id"] != "dev" and agent_session.user_id not in { - user["user_id"], - "dev", - }: + if not session_manager.verify_session_access(session_id, user["user_id"]): raise HTTPException(status_code=403, detail="Access denied to this session") - return agent_session @router.get("/health", response_model=HealthResponse) @@ -253,13 +89,14 @@ async def llm_health_check() -> LLMHealthResponse: - timeout / network β†’ provider unreachable """ model = session_manager.config.model_name + hf_key = os.environ.get("INFERENCE_TOKEN") try: - llm_params = _resolve_llm_params(model, reasoning_effort="high") await acompletion( + model=model, messages=[{"role": "user", "content": "hi"}], max_tokens=1, timeout=10, - **llm_params, + api_key=hf_key if hf_key and model.startswith("huggingface/") else None, ) return LLMHealthResponse(status="ok", model=model) except Exception as e: @@ -304,71 +141,56 @@ async def get_model() -> dict: } -_TITLE_STRIP_CHARS = str.maketrans("", "", "`*_~#[]()") +@router.post("/config/model") +async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict: + """Set the LLM model. Applies to new conversations.""" + model_id = body.get("model") + if not model_id: + raise HTTPException(status_code=400, detail="Missing 'model' field") + valid_ids = {m["id"] for m in AVAILABLE_MODELS} + if model_id not in valid_ids: + raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") + session_manager.config.model_name = model_id + logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}") + return {"model": model_id} @router.post("/title") async def generate_title( request: SubmitRequest, user: dict = Depends(get_current_user) ) -> dict: - """Generate a short title for a chat session based on the first user message. - - Always uses gpt-oss-120b via Cerebras on the HF router. The tab headline - renders as plain text, so the model is told to avoid markdown and any - stray formatting characters are stripped before returning. gpt-oss is a - reasoning model β€” reasoning_effort=low keeps the reasoning budget small - so the 60-token output budget isn't consumed before the title is written. - """ - api_key = resolve_hf_router_token(_user_hf_token(user)) + """Generate a short title for a chat session based on the first user message.""" + model = session_manager.config.model_name + hf_key = os.environ.get("INFERENCE_TOKEN") try: response = await acompletion( - # Double openai/ prefix: LiteLLM strips the first as its provider - # prefix, leaving the HF model id on the wire for the router. - model="openai/openai/gpt-oss-120b:cerebras", - api_base="https://router.huggingface.co/v1", - api_key=api_key, + model=model, messages=[ { "role": "system", "content": ( "Generate a very short title (max 6 words) for a chat conversation " "that starts with the following user message. " - "Reply with ONLY the title in plain text. " - "Do NOT use markdown, backticks, asterisks, quotes, brackets, or any " - "formatting characters. No punctuation at the end." + "Reply with ONLY the title, no quotes, no punctuation at the end." ), }, {"role": "user", "content": request.text[:500]}, ], - max_tokens=60, + max_tokens=20, temperature=0.3, - timeout=10, - reasoning_effort="low", + timeout=8, + api_key=hf_key if hf_key and model.startswith("huggingface/") else None, ) title = response.choices[0].message.content.strip().strip('"').strip("'") - title = title.translate(_TITLE_STRIP_CHARS).strip() + # Safety: cap at 50 chars if len(title) > 50: title = title[:50].rstrip() + "…" - try: - await _check_session_access(request.session_id, user) - await session_manager.update_session_title(request.session_id, title) - except Exception: - logger.debug( - "Skipping title persistence for missing session %s", request.session_id - ) return {"title": title} except Exception as e: logger.warning(f"Title generation failed: {e}") + # Fallback: truncate the message fallback = request.text.strip() title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback - try: - await _check_session_access(request.session_id, user) - await session_manager.update_session_title(request.session_id, title) - except Exception: - logger.debug( - "Skipping fallback title persistence for missing session %s", - request.session_id, - ) return {"title": title} @@ -382,103 +204,23 @@ async def create_session( and stored in the session so that tools (e.g. hf_jobs) can act on behalf of the user. - Optional body ``{"model"?: }`` selects the session's LLM; unknown - ids are rejected (400). The premium-model quota runs at message-submit - time, not here β€” spinning up a session to look around is free. - Returns 503 if the server or user has reached the session limit. """ - # Extract the user's HF token (Bearer header, HttpOnly cookie, or env var) - hf_token = resolve_hf_request_token(request) - - # Optional model override. Empty body falls back to the config default. - model: str | None = None - try: - body = await request.json() - except Exception: - body = None - if isinstance(body, dict): - model = body.get("model") - - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: - raise HTTPException(status_code=400, detail=f"Unknown model: {model}") - - # Explicit premium selections are allowed. If the implicit configured - # default is premium, start the session on a free model instead. - model = await _model_override_for_new_session(request, model) - - try: - session_id = await session_manager.create_session( - user_id=user["user_id"], - hf_username=user.get("username"), - hf_token=hf_token, - model=model, - is_pro=user.get("plan") == "pro", - ) - except SessionCapacityError as e: - raise HTTPException(status_code=503, detail=str(e)) - - return SessionResponse( - session_id=session_id, - ready=True, - model=model or session_manager.config.model_name, - ) - - -@router.post("/session/restore-summary", response_model=SessionResponse) -async def restore_session_summary( - request: Request, body: dict, user: dict = Depends(get_current_user) -) -> SessionResponse: - """Create a new session seeded with a summary of the caller's prior - conversation. The client sends its cached messages; we run the standard - summarization prompt on them and drop the result into the new - session's context as a user-role system note. - - Optional ``"model"`` in the body overrides the session's LLM. The - premium-model quota runs at message-submit time, not here. - """ - messages = body.get("messages") - if not isinstance(messages, list) or not messages: - raise HTTPException(status_code=400, detail="Missing 'messages' array") - - hf_token = resolve_hf_request_token(request) - - model = body.get("model") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: - raise HTTPException(status_code=400, detail=f"Unknown model: {model}") - - model = await _model_override_for_new_session(request, model) + # Extract the user's HF token (Bearer header or HttpOnly cookie) + hf_token = None + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + hf_token = auth_header[7:] + if not hf_token: + hf_token = request.cookies.get("hf_access_token") try: session_id = await session_manager.create_session( - user_id=user["user_id"], - hf_username=user.get("username"), - hf_token=hf_token, - model=model, - is_pro=user.get("plan") == "pro", + user_id=user["user_id"], hf_token=hf_token ) except SessionCapacityError as e: raise HTTPException(status_code=503, detail=str(e)) - - try: - summarized = await session_manager.seed_from_summary(session_id, messages) - except ValueError as e: - raise HTTPException(status_code=500, detail=str(e)) - except Exception as e: - logger.exception("seed_from_summary failed") - raise HTTPException(status_code=500, detail=f"Summary failed: {e}") - - logger.info( - f"Seeded session {session_id} for {user.get('username', 'unknown')} " - f"(summary of {summarized} messages)" - ) - return SessionResponse( - session_id=session_id, - ready=True, - model=model or session_manager.config.model_name, - ) + return SessionResponse(session_id=session_id, ready=True) @router.get("/session/{session_id}", response_model=SessionInfo) @@ -486,142 +228,24 @@ async def get_session( session_id: str, user: dict = Depends(get_current_user) ) -> SessionInfo: """Get session information. Only accessible by the session owner.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) info = session_manager.get_session_info(session_id) return SessionInfo(**info) -@router.post("/session/{session_id}/model") -async def set_session_model( - session_id: str, - body: dict, - request: Request, - user: dict = Depends(get_current_user), -) -> dict: - """Switch the active model for a single session (tab-scoped). - - Takes effect on the next LLM call in that session β€” other sessions - (including other browser tabs) are unaffected. Model switches don't - charge quota β€” the premium-model quota only fires at message-submit time. - """ - agent_session = await _check_session_access(session_id, user, request) - model_id = body.get("model") - if not model_id: - raise HTTPException(status_code=400, detail="Missing 'model' field") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model_id not in valid_ids: - raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") - if not agent_session: - raise HTTPException(status_code=404, detail="Session not found") - await session_manager.update_session_model(session_id, model_id) - logger.info( - f"Session {session_id} model β†’ {model_id} " - f"(by {user.get('username', 'unknown')})" - ) - return {"session_id": session_id, "model": model_id} - - -@router.post("/session/{session_id}/notifications") -async def set_session_notifications( - session_id: str, - body: SessionNotificationsRequest, - user: dict = Depends(get_current_user), -) -> dict: - """Replace the session's auto-notification destinations.""" - agent_session = await _check_session_access(session_id, user) - try: - destinations = session_manager.set_notification_destinations( - session_id, body.destinations - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - await session_manager.persist_session_snapshot(agent_session) - return { - "session_id": session_id, - "notification_destinations": destinations, - } - - -@router.patch("/session/{session_id}/yolo") -async def set_session_yolo( - session_id: str, - body: SessionYoloRequest, - user: dict = Depends(get_current_user), -) -> dict: - """Update the session-scoped auto-approval policy.""" - await _check_session_access(session_id, user) - try: - summary = await session_manager.update_session_auto_approval( - session_id, - enabled=body.enabled, - cost_cap_usd=body.cost_cap_usd, - cap_provided="cost_cap_usd" in body.model_fields_set, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return {"session_id": session_id, **summary} - - -@router.get("/user/quota") -async def get_user_quota(user: dict = Depends(get_current_user)) -> dict: - """Return the user's plan tier and today's premium-model quota state.""" - plan = user.get("plan", "free") - used = await user_quotas.get_claude_used_today(user["user_id"]) - cap = user_quotas.daily_cap_for(plan) - remaining = max(0, cap - used) - return { - "plan": plan, - "premium_used_today": used, - "premium_daily_cap": cap, - "premium_remaining": remaining, - } - - -@router.get("/user/jobs-access") -async def get_jobs_access_info( - request: Request, user: dict = Depends(get_current_user) -) -> dict: - """Return the namespaces the current token can run HF Jobs under. - - Credits are enforced by the HF API at job-creation time, not here β€” - the response only describes which wallets the caller is allowed to - pick from. Pro is irrelevant. - """ - token = resolve_hf_request_token(request) - - access = await get_jobs_access(token or "") - return { - "eligible_namespaces": access.eligible_namespaces if access else [], - "default_namespace": access.default_namespace if access else None, - "billing_url": "https://huggingface.co/settings/billing", - } - - @router.get("/sessions", response_model=list[SessionInfo]) async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]: """List sessions belonging to the authenticated user.""" - sessions = await session_manager.list_sessions(user_id=user["user_id"]) + sessions = session_manager.list_sessions(user_id=user["user_id"]) return [SessionInfo(**s) for s in sessions] -@router.post("/session/{session_id}/sandbox/teardown") -async def teardown_session_sandbox( - session_id: str, user: dict = Depends(get_current_user) -) -> dict: - """Best-effort sandbox teardown that preserves durable chat history.""" - await _check_session_access(session_id, user, preload_sandbox=False) - task = asyncio.create_task(session_manager.teardown_sandbox(session_id)) - _background_teardown_tasks.add(task) - task.add_done_callback(_background_teardown_tasks.discard) - return {"status": "teardown_requested", "session_id": session_id} - - @router.delete("/session/{session_id}") async def delete_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Delete a session. Only accessible by the session owner.""" - await _check_session_access(session_id, user, preload_sandbox=False) + _check_session_access(session_id, user) success = await session_manager.delete_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") @@ -630,41 +254,14 @@ async def delete_session( @router.post("/submit") async def submit_input( - request: Request, user: dict = Depends(get_current_user) + request: SubmitRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit user input to a session. Only accessible by the session owner.""" - # Parse the body manually so session ownership can be checked before the - # text-length constraints fire β€” otherwise a non-owner sending an empty - # or oversized text gets a 422 leaking the constraint instead of the 404 - # they'd get for any other access to a session they don't own. - try: - payload = await request.json() - except (json.JSONDecodeError, TypeError) as exc: - raise HTTPException(status_code=422, detail=str(exc)) - if not isinstance(payload, dict): - raise HTTPException(status_code=422, detail="Body must be a JSON object") - raw_session_id = payload.get("session_id") - if not isinstance(raw_session_id, str) or not raw_session_id: - raise RequestValidationError( - [ - { - "type": "missing", - "loc": ("body", "session_id"), - "msg": "Field required", - "input": payload, - } - ] - ) - agent_session = await _check_session_access(raw_session_id, user) - try: - body = SubmitRequest(**payload) - except ValidationError as exc: - raise RequestValidationError(exc.errors()) from exc - await _enforce_premium_model_quota(user, agent_session) - success = await session_manager.submit_user_input(body.session_id, body.text) + _check_session_access(request.session_id, user) + success = await session_manager.submit_user_input(request.session_id, request.text) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") - return {"status": "submitted", "session_id": body.session_id} + return {"status": "submitted", "session_id": request.session_id} @router.post("/approve") @@ -672,14 +269,12 @@ async def submit_approval( request: ApprovalRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit tool approvals to a session. Only accessible by the session owner.""" - await _check_session_access(request.session_id, user) + _check_session_access(request.session_id, user) approvals = [ { "tool_call_id": a.tool_call_id, "approved": a.approved, "feedback": a.feedback, - "edited_script": a.edited_script, - "namespace": a.namespace, } for a in request.approvals ] @@ -689,286 +284,34 @@ async def submit_approval( return {"status": "submitted", "session_id": request.session_id} -@router.post("/chat/{session_id}") -async def chat_sse( - session_id: str, - request: Request, - user: dict = Depends(get_current_user), -) -> StreamingResponse: - """SSE endpoint: submit input or approval, then stream events until turn ends.""" - agent_session = await _check_session_access(session_id, user, request) - if not agent_session or not agent_session.is_active: - raise HTTPException(status_code=404, detail="Session not found or inactive") - - # Parse body - body = await request.json() - - # Subscribe BEFORE submitting so we never miss events β€” even if the - # agent loop processes the submission before this coroutine continues. - broadcaster = agent_session.broadcaster - sub_id, event_queue = broadcaster.subscribe() - - # Submit the operation - text = body.get("text") - approvals = body.get("approvals") - - # Gate user-message sends against the daily premium-model quota. Approvals are - # continuations of an in-progress turn β€” the session was already charged - # on its first message, so we skip the gate there. - if text is not None and not approvals: - try: - await _enforce_premium_model_quota(user, agent_session) - except HTTPException: - broadcaster.unsubscribe(sub_id) - raise - - try: - if approvals: - formatted = [ - { - "tool_call_id": a["tool_call_id"], - "approved": a["approved"], - "feedback": a.get("feedback"), - "edited_script": a.get("edited_script"), - "namespace": a.get("namespace"), - } - for a in approvals - ] - success = await session_manager.submit_approval(session_id, formatted) - elif text is not None: - success = await session_manager.submit_user_input(session_id, text) - else: - broadcaster.unsubscribe(sub_id) - raise HTTPException( - status_code=400, detail="Must provide 'text' or 'approvals'" - ) - - if not success: - broadcaster.unsubscribe(sub_id) - raise HTTPException(status_code=404, detail="Session not found or inactive") - except HTTPException: - broadcaster.unsubscribe(sub_id) - raise - except Exception: - broadcaster.unsubscribe(sub_id) - raise - - return _sse_response(broadcaster, event_queue, sub_id) - - -@router.post("/pro-click/{session_id}") -async def record_pro_click( - session_id: str, - body: dict, - user: dict = Depends(get_current_user), -) -> dict: - """Record a click on a Pro upgrade CTA shown from inside a session.""" - agent_session = await _check_session_access(session_id, user) - - from agent.core import telemetry - - await telemetry.record_pro_cta_click( - agent_session.session, - source=str(body.get("source") or "unknown"), - target=str(body.get("target") or "pro_pricing"), - ) - if agent_session.session.config.save_sessions: - agent_session.session.save_and_upload_detached( - agent_session.session.config.session_dataset_repo - ) - return {"status": "ok"} - - -# --------------------------------------------------------------------------- -# Shared SSE helpers -# --------------------------------------------------------------------------- -_TERMINAL_EVENTS = { - "turn_complete", - "approval_required", - "error", - "interrupted", - "shutdown", -} -_SSE_KEEPALIVE_SECONDS = 15 - - -def _last_event_seq(request: Request) -> int: - raw = ( - request.headers.get("last-event-id") or request.query_params.get("after") or "0" - ) - try: - return max(0, int(raw)) - except (TypeError, ValueError): - return 0 - - -def _format_sse(msg: dict[str, Any]) -> str: - seq = msg.get("seq") - body = {"event_type": msg.get("event_type"), "data": msg.get("data") or {}} - if seq is not None: - body["seq"] = seq - return f"id: {seq}\ndata: {json.dumps(body)}\n\n" - return f"data: {json.dumps(body)}\n\n" - - -def _event_doc_to_msg(doc: dict[str, Any]) -> dict[str, Any]: - return { - "event_type": doc.get("event_type"), - "data": doc.get("data") or {}, - "seq": doc.get("seq"), - } - - -def _sse_response( - broadcaster, - event_queue, - sub_id, - *, - replay_events: list[dict[str, Any]] | None = None, - after_seq: int = 0, -) -> StreamingResponse: - """Build a StreamingResponse that drains *event_queue* as SSE, - sending keepalive comments every 15 s to prevent proxy timeouts.""" - - async def event_generator(): - try: - for doc in replay_events or []: - msg = _event_doc_to_msg(doc) - seq = msg.get("seq") - if isinstance(seq, int) and seq <= after_seq: - continue - yield _format_sse(msg) - if msg.get("event_type", "") in _TERMINAL_EVENTS: - return - - while True: - try: - msg = await asyncio.wait_for( - event_queue.get(), timeout=_SSE_KEEPALIVE_SECONDS - ) - except asyncio.TimeoutError: - # SSE comment β€” ignored by parsers, keeps connection alive - yield ": keepalive\n\n" - continue - event_type = msg.get("event_type", "") - yield _format_sse(msg) - if event_type in _TERMINAL_EVENTS: - break - finally: - broadcaster.unsubscribe(sub_id) - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -@router.get("/events/{session_id}") -async def subscribe_events( - session_id: str, - request: Request, - user: dict = Depends(get_current_user), -) -> StreamingResponse: - """Subscribe to events for a running session without submitting new input. - - Used by the frontend to re-attach after a connection drop (e.g. screen - sleep). Returns 404 if the session isn't active or isn't processing. - """ - agent_session = await _check_session_access(session_id, user, request) - if not agent_session or not agent_session.is_active: - raise HTTPException(status_code=404, detail="Session not found or inactive") - - after_seq = _last_event_seq(request) - replay_events = await session_manager._store().load_events_after( - session_id, after_seq - ) - broadcaster = agent_session.broadcaster - sub_id, event_queue = broadcaster.subscribe() - return _sse_response( - broadcaster, - event_queue, - sub_id, - replay_events=replay_events, - after_seq=after_seq, - ) - - @router.post("/interrupt/{session_id}") async def interrupt_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Interrupt the current operation in a session.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) success = await session_manager.interrupt(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "interrupted", "session_id": session_id} -@router.get("/session/{session_id}/messages") -async def get_session_messages( - session_id: str, user: dict = Depends(get_current_user) -) -> list[dict]: - """Return the session's message history from memory.""" - agent_session = await _check_session_access(session_id, user) - if not agent_session or not agent_session.is_active: - raise HTTPException(status_code=404, detail="Session not found or inactive") - return [ - msg.model_dump(mode="json") - for msg in agent_session.session.context_manager.items - ] - - @router.post("/undo/{session_id}") async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict: """Undo the last turn in a session.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) success = await session_manager.undo(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "undo_requested", "session_id": session_id} -@router.post("/truncate/{session_id}") -async def truncate_session( - session_id: str, - request: Request, - user: dict = Depends(get_current_user), -) -> dict: - """Truncate conversation to before a specific user message.""" - # Check session ownership before parsing the request body so a 404 on a - # non-existent / non-owned session_id beats the 422 schema-validation error - # (otherwise the response leaks the required field name to non-owners). - await _check_session_access(session_id, user) - try: - body = TruncateRequest(**(await request.json())) - except ValidationError as exc: - # Re-raise as RequestValidationError so FastAPI returns its standard - # structured 422 schema (`{"detail": [{"type":..., "loc":..., ...}]}`) - # instead of a string-stringified Pydantic dump. - raise RequestValidationError(exc.errors()) from exc - except (json.JSONDecodeError, TypeError) as exc: - raise HTTPException(status_code=422, detail=str(exc)) - success = await session_manager.truncate(session_id, body.user_message_index) - if not success: - raise HTTPException( - status_code=404, - detail="Session not found, inactive, or message index out of range", - ) - return {"status": "truncated", "session_id": session_id} - - @router.post("/compact/{session_id}") async def compact_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Compact the context in a session.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) success = await session_manager.compact(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -980,44 +323,82 @@ async def shutdown_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Shutdown a session.""" - await _check_session_access(session_id, user) + _check_session_access(session_id, user) success = await session_manager.shutdown_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} -@router.post("/feedback/{session_id}") -async def submit_feedback( - session_id: str, - body: dict, - user: dict = Depends(get_current_user), -) -> dict: - """Attach a user feedback signal to a session's event log. +@router.websocket("/ws/{session_id}") +async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None: + """WebSocket endpoint for real-time events. + + Authentication is done via: + - ?token= query parameter (for browsers that can't send WS headers) + - Cookie (automatic for same-origin connections) + - Dev mode bypass (when OAUTH_CLIENT_ID is not set) - Body: {rating: "up"|"down"|"outcome_success"|"outcome_fail", - turn_index?: int, comment?: str, message_id?: str} - Appended as a `feedback` event and saved with the session trajectory. + NOTE: We must accept() before close() so the browser receives our custom + close codes (4001, 4003, 4004). If we close() before accept(), Starlette + sends HTTP 403 and the browser only sees code 1006 (abnormal closure). """ - agent_session = await _check_session_access(session_id, user) + logger.info(f"WebSocket connection request for session {session_id}") + + # Authenticate the WebSocket connection + user = await get_ws_user(websocket) + if not user: + logger.warning( + f"WebSocket rejected: authentication failed for session {session_id}" + ) + await websocket.accept() + await websocket.close(code=4001, reason="Authentication required") + return + + # Verify session exists + info = session_manager.get_session_info(session_id) + if not info: + logger.warning(f"WebSocket rejected: session {session_id} not found") + await websocket.accept() + await websocket.close(code=4004, reason="Session not found") + return - rating = body.get("rating") - if rating not in {"up", "down", "outcome_success", "outcome_fail"}: - raise HTTPException(status_code=400, detail="invalid rating") + # Verify user owns the session + if not session_manager.verify_session_access(session_id, user["user_id"]): + logger.warning( + f"WebSocket rejected: user {user['user_id']} denied access to session {session_id}" + ) + await websocket.accept() + await websocket.close(code=4003, reason="Access denied") + return - from agent.core import telemetry + await ws_manager.connect(websocket, session_id) - await telemetry.record_feedback( - agent_session.session, - rating=rating, - turn_index=body.get("turn_index"), - message_id=body.get("message_id"), - comment=body.get("comment"), - ) - # Fire-and-forget save so feedback reaches the dataset even if the user - # closes the tab right after clicking. - if agent_session.session.config.save_sessions: - agent_session.session.save_and_upload_detached( - agent_session.session.config.session_dataset_repo + # Send "ready" immediately on WebSocket connection so the frontend + # knows the session is alive. The original ready event from _run_session + # fires before the WS is connected and is always lost. + try: + await websocket.send_json( + { + "event_type": "ready", + "data": {"message": "Agent initialized"}, + } ) - return {"status": "ok"} + except Exception as e: + logger.error(f"Failed to send ready event for session {session_id}: {e}") + + try: + while True: + # Keep connection alive, handle ping/pong + data = await websocket.receive_json() + + # Handle client messages (e.g., ping) + if data.get("type") == "ping": + await websocket.send_json({"type": "pong"}) + + except WebSocketDisconnect: + logger.info(f"WebSocket disconnected for session {session_id}") + except Exception as e: + logger.error(f"WebSocket error for session {session_id}: {e}") + finally: + ws_manager.disconnect(session_id) diff --git a/backend/routes/auth.py b/backend/routes/auth.py index d736deff1841dcc89594f2abb8728bc7306741f5..224febf4b926890eb58943e3103a985fe0ed4626 100644 --- a/backend/routes/auth.py +++ b/backend/routes/auth.py @@ -4,47 +4,28 @@ Handles the OAuth 2.0 authorization code flow with HF as provider. After successful auth, sets an HttpOnly cookie with the access token. """ -import logging import os import secrets import time from urllib.parse import urlencode import httpx -from dependencies import ( - AUTH_ENABLED, - OAUTH_SCOPE_COOKIE, - REQUIRED_OAUTH_SCOPES, - configured_oauth_scopes, - get_current_user, - oauth_scope_fingerprint, -) +from dependencies import AUTH_ENABLED, get_current_user from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import RedirectResponse router = APIRouter(prefix="/auth", tags=["auth"]) -logger = logging.getLogger(__name__) # OAuth configuration from environment OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "") OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "") OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") -OAUTH_SCOPES = configured_oauth_scopes() # In-memory OAuth state store with expiry (5 min TTL) _OAUTH_STATE_TTL = 300 oauth_states: dict[str, dict] = {} -def _missing_required_scopes(token_data: dict) -> set[str]: - raw_scopes = token_data.get("scope") - if not isinstance(raw_scopes, str) or not raw_scopes.strip(): - logger.debug("OAuth token response omitted a usable scope field") - return set() - granted = set(raw_scopes.split()) - return set(REQUIRED_OAUTH_SCOPES) - granted - - def _cleanup_expired_states() -> None: """Remove expired OAuth states to prevent memory growth.""" now = time.time() @@ -82,15 +63,16 @@ async def oauth_login(request: Request) -> RedirectResponse: "expires_at": time.time() + _OAUTH_STATE_TTL, } - # Build authorization URL. We no longer suggest a default `orgIds` β€” - # users no longer need to join the ML Agent Explorers org to use the - # app, and HF Jobs are billed per-namespace via credits. + # Build authorization URL params = { "client_id": OAUTH_CLIENT_ID, "redirect_uri": get_redirect_uri(request), - "scope": " ".join(OAUTH_SCOPES), + "scope": "openid profile read-repos write-repos contribute-repos manage-repos inference-api jobs write-discussions", "response_type": "code", "state": state, + "orgIds": os.environ.get( + "HF_OAUTH_ORG_ID", "698dbf55845d85df163175f1" + ), # ml-agent-explorers } auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}" @@ -138,15 +120,6 @@ async def oauth_callback( status_code=500, detail="Token exchange succeeded but no access_token was returned.", ) - missing_scopes = _missing_required_scopes(token_data) - if missing_scopes: - raise HTTPException( - status_code=403, - detail=( - "OAuth token is missing required scopes: " - + ", ".join(sorted(missing_scopes)) - ), - ) # Fetch user info (optional β€” failure is not fatal) async with httpx.AsyncClient() as client: @@ -169,16 +142,7 @@ async def oauth_callback( httponly=True, secure=is_production, # Secure flag only in production (HTTPS) samesite="lax", - max_age=3600 * 24 * 7, # 7 days - path="/", - ) - response.set_cookie( - key=OAUTH_SCOPE_COOKIE, - value=oauth_scope_fingerprint(OAUTH_SCOPES), - httponly=True, - secure=is_production, - samesite="lax", - max_age=3600 * 24 * 7, + max_age=3600 * 24, # 24 hours path="/", ) return response @@ -189,7 +153,6 @@ async def logout() -> RedirectResponse: """Log out the user by clearing the auth cookie.""" response = RedirectResponse(url="/") response.delete_cookie(key="hf_access_token", path="/") - response.delete_cookie(key=OAUTH_SCOPE_COOKIE, path="/") return response @@ -205,4 +168,4 @@ async def get_me(user: dict = Depends(get_current_user)) -> dict: Uses the shared auth dependency which handles cookie + Bearer token. """ - return {key: value for key, value in user.items() if not key.startswith("_")} + return user diff --git a/backend/session_manager.py b/backend/session_manager.py index 3c992c9c09d6334e23db7c2b06a2b6b1cd8e662e..03d9b2d9b8d706f1fa391f69e43e759d77246b86 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -1,25 +1,23 @@ """Session manager for handling multiple concurrent agent sessions.""" import asyncio -import json import logging -import os import uuid from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Optional +from websocket import manager as ws_manager + from agent.config import load_config from agent.core.agent_loop import process_submission from agent.core.session import Event, OpType, Session -from agent.core.session_persistence import get_session_store from agent.core.tools import ToolRouter -from agent.messaging.gateway import NotificationGateway # Get project root (parent of backend directory) PROJECT_ROOT = Path(__file__).parent.parent -DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "frontend_agent_config.json") +DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json") # These dataclasses match agent/main.py structure @@ -42,47 +40,6 @@ class Submission: logger = logging.getLogger(__name__) -class EventBroadcaster: - """Reads from the agent's event queue and fans out to SSE subscribers. - - Events that arrive when no subscribers are listening are discarded by - this in-memory fanout. Durable replay is handled by session_persistence. - """ - - def __init__(self, event_queue: asyncio.Queue): - self._source = event_queue - self._subscribers: dict[int, asyncio.Queue] = {} - self._counter = 0 - - def subscribe(self) -> tuple[int, asyncio.Queue]: - """Create a new subscriber. Returns (id, queue).""" - self._counter += 1 - sub_id = self._counter - q: asyncio.Queue = asyncio.Queue() - self._subscribers[sub_id] = q - return sub_id, q - - def unsubscribe(self, sub_id: int) -> None: - self._subscribers.pop(sub_id, None) - - async def run(self) -> None: - """Main loop β€” reads from source queue and broadcasts.""" - while True: - try: - event: Event = await self._source.get() - msg = { - "event_type": event.event_type, - "data": event.data, - "seq": event.seq, - } - for q in self._subscribers.values(): - await q.put(msg) - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"EventBroadcaster error: {e}") - - @dataclass class AgentSession: """Wrapper for an agent session with its associated resources.""" @@ -92,18 +49,10 @@ class AgentSession: tool_router: ToolRouter submission_queue: asyncio.Queue user_id: str = "dev" # Owner of this session - hf_username: str | None = None # HF namespace used for personal trace uploads hf_token: str | None = None # User's HF OAuth token for tool execution task: asyncio.Task | None = None created_at: datetime = field(default_factory=datetime.utcnow) is_active: bool = True - is_processing: bool = False # True while a submission is being executed - broadcaster: Any = None - title: str | None = None - # True once this session has been counted against the user's daily - # Claude quota. Guards double-counting when the user re-selects an - # Anthropic model mid-session. - claude_counted: bool = False class SessionCapacityError(Exception): @@ -115,15 +64,10 @@ class SessionCapacityError(Exception): # ── Capacity limits ───────────────────────────────────────────────── -# Sized for HF Spaces 8 vCPU / 32 GB RAM. -# Each session uses ~10-20 MB (context, tools, queues, task); 200 Γ— 20 MB -# = 4 GB worst case, leaving plenty of headroom for the Python runtime -# and per-request overhead. -MAX_SESSIONS: int = 200 +# Estimated for HF Spaces cpu-basic (2 vCPU, 16 GB RAM). +# Each session uses ~10-20 MB (context, tools, queues, task). +MAX_SESSIONS: int = 50 MAX_SESSIONS_PER_USER: int = 10 -DEFAULT_YOLO_COST_CAP_USD: float = 5.0 -SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10 -SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0 class SessionManager: @@ -131,563 +75,18 @@ class SessionManager: def __init__(self, config_path: str | None = None) -> None: self.config = load_config(config_path or DEFAULT_CONFIG_PATH) - self.messaging_gateway = NotificationGateway(self.config.messaging) self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() - self.persistence_store = None - - async def start(self) -> None: - """Start shared background resources.""" - self.persistence_store = get_session_store() - await self.persistence_store.init() - await self.messaging_gateway.start() - - async def close(self) -> None: - """Flush and close shared background resources.""" - await self._cleanup_all_sandboxes_on_close() - await self.messaging_gateway.close() - if self.persistence_store is not None: - await self.persistence_store.close() - - def _store(self): - if self.persistence_store is None: - self.persistence_store = get_session_store() - return self.persistence_store def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" return sum( - 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active - ) - - def _create_session_sync( - self, - *, - session_id: str, - user_id: str, - hf_username: str | None, - hf_token: str | None, - model: str | None, - event_queue: asyncio.Queue, - notification_destinations: list[str] | None = None, - ) -> tuple[ToolRouter, Session]: - """Build blocking per-session resources in a worker thread.""" - import time as _time - - t0 = _time.monotonic() - tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) - # Deep-copy config so each session's model switches independently β€” - # tab A picking GLM doesn't flip tab B off Claude. - session_config = self.config.model_copy(deep=True) - if model: - session_config.model_name = model - session = Session( - event_queue=event_queue, - config=session_config, - tool_router=tool_router, - hf_token=hf_token, - user_id=user_id, - hf_username=hf_username, - notification_gateway=self.messaging_gateway, - notification_destinations=notification_destinations or [], - session_id=session_id, - persistence_store=self._store(), - ) - t1 = _time.monotonic() - logger.info("Session initialized in %.2fs", t1 - t0) - return tool_router, session - - def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: - return [msg.model_dump(mode="json") for msg in session.context_manager.items] - - def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]: - pending = session.pending_approval or {} - tool_calls = pending.get("tool_calls") or [] - serialized: list[dict[str, Any]] = [] - for tc in tool_calls: - if hasattr(tc, "model_dump"): - serialized.append(tc.model_dump(mode="json")) - elif isinstance(tc, dict): - serialized.append(tc) - return serialized - - @staticmethod - def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None: - pending = session.pending_approval or {} - tool_calls = pending.get("tool_calls") or [] - if not tool_calls: - return None - result: list[dict[str, Any]] = [] - for tc in tool_calls: - try: - args = json.loads(tc.function.arguments) - except (json.JSONDecodeError, AttributeError, TypeError): - args = {} - result.append( - { - "tool": getattr(tc.function, "name", None), - "tool_call_id": getattr(tc, "id", None), - "arguments": args, - } - ) - return result - - def _restore_pending_approval( - self, session: Session, pending_approval: list[dict[str, Any]] | None - ) -> None: - if not pending_approval: - session.pending_approval = None - return - from litellm import ChatCompletionMessageToolCall as ToolCall - - restored = [] - for raw in pending_approval: - try: - if "function" in raw: - restored.append(ToolCall(**raw)) - else: - restored.append( - ToolCall( - id=raw["tool_call_id"], - type="function", - function={ - "name": raw["tool"], - "arguments": json.dumps(raw.get("arguments") or {}), - }, - ) - ) - except Exception as e: - logger.warning("Dropping malformed pending approval: %s", e) - session.pending_approval = {"tool_calls": restored} if restored else None - - @staticmethod - def _pending_docs_for_api( - pending_approval: list[dict[str, Any]] | None, - ) -> list[dict[str, Any]] | None: - if not pending_approval: - return None - result: list[dict[str, Any]] = [] - for raw in pending_approval: - if "function" in raw: - function = raw.get("function") or {} - try: - args = json.loads(function.get("arguments") or "{}") - except (json.JSONDecodeError, TypeError): - args = {} - result.append( - { - "tool": function.get("name"), - "tool_call_id": raw.get("id"), - "arguments": args, - } - ) - elif {"tool", "tool_call_id"}.issubset(raw): - result.append( - { - "tool": raw.get("tool"), - "tool_call_id": raw.get("tool_call_id"), - "arguments": raw.get("arguments") or {}, - } - ) - return result or None - - @staticmethod - def _runtime_state(agent_session: AgentSession) -> str: - if agent_session.session.pending_approval: - return "waiting_approval" - if agent_session.is_processing: - return "processing" - if not agent_session.is_active: - return "ended" - return "idle" - - @staticmethod - def _auto_approval_summary(session: Session) -> dict[str, Any]: - if hasattr(session, "auto_approval_policy_summary"): - return session.auto_approval_policy_summary() - cap = getattr(session, "auto_approval_cost_cap_usd", None) - estimated = float( - getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0 - ) - remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4) - return { - "enabled": bool(getattr(session, "auto_approval_enabled", False)), - "cost_cap_usd": cap, - "estimated_spend_usd": round(estimated, 4), - "remaining_usd": remaining, - } - - async def _start_agent_session( - self, - *, - agent_session: AgentSession, - event_queue: asyncio.Queue, - tool_router: ToolRouter, - ) -> AgentSession: - async with self._lock: - existing = self.sessions.get(agent_session.session_id) - if existing: - return existing - self.sessions[agent_session.session_id] = agent_session - - task = asyncio.create_task( - self._run_session( - agent_session.session_id, - agent_session.submission_queue, - event_queue, - tool_router, - ) - ) - agent_session.task = task - return agent_session - - @staticmethod - def _start_cpu_sandbox_preload(agent_session: AgentSession) -> None: - """Kick off a best-effort cpu-basic sandbox for the session.""" - try: - from agent.tools.sandbox_tool import start_cpu_sandbox_preload - - start_cpu_sandbox_preload(agent_session.session) - except Exception as e: - logger.warning( - "Failed to start CPU sandbox preload for %s: %s", - agent_session.session_id, - e, - ) - - @staticmethod - def _can_access_session(agent_session: AgentSession, user_id: str) -> bool: - return ( - user_id == "dev" - or agent_session.user_id == "dev" - or agent_session.user_id == user_id - ) - - @staticmethod - def _update_hf_identity( - agent_session: AgentSession, - *, - hf_token: str | None, - hf_username: str | None, - ) -> None: - if hf_token: - agent_session.hf_token = hf_token - agent_session.session.hf_token = hf_token - if hf_username: - agent_session.hf_username = hf_username - agent_session.session.hf_username = hf_username - - @staticmethod - def _has_active_sandbox_preload(agent_session: AgentSession) -> bool: - task = getattr(agent_session.session, "sandbox_preload_task", None) - return bool(task and not task.done()) - - @staticmethod - def _preload_failed_for_missing_hf_token(agent_session: AgentSession) -> bool: - error = getattr(agent_session.session, "sandbox_preload_error", None) - return isinstance(error, str) and error.startswith("No HF token available") - - def _restart_cpu_preload_if_token_recovered( - self, - agent_session: AgentSession, - *, - preload_sandbox: bool, - ) -> None: - if not preload_sandbox: - return - session = agent_session.session - if getattr(session, "sandbox", None): - return - if self._has_active_sandbox_preload(agent_session): - return - if not (agent_session.hf_token or getattr(session, "hf_token", None)): - return - - if not self._preload_failed_for_missing_hf_token(agent_session): - return - - session.sandbox_preload_error = None - session.sandbox_preload_task = None - session.sandbox_preload_cancel_event = None - self._start_cpu_sandbox_preload(agent_session) - - async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None: - try: - await self._store().update_session_fields( - session_id, - sandbox_space_id=None, - sandbox_hardware=None, - sandbox_owner=None, - sandbox_created_at=None, - sandbox_status="destroyed", - ) - except Exception as e: - logger.warning("Failed to clear sandbox metadata for %s: %s", session_id, e) - - async def _cleanup_persisted_sandbox( - self, - session_id: str, - metadata: dict[str, Any], - *, - hf_token: str | None, - ) -> None: - """Delete a sandbox recorded by a previous backend process, if any.""" - space_id = metadata.get("sandbox_space_id") - if not isinstance(space_id, str) or not space_id: - return - if metadata.get("sandbox_status") == "destroyed": - return - - tokens: list[tuple[str, str]] = [] - seen: set[str] = set() - for label, token in ( - ("user", hf_token), - ("admin", os.environ.get("HF_ADMIN_TOKEN")), - ): - if token and token not in seen: - tokens.append((label, token)) - seen.add(token) - - if not tokens: - logger.warning( - "Cannot clean persisted sandbox %s for session %s: no HF token available", - space_id, - session_id, - ) - return - - last_err: Exception | None = None - for label, token in tokens: - try: - from huggingface_hub import HfApi - - api = HfApi(token=token) - await asyncio.to_thread( - api.delete_repo, - repo_id=space_id, - repo_type="space", - ) - logger.info( - "Deleted persisted sandbox %s for session %s with %s token", - space_id, - session_id, - label, - ) - await self._clear_persisted_sandbox_metadata(session_id) - return - except Exception as e: - status_code = getattr(getattr(e, "response", None), "status_code", None) - if status_code == 404: - logger.info( - "Persisted sandbox %s for session %s is already gone", - space_id, - session_id, - ) - await self._clear_persisted_sandbox_metadata(session_id) - return - last_err = e - - logger.warning( - "Failed to delete persisted sandbox %s for session %s: %s", - space_id, - session_id, - last_err, - ) - - async def persist_session_snapshot( - self, - agent_session: AgentSession, - *, - runtime_state: str | None = None, - status: str = "active", - ) -> None: - """Persist the current runtime context snapshot.""" - store = self._store() - if not getattr(store, "enabled", False): - return - try: - await store.save_snapshot( - session_id=agent_session.session_id, - user_id=agent_session.user_id, - model=agent_session.session.config.model_name, - title=agent_session.title, - messages=self._serialize_messages(agent_session.session), - runtime_state=runtime_state or self._runtime_state(agent_session), - status=status, - turn_count=agent_session.session.turn_count, - pending_approval=self._serialize_pending_approval( - agent_session.session - ), - claude_counted=agent_session.claude_counted, - created_at=agent_session.created_at, - notification_destinations=list( - agent_session.session.notification_destinations - ), - auto_approval_enabled=bool( - getattr(agent_session.session, "auto_approval_enabled", False) - ), - auto_approval_cost_cap_usd=getattr( - agent_session.session, "auto_approval_cost_cap_usd", None - ), - auto_approval_estimated_spend_usd=float( - getattr( - agent_session.session, - "auto_approval_estimated_spend_usd", - 0.0, - ) - or 0.0 - ), - ) - except Exception as e: - logger.warning( - "Failed to persist snapshot for %s: %s", - agent_session.session_id, - e, - ) - - async def ensure_session_loaded( - self, - session_id: str, - user_id: str, - hf_token: str | None = None, - hf_username: str | None = None, - preload_sandbox: bool = True, - ) -> AgentSession | None: - """Return a live runtime session, lazily restoring it from Mongo.""" - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, - hf_username=hf_username, - ) - self._restart_cpu_preload_if_token_recovered( - existing, - preload_sandbox=preload_sandbox, - ) - return existing - return None - - store = self._store() - loaded = await store.load_session(session_id) - if not loaded: - return None - - async with self._lock: - existing = self.sessions.get(session_id) - if existing: - if self._can_access_session(existing, user_id): - self._update_hf_identity( - existing, - hf_token=hf_token, - hf_username=hf_username, - ) - self._restart_cpu_preload_if_token_recovered( - existing, - preload_sandbox=preload_sandbox, - ) - return existing - return None - - meta = loaded.get("metadata") or {} - owner = str(meta.get("user_id") or "") - if user_id != "dev" and owner != "dev" and owner != user_id: - return None - - await self._cleanup_persisted_sandbox( - session_id, - meta, - hf_token=hf_token, - ) - - from litellm import Message - - model = meta.get("model") or self.config.model_name - event_queue: asyncio.Queue = asyncio.Queue() - submission_queue: asyncio.Queue = asyncio.Queue() - tool_router, session = await asyncio.to_thread( - self._create_session_sync, - session_id=session_id, - user_id=owner or user_id, - hf_username=hf_username, - hf_token=hf_token, - model=model, - event_queue=event_queue, - notification_destinations=meta.get("notification_destinations") or [], + 1 + for s in self.sessions.values() + if s.user_id == user_id and s.is_active ) - restored_messages: list[Message] = [] - for raw in loaded.get("messages") or []: - if not isinstance(raw, dict) or raw.get("role") == "system": - continue - try: - restored_messages.append(Message.model_validate(raw)) - except Exception as e: - logger.warning("Dropping malformed restored message: %s", e) - if restored_messages: - # Keep the freshly-rendered system prompt, then attach the durable - # non-system context so tools/date/user context stay current. - session.context_manager.items = [ - session.context_manager.items[0], - *restored_messages, - ] - - self._restore_pending_approval(session, meta.get("pending_approval") or []) - session.turn_count = int(meta.get("turn_count") or 0) - session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False)) - raw_cap = meta.get("auto_approval_cost_cap_usd") - session.auto_approval_cost_cap_usd = ( - float(raw_cap) if isinstance(raw_cap, int | float) else None - ) - session.auto_approval_estimated_spend_usd = float( - meta.get("auto_approval_estimated_spend_usd") or 0.0 - ) - - created_at = meta.get("created_at") - if not isinstance(created_at, datetime): - created_at = datetime.utcnow() - - agent_session = AgentSession( - session_id=session_id, - session=session, - tool_router=tool_router, - submission_queue=submission_queue, - user_id=owner or user_id, - hf_username=hf_username, - hf_token=hf_token, - created_at=created_at, - is_active=True, - is_processing=False, - claude_counted=bool(meta.get("claude_counted")), - title=meta.get("title"), - ) - started = await self._start_agent_session( - agent_session=agent_session, - event_queue=event_queue, - tool_router=tool_router, - ) - if started is not agent_session: - self._update_hf_identity( - started, - hf_token=hf_token, - hf_username=hf_username, - ) - return started - if preload_sandbox: - self._start_cpu_sandbox_preload(agent_session) - logger.info("Restored session %s for user %s", session_id, owner or user_id) - return agent_session - - async def create_session( - self, - user_id: str = "dev", - hf_username: str | None = None, - hf_token: str | None = None, - model: str | None = None, - is_pro: bool | None = None, - ) -> str: + async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str: """Create a new agent session and return its ID. Session() and ToolRouter() constructors contain blocking I/O @@ -696,11 +95,6 @@ class SessionManager: Args: user_id: The ID of the user who owns this session. - hf_username: The HF username/namespace used for personal trace uploads. - hf_token: The user's HF OAuth token, stored for tool execution. - model: Optional model override. When set, replaces ``model_name`` - on the per-session config clone. None falls back to the - config default. Raises: SessionCapacityError: If the server or user has reached the @@ -731,15 +125,22 @@ class SessionManager: event_queue: asyncio.Queue = asyncio.Queue() # Run blocking constructors in a thread to keep the event loop responsive. - tool_router, session = await asyncio.to_thread( - self._create_session_sync, - session_id=session_id, - user_id=user_id, - hf_username=hf_username, - hf_token=hf_token, - model=model, - event_queue=event_queue, - ) + # Without this, Session.__init__ β†’ ContextManager β†’ litellm.get_max_tokens() + # blocks all HTTP/WebSocket handling. + import time as _time + + def _create_session_sync(): + t0 = _time.monotonic() + tool_router = ToolRouter(self.config.mcpServers) + session = Session(event_queue, config=self.config, tool_router=tool_router) + t1 = _time.monotonic() + logger.info(f"Session initialized in {t1 - t0:.2f}s") + return tool_router, session + + tool_router, session = await asyncio.to_thread(_create_session_sync) + + # Store user's HF token on the session so tools can use it + session.hf_token = hf_token # Create wrapper agent_session = AgentSession( @@ -748,165 +149,21 @@ class SessionManager: tool_router=tool_router, submission_queue=submission_queue, user_id=user_id, - hf_username=hf_username, hf_token=hf_token, ) - await self._start_agent_session( - agent_session=agent_session, - event_queue=event_queue, - tool_router=tool_router, - ) - await self.persist_session_snapshot(agent_session, runtime_state="idle") - self._start_cpu_sandbox_preload(agent_session) + async with self._lock: + self.sessions[session_id] = agent_session - if is_pro is not None and user_id and user_id != "dev": - await self._track_pro_status(agent_session, is_pro=is_pro) + # Start the agent loop task + task = asyncio.create_task( + self._run_session(session_id, submission_queue, event_queue, tool_router) + ) + agent_session.task = task logger.info(f"Created session {session_id} for user {user_id}") return session_id - async def _track_pro_status( - self, agent_session: AgentSession, *, is_pro: bool - ) -> None: - """Update Mongo per-user Pro state and emit a one-shot conversion - event if the store reports a freeβ†’Pro transition. Best-effort: any - Mongo failure is swallowed so we never fail session creation on - telemetry.""" - store = self._store() - if not getattr(store, "enabled", False): - return - try: - result = await store.mark_pro_seen(agent_session.user_id, is_pro=is_pro) - except Exception as e: - logger.debug("mark_pro_seen failed: %s", e) - return - if not result or not result.get("converted"): - return - try: - from agent.core import telemetry - - await telemetry.record_pro_conversion( - agent_session.session, - first_seen_at=result.get("first_seen_at"), - ) - except Exception as e: - logger.debug("record_pro_conversion failed: %s", e) - - async def seed_from_summary(self, session_id: str, messages: list[dict]) -> int: - """Rehydrate a session from cached prior messages via summarization. - - Runs the standard summarization prompt (same one compaction uses) - over the provided messages, then seeds the new session's context - with that summary. Tool-call pairing concerns disappear because the - output is plain text. Returns the number of messages summarized. - """ - from litellm import Message - - from agent.context_manager.manager import _RESTORE_PROMPT, summarize_messages - - agent_session = self.sessions.get(session_id) - if not agent_session: - raise ValueError(f"Session {session_id} not found") - - # Parse into Message objects, tolerating malformed entries. - parsed: list[Message] = [] - for raw in messages: - if raw.get("role") == "system": - continue # the new session has its own system prompt - try: - parsed.append(Message.model_validate(raw)) - except Exception as e: - logger.warning("Dropping malformed message during seed: %s", e) - - if not parsed: - return 0 - - session = agent_session.session - # Pass the real tool specs so the summarizer sees what the agent - # actually has β€” otherwise Anthropic's modify_params injects a - # dummy tool and the summarizer editorializes that the original - # tool calls were fabricated. - tool_specs = None - try: - tool_specs = agent_session.tool_router.get_tool_specs_for_llm() - except Exception: - pass - try: - summary, _ = await summarize_messages( - parsed, - model_name=session.config.model_name, - hf_token=session.hf_token, - max_tokens=4000, - prompt=_RESTORE_PROMPT, - tool_specs=tool_specs, - session=session, - kind="restore", - ) - except Exception as e: - logger.error("Summary call failed during seed: %s", e) - raise - - seed = Message( - role="user", - content=( - "[SYSTEM: Your prior memory of this conversation β€” written " - "in your own voice right before restart. Continue from here.]\n\n" - + (summary or "(no summary returned)") - ), - ) - session.context_manager.items.append(seed) - await self.persist_session_snapshot(agent_session, runtime_state="idle") - return len(parsed) - - @staticmethod - async def _cleanup_sandbox(session: Session) -> None: - """Delete the sandbox Space if one was created for this session. - - Retries on transient failures (HF API 5xx, rate-limit, network blips) - with exponential backoff. A single missed delete = a permanently - orphaned Space, so the cost of an extra retry beats the alternative. - """ - from agent.tools.sandbox_tool import teardown_session_sandbox - - await teardown_session_sandbox(session) - - async def _cleanup_all_sandboxes_on_close(self) -> None: - """Best-effort sandbox cleanup for graceful backend shutdown.""" - async with self._lock: - agent_sessions = list(self.sessions.values()) - if not agent_sessions: - return - - semaphore = asyncio.Semaphore(SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY) - - async def _cleanup_one(agent_session: AgentSession) -> None: - async with semaphore: - try: - await self._cleanup_sandbox(agent_session.session) - except Exception as e: - logger.warning( - "Shutdown sandbox cleanup failed for %s: %s", - agent_session.session_id, - e, - ) - - tasks = [ - asyncio.create_task(_cleanup_one(agent_session)) - for agent_session in agent_sessions - ] - try: - await asyncio.wait_for( - asyncio.gather(*tasks, return_exceptions=True), - timeout=SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, - ) - except asyncio.TimeoutError: - logger.warning( - "Timed out after %.0fs cleaning up sandboxes on shutdown; " - "orphan sweeper will handle any stragglers", - SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, - ) - async def _run_session( self, session_id: str, @@ -914,7 +171,7 @@ class SessionManager: event_queue: asyncio.Queue, tool_router: ToolRouter, ) -> None: - """Run the agent loop for a session and broadcast events via EventBroadcaster.""" + """Run the agent loop for a session and forward events to WebSocket.""" agent_session = self.sessions.get(session_id) if not agent_session: logger.error(f"Session {session_id} not found") @@ -922,10 +179,10 @@ class SessionManager: session = agent_session.session - # Start event broadcaster task - broadcaster = EventBroadcaster(event_queue) - agent_session.broadcaster = broadcaster - broadcast_task = asyncio.create_task(broadcaster.run()) + # Start event forwarder task + event_forwarder = asyncio.create_task( + self._forward_events(session_id, event_queue) + ) try: async with tool_router: @@ -940,14 +197,7 @@ class SessionManager: submission = await asyncio.wait_for( submission_queue.get(), timeout=1.0 ) - agent_session.is_processing = True - try: - should_continue = await process_submission( - session, submission - ) - finally: - agent_session.is_processing = False - await self.persist_session_snapshot(agent_session) + should_continue = await process_submission(session, submission) if not should_continue: break except asyncio.TimeoutError: @@ -962,36 +212,31 @@ class SessionManager: ) finally: - broadcast_task.cancel() + event_forwarder.cancel() try: - await broadcast_task + await event_forwarder except asyncio.CancelledError: pass - await self._cleanup_sandbox(session) - - # Final-flush: always save on session death so we capture ended - # sessions even if the client disconnects without /shutdown. - # Idempotent via session_id key; detached subprocess. - if session.config.save_sessions: - try: - session.save_and_upload_detached( - session.config.session_dataset_repo - ) - except Exception as e: - logger.warning(f"Final-flush failed for {session_id}: {e}") - async with self._lock: if session_id in self.sessions: self.sessions[session_id].is_active = False - await self.persist_session_snapshot( - self.sessions[session_id], - runtime_state="ended", - status="ended", - ) logger.info(f"Session {session_id} ended") + async def _forward_events( + self, session_id: str, event_queue: asyncio.Queue + ) -> None: + """Forward events from the agent to the WebSocket.""" + while True: + try: + event: Event = await event_queue.get() + await ws_manager.send_event(session_id, event.event_type, event.data) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error forwarding event for {session_id}: {e}") + async def submit(self, session_id: str, operation: Operation) -> bool: """Submit an operation to a session.""" async with self._lock: @@ -1020,31 +265,15 @@ class SessionManager: return await self.submit(session_id, operation) async def interrupt(self, session_id: str) -> bool: - """Interrupt a session by signalling cancellation directly (bypasses queue).""" - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: - return False - agent_session.session.cancel() - return True + """Interrupt a session.""" + operation = Operation(op_type=OpType.INTERRUPT) + return await self.submit(session_id, operation) async def undo(self, session_id: str) -> bool: """Undo last turn in a session.""" operation = Operation(op_type=OpType.UNDO) return await self.submit(session_id, operation) - async def truncate(self, session_id: str, user_message_index: int) -> bool: - """Truncate conversation to before a specific user message (direct, no queue).""" - async with self._lock: - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: - return False - success = agent_session.session.context_manager.truncate_to_user_message( - user_message_index - ) - if success: - await self.persist_session_snapshot(agent_session, runtime_state="idle") - return success - async def compact(self, session_id: str) -> bool: """Compact context in a session.""" operation = Operation(op_type=OpType.COMPACT) @@ -1068,18 +297,12 @@ class SessionManager: return success async def delete_session(self, session_id: str) -> bool: - """Soft-delete a session and stop its runtime resources.""" + """Delete a session entirely.""" async with self._lock: agent_session = self.sessions.pop(session_id, None) if not agent_session: - await self._store().soft_delete_session(session_id) - return True - - await self._store().soft_delete_session(session_id) - - # Clean up sandbox Space before cancelling the task - await self._cleanup_sandbox(agent_session.session) + return False # Cancel the task if running if agent_session.task and not agent_session.task.done(): @@ -1091,68 +314,6 @@ class SessionManager: return True - async def teardown_sandbox(self, session_id: str) -> bool: - """Delete only this session's sandbox runtime, preserving chat state.""" - async with self._lock: - agent_session = self.sessions.get(session_id) - - if not agent_session or not agent_session.is_active: - return False - - await self._cleanup_sandbox(agent_session.session) - await self.persist_session_snapshot(agent_session, runtime_state="idle") - return True - - async def update_session_title(self, session_id: str, title: str | None) -> None: - """Persist a user-visible title for sidebar rehydration.""" - agent_session = self.sessions.get(session_id) - if agent_session: - agent_session.title = title - await self._store().update_session_fields(session_id, title=title) - - async def update_session_model(self, session_id: str, model_id: str) -> bool: - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: - return False - agent_session.session.update_model(model_id) - await self.persist_session_snapshot(agent_session, runtime_state="idle") - return True - - async def update_session_auto_approval( - self, - session_id: str, - *, - enabled: bool, - cost_cap_usd: float | None, - cap_provided: bool = False, - ) -> dict[str, Any]: - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: - raise ValueError("Session not found or inactive") - - session = agent_session.session - if enabled: - if not cap_provided and cost_cap_usd is None: - cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) - if cost_cap_usd is None: - cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD - elif cost_cap_usd is None: - cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD - else: - if not cap_provided: - cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) - - if hasattr(session, "set_auto_approval_policy"): - session.set_auto_approval_policy( - enabled=enabled, - cost_cap_usd=cost_cap_usd, - ) - else: - session.auto_approval_enabled = bool(enabled) - session.auto_approval_cost_cap_usd = cost_cap_usd - await self.persist_session_snapshot(agent_session) - return self._auto_approval_summary(session) - def get_session_owner(self, session_id: str) -> str | None: """Get the user_id that owns a session, or None if session doesn't exist.""" agent_session = self.sessions.get(session_id) @@ -1180,117 +341,22 @@ class SessionManager: if not agent_session: return None - pending_approval = self._pending_tools_for_api(agent_session.session) - return { "session_id": session_id, "created_at": agent_session.created_at.isoformat(), "is_active": agent_session.is_active, - "is_processing": agent_session.is_processing, "message_count": len(agent_session.session.context_manager.items), "user_id": agent_session.user_id, - "pending_approval": pending_approval, - "model": agent_session.session.config.model_name, - "title": agent_session.title, - "notification_destinations": list( - agent_session.session.notification_destinations - ), - "auto_approval": self._auto_approval_summary(agent_session.session), } - def set_notification_destinations( - self, session_id: str, destinations: list[str] - ) -> list[str]: - """Replace the session's opted-in auto-notification destinations.""" - agent_session = self.sessions.get(session_id) - if not agent_session or not agent_session.is_active: - raise ValueError("Session not found or inactive") - - normalized: list[str] = [] - seen: set[str] = set() - for raw_name in destinations: - name = raw_name.strip() - if not name: - raise ValueError("Destination names must not be empty") - destination = self.config.messaging.get_destination(name) - if destination is None: - raise ValueError(f"Unknown destination '{name}'") - if not destination.allow_auto_events: - raise ValueError(f"Destination '{name}' is not enabled for auto events") - if name not in seen: - normalized.append(name) - seen.add(name) - - agent_session.session.set_notification_destinations(normalized) - return normalized - - async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: + def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: """List sessions, optionally filtered by user. Args: user_id: If provided, only return sessions owned by this user. If "dev", return all sessions (dev mode). """ - results: list[dict[str, Any]] = [] - store = self._store() - if getattr(store, "enabled", False): - for row in await store.list_sessions(user_id or "dev"): - sid = row.get("session_id") or row.get("_id") - if not sid: - continue - runtime_info = self.get_session_info(str(sid)) - if runtime_info: - results.append(runtime_info) - continue - created_at = row.get("created_at") - if isinstance(created_at, datetime): - created_at_str = created_at.isoformat() - else: - created_at_str = str(created_at or datetime.utcnow().isoformat()) - pending = self._pending_docs_for_api(row.get("pending_approval") or []) - results.append( - { - "session_id": str(sid), - "created_at": created_at_str, - "is_active": row.get("status") != "ended", - "is_processing": row.get("runtime_state") == "processing", - "message_count": int(row.get("message_count") or 0), - "user_id": row.get("user_id") or "dev", - "pending_approval": pending or None, - "model": row.get("model"), - "title": row.get("title"), - "notification_destinations": row.get( - "notification_destinations" - ) - or [], - "auto_approval": { - "enabled": bool(row.get("auto_approval_enabled", False)), - "cost_cap_usd": row.get("auto_approval_cost_cap_usd"), - "estimated_spend_usd": float( - row.get("auto_approval_estimated_spend_usd") or 0.0 - ), - "remaining_usd": ( - None - if row.get("auto_approval_cost_cap_usd") is None - else round( - max( - 0.0, - float( - row.get("auto_approval_cost_cap_usd") or 0.0 - ) - - float( - row.get("auto_approval_estimated_spend_usd") - or 0.0 - ), - ), - 4, - ) - ), - }, - } - ) - return results - + results = [] for sid in self.sessions: info = self.get_session_info(sid) if not info: diff --git a/backend/start.sh b/backend/start.sh deleted file mode 100755 index 72b35198f89ef41a73c3119843d2ac21a9cf0a42..0000000000000000000000000000000000000000 --- a/backend/start.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -# Entrypoint for HF Spaces dev mode compatibility. -# Dev mode spawns CMD multiple times simultaneously on restart. -# Only the first instance can bind port 7860 β€” the rest must exit -# with code 0 so the dev mode daemon doesn't mark the app as crashed. - -# Run uvicorn; if it fails due to port conflict, exit cleanly. -uvicorn main:app --host 0.0.0.0 --port 7860 -EXIT_CODE=$? - -if [ $EXIT_CODE -ne 0 ]; then - # Check if this was a port-in-use failure (another instance already running) - echo "uvicorn exited with code $EXIT_CODE, exiting gracefully." - exit 0 -fi diff --git a/backend/user_quotas.py b/backend/user_quotas.py deleted file mode 100644 index 4da4a8d91b755b64e165f82ba8940b0e1b19ae38..0000000000000000000000000000000000000000 --- a/backend/user_quotas.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Daily quota for premium model session creations. - -Tracks per-user premium model session starts against a daily cap derived from -the user's HF plan. MongoDB is the source of truth when configured; the -in-process dict remains the fallback for local/dev/test runs. - -The public names still say ``claude`` because this quota bucket originally -only covered Claude and the persisted session field uses that name. - -Unit: session *creations*, not messages. A user who sends with a premium model -in a new session consumes one quota point; switching an already-counted session -back to a premium model doesn't (`AgentSession.claude_counted` guards that). - -Cap tiers: - free user β†’ CLAUDE_FREE_DAILY (1) - pro user β†’ CLAUDE_PRO_DAILY (20) -""" - -import asyncio -import os -from datetime import UTC, datetime - -from agent.core.session_persistence import ( - NoopSessionStore, - get_session_store, - _reset_store_for_tests, -) - -CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1")) -CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20")) - -# user_id -> (day_utc_iso, count_for_that_day) -_claude_counts: dict[str, tuple[str, int]] = {} -_lock = asyncio.Lock() - - -def _today() -> str: - return datetime.now(UTC).date().isoformat() - - -def daily_cap_for(plan: str | None) -> int: - """Return the daily Claude-session cap for the given plan.""" - return CLAUDE_PRO_DAILY if plan == "pro" else CLAUDE_FREE_DAILY - - -async def get_claude_used_today(user_id: str) -> int: - """Return today's Claude session count for the user (0 if none / stale day).""" - store = get_session_store() - if getattr(store, "enabled", False): - db_count = await store.get_quota(user_id, _today()) - return db_count or 0 - - async with _lock: - entry = _claude_counts.get(user_id) - if entry is None: - return 0 - day, count = entry - if day != _today(): - # Stale day β€” drop the entry so the first increment starts fresh. - _claude_counts.pop(user_id, None) - return 0 - return count - - -async def increment_claude(user_id: str) -> int: - """Bump today's Claude session count for the user. Returns the new value.""" - store = get_session_store() - if getattr(store, "enabled", False): - db_count = await store.try_increment_quota(user_id, _today(), cap=10**9) - return db_count or 0 - - async with _lock: - today = _today() - day, count = _claude_counts.get(user_id, (today, 0)) - if day != today: - count = 0 - count += 1 - _claude_counts[user_id] = (today, count) - return count - - -async def try_increment_claude(user_id: str, cap: int) -> int | None: - """Atomically bump today's count if below *cap*. - - Returns the new count, or None when the user is already at the cap. - """ - store = get_session_store() - if getattr(store, "enabled", False): - return await store.try_increment_quota(user_id, _today(), cap) - - async with _lock: - today = _today() - day, count = _claude_counts.get(user_id, (today, 0)) - if day != today: - count = 0 - if count >= cap: - return None - count += 1 - _claude_counts[user_id] = (today, count) - return count - - -async def refund_claude(user_id: str) -> None: - """Decrement today's count β€” used when session creation fails after a successful gate.""" - store = get_session_store() - if getattr(store, "enabled", False): - await store.refund_quota(user_id, _today()) - return - - async with _lock: - entry = _claude_counts.get(user_id) - if entry is None: - return - day, count = entry - if day != _today(): - _claude_counts.pop(user_id, None) - return - new_count = max(0, count - 1) - if new_count == 0: - _claude_counts.pop(user_id, None) - else: - _claude_counts[user_id] = (day, new_count) - - -def _reset_for_tests() -> None: - """Test-only: clear the in-memory store.""" - _claude_counts.clear() - _reset_store_for_tests(NoopSessionStore()) diff --git a/backend/websocket.py b/backend/websocket.py new file mode 100644 index 0000000000000000000000000000000000000000..bc09ed747b164bbe99ddebd6d35a36ae6a2faad8 --- /dev/null +++ b/backend/websocket.py @@ -0,0 +1,62 @@ +"""WebSocket connection manager for real-time communication.""" + +import logging +from typing import Any + +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + + +class ConnectionManager: + """Manages WebSocket connections for multiple sessions.""" + + def __init__(self) -> None: + # session_id -> WebSocket + self.active_connections: dict[str, WebSocket] = {} + + async def connect(self, websocket: WebSocket, session_id: str) -> None: + """Accept a WebSocket connection and register it.""" + logger.info(f"Attempting to accept WebSocket for session {session_id}") + await websocket.accept() + self.active_connections[session_id] = websocket + logger.info(f"WebSocket connected and registered for session {session_id}") + + def disconnect(self, session_id: str) -> None: + """Remove a WebSocket connection.""" + if session_id in self.active_connections: + del self.active_connections[session_id] + logger.info(f"WebSocket disconnected for session {session_id}") + + async def send_event( + self, session_id: str, event_type: str, data: dict[str, Any] | None = None + ) -> None: + """Send an event to a specific session's WebSocket.""" + if session_id not in self.active_connections: + logger.warning(f"No active connection for session {session_id}") + return + + message = {"event_type": event_type} + if data is not None: + message["data"] = data + + try: + await self.active_connections[session_id].send_json(message) + except Exception as e: + logger.error(f"Error sending to session {session_id}: {e}") + self.disconnect(session_id) + + async def broadcast( + self, event_type: str, data: dict[str, Any] | None = None + ) -> None: + """Broadcast an event to all connected sessions.""" + for session_id in list(self.active_connections.keys()): + await self.send_event(session_id, event_type, data) + + def is_connected(self, session_id: str) -> bool: + """Check if a session has an active WebSocket connection.""" + return session_id in self.active_connections + + +# Global connection manager instance +manager = ConnectionManager() diff --git a/configs/cli_agent_config.json b/configs/cli_agent_config.json deleted file mode 100644 index ed247998688a102f143b22b1a76538d0aa02520b..0000000000000000000000000000000000000000 --- a/configs/cli_agent_config.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "model_name": "anthropic/claude-opus-4-6", - "save_sessions": true, - "session_dataset_repo": "smolagents/ml-intern-sessions", - "share_traces": true, - "personal_trace_repo_template": "{hf_user}/ml-intern-sessions", - "yolo_mode": false, - "confirm_cpu_jobs": true, - "auto_file_upload": true, - "messaging": { - "enabled": false, - "auto_event_types": ["approval_required", "error", "turn_complete"], - "destinations": {} - }, - "mcpServers": { - "hf-mcp-server": { - "transport": "http", - "url": "https://huggingface.co/mcp?login" - } - } -} diff --git a/configs/frontend_agent_config.json b/configs/frontend_agent_config.json deleted file mode 100644 index c674a223b018967b7ab4482f3228b0b58d054dd3..0000000000000000000000000000000000000000 --- a/configs/frontend_agent_config.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "model_name": "${ML_INTERN_CLAUDE_MODEL_ID:-bedrock/us.anthropic.claude-opus-4-6-v1}", - "save_sessions": true, - "session_dataset_repo": "smolagents/ml-intern-sessions", - "share_traces": true, - "personal_trace_repo_template": "{hf_user}/ml-intern-sessions", - "yolo_mode": false, - "confirm_cpu_jobs": true, - "auto_file_upload": true, - "mcpServers": { - "hf-mcp-server": { - "transport": "http", - "url": "https://huggingface.co/mcp?login" - } - } -} diff --git a/configs/main_agent_config.json b/configs/main_agent_config.json new file mode 100644 index 0000000000000000000000000000000000000000..18a414b3bfced18b47d2737579e3db9c9d137cd6 --- /dev/null +++ b/configs/main_agent_config.json @@ -0,0 +1,17 @@ +{ + "model_name": "anthropic/claude-opus-4-5-20251101", + "save_sessions": true, + "session_dataset_repo": "akseljoonas/hf-agent-sessions", + "yolo_mode": false, + "confirm_cpu_jobs": false, + "auto_file_upload": true, + "mcpServers": { + "hf-mcp-server": { + "transport": "http", + "url": "https://huggingface.co/mcp?login", + "headers": { + "Authorization": "Bearer ${HF_TOKEN}" + } + } + } +} diff --git a/eval/README.md b/eval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b1342632a8079eef7038c39095caab6e6708a86a --- /dev/null +++ b/eval/README.md @@ -0,0 +1,100 @@ +# HF-Agent Eval + +Rubric-based evaluation pipeline implementing [Rubrics as Rewards](https://arxiv.org/abs/2507.17746) paper (RaR-Explicit formula). + +## Components + +| Component | Purpose | Long Term Goal | +|-----------|---------|----------------| +| **`generate_rubrics.py`** | Generates instance-specific evaluation criteria (7-20 weighted rubrics) from QA pairs using LLM, following the RaR paper methodology | Improve rubric quality with few-shot examples, domain-specific templates, and iterative refinement | +| **`rubric_eval.py`** | Scores responses using RaR-Explicit formula: checks each criterion independently via LLM judge, computes weighted normalized score | Support batch evaluation, caching, and alternative scoring formulas (RaR-Holistic) | +| **`task.py`** | Defines Inspect AI task `hf-benchmark-with-rubrics` that wires dataset, solver, and rubric scorer into a single evaluation pipeline | Add more task variants for different benchmarks (code generation, tool use, multi-turn) | +| **`solvers.py`** | Registry of solver implementations (`hf_agent`, `claude_code`, `claude_code+hf_mcp`) that can be swapped via CLI args | Expand solver library to benchmark more agents (OpenAI Codex, Gemini, open-source agents) | +| **`hf_agent_connector.py`** | Lightweight bridge that spins up the hf-agent stack (tools, MCP, LiteLLM loop) and returns the final assistant response | Enable streaming, intermediate step logging, and cost tracking per evaluation | +| **`leaderboard.py`** | Utilities to build records and append scores to a HuggingFace dataset for tracking performance over time | Add score breakdowns, visualizations, and automatic regression detection | +| **`run_eval_with_leaderboard.py`** | CLI wrapper that runs `inspect eval`, parses scores from logs, and pushes results to the leaderboard dataset | Support scheduled CI runs, PR-gated benchmarks, and multi-dataset aggregation | +| **`hf_io.py`** | Helper utilities for pushing DataFrames to HuggingFace Hub | Extend with dataset versioning and diff tracking | +| **`models.py`** | Shared Pydantic models for evaluation data structures | Centralize all eval schemas for consistency across components | + +## Pipeline + +``` +QA pairs β†’ generate_rubrics.py β†’ run `inspect-ai eval eval/task.py@hf-benchmark-with-rubrics` β†’ scores +``` + +### 1. Generate Rubrics (if not already generated) + +Creates instance-specific evaluation criteria from question + reference answer. + +```bash +python eval/generate_rubrics.py \ + --infile qa_pairs.jsonl \ + --outfile qa_rubrics.jsonl \ + --model anthropic/claude-sonnet-4-5-20250929 \ + --push-to-hub akseljoonas/hf-agent-benchmark@rubrics +``` + +**Input format:** +```json +{"question": "...", "solution": "...", "thread": [...]} +``` + +**Output:** 7-20 weighted criteria per question (Essential: +5, Important: +3-4, Optional: +1-2, Pitfall: -1 to -2) + +### 2. Response evaluation + +Files: +- `eval/hf_agent_connector.py` contains a lightweight bridge that spins up + the existing hf-agent stack in `agent/` (tools, MCP, LiteLLM loop) and returns the assistant reply. +- `eval/solvers.py` keeps the solver implementations (e.g. `hf_agent`, + `claude_code`). If additional solvers are needed, register them there and pass + `-T solver_name=` to swap them in without touching the task. +- `eval/task.py` registers `hf-benchmark-with-rubrics`, which wires + the dataset, solver, and rubric scorer into a single Inspect task and does the eval. + +### Running the hf-agent (implemented in `agent/`) (args are optional) +```bash +uv run inspect eval eval/task.py@hf-benchmark-with-rubrics \ + -T dataset_name=akseljoonas/hf-agent-rubrics \ + -T dataset_split=train \ + -T limit=25 \ + -T solver_name=hf_agent \ + -T solver_kwargs='{"config_path":"agent/config_mcp_example.json","max_iterations":10}' \ + --log-dir logs/inspect +``` + +Different benchmarks can be used by making/running a new task in `eval/task.py`. + +### Running Claude Code headlessly + +The `claude_code` solver shell-outs to the `claude` CLI (`claude -p ... --output-format json`) +so you can benchmark Claude Code without any interactive UI. Example: + +Claude Code command example (kwargs are optional): +```bash +uv run inspect eval eval/task.py@hf-benchmark-with-rubrics \ + -T solver_name=claude_code \ + -T solver_kwargs='{"allowed_tools":"Bash,Read","output_format":"json"}' +``` + +### Leaderboard + +Scores can be pushed to a Hugging Face dataset automatically by wrapping the run +with `eval/run_eval_with_leaderboard.py` (it executes `inspect eval ...` under the hood +and only appends results when the command succeeds): + +```bash +uv run python eval/run_eval_with_leaderboard.py \ + --hf-dataset akseljoonas/hf-agent-leaderboard \ + --hf-token $HF_TOKEN \ + --solver-name hf_agent \ + --solver-kwargs '{"config_path":"agent/config_mcp_example.json","max_iterations":10}' \ + --dataset akseljoonas/hf-agent-rubrics@train \ + --limit 25 +``` + +## Scoring (implemented in `eval/rubric_eval.py`) + +The scoring is implemented in `eval/rubric_eval.py` and is based on the RaR-Explicit formula: `score = Ξ£(weight Γ— satisfied) / Ξ£(positive_weights)`. + +The score is normalized to [0, 1] and clipped if pitfalls make it negative. diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c661b764c811bd99adc8cdabbc29e8275774c97b --- /dev/null +++ b/eval/__init__.py @@ -0,0 +1,3 @@ +from eval.task import hf_benchmark_with_rubrics + +__all__ = ["hf_benchmark_with_rubrics"] diff --git a/eval/check_completeness.py b/eval/check_completeness.py new file mode 100644 index 0000000000000000000000000000000000000000..94790bce8c28a54da6b927ecff284bc6966dd3b7 --- /dev/null +++ b/eval/check_completeness.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Minimal script to check if tasks in solved_tasks.jsonl were fully completed and verified. +Uses an LLM to assess completion status and adds the result to each row. +""" + +import argparse +import json +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed + +import litellm +from dotenv import load_dotenv +from pydantic import BaseModel + +load_dotenv() + + +class CompletionCheck(BaseModel): + reasoning: str + completed: bool + verified: bool + + +PROMPT = """You are evaluating whether an AI agent fully completed a task AND verified its completion. + +Task: {question} + +Agent's final answer: {solution} + +Agent's trace (tool calls and responses): +{trace} + +Evaluate: +1. **completed**: Did the agent actually complete the task? (not just explain what could be done, but actually do it) +2. **verified**: Did the agent verify/confirm that the task was completed correctly? (e.g., checked output, validated results, confirmed success) + +Be strict: +- If the agent asked for more information or said "please provide...", it's NOT completed. +- If the agent only explained how to do something but didn't do it, it's NOT completed. +- If the agent just made a plan of how to complete it but didn't do it, it's NOT completed. +- If there's an error in the trace and no recovery, it's NOT completed. +- If the agent didn't check/confirm the code/command completed succesfully or the result is correct somehow, it's NOT verified. + +Return JSON with: completed (bool), verified (bool), reasoning (brief explanation).""" + + +def format_trace(messages: list) -> str: + """Format messages trace for the prompt.""" + if not messages: + return "(No trace)" + + parts = [] + for msg in messages: + role = msg.get("role", "unknown") + if role == "system": + continue + + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", []) + + if tool_calls: + for tc in tool_calls: + if isinstance(tc, dict) and "function" in tc: + name = tc["function"].get("name", "?") + parts.append(f"[TOOL CALL] {name}") + + if content: + # Truncate long content + if len(content) > 5000: + content = content[:4000] + "..." + content[-1000:] + parts.append(f"[{role.upper()}] {content}") + + return "\n".join(parts) if parts else "(Empty trace)" + + +def check_row(row: dict, model: str) -> CompletionCheck | None: + """Check if a single task was completed and verified.""" + prompt = PROMPT.format( + question=row["question"], + solution=row.get("solution", "(No solution)"), + trace=format_trace(row.get("messages", [])), + ) + + try: + response = litellm.completion( + model=model, + messages=[{"role": "user", "content": prompt}], + response_format=CompletionCheck, + timeout=60, + ) + return CompletionCheck.model_validate_json(response.choices[0].message.content) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return None + + +def main(): + parser = argparse.ArgumentParser(description="Check task completion status") + parser.add_argument("--infile", type=str, default="eval/solved_tasks.jsonl") + parser.add_argument( + "--outfile", type=str, default="eval/solved_tasks_checked.jsonl" + ) + parser.add_argument( + "--model", type=str, default="anthropic/claude-sonnet-4-5-20250929" + ) + parser.add_argument("--max-concurrent", type=int, default=30) + args = parser.parse_args() + + # Load data + print(f"Loading {args.infile}...") + rows = [] + with open(args.infile) as f: + for line in f: + rows.append(json.loads(line)) + print(f"Loaded {len(rows)} rows") + + # Process in parallel + print(f"Checking completion with {args.model}...") + with ThreadPoolExecutor(max_workers=args.max_concurrent) as executor: + futures = { + executor.submit(check_row, row, args.model): i for i, row in enumerate(rows) + } + results = [None] * len(rows) + + for future in as_completed(futures): + idx = futures[future] + results[idx] = future.result() + print( + f"Done: {sum(1 for r in results if r is not None)}/{len(rows)}", + end="\r", + ) + + print() + + # Merge results + output_rows = [] + for row, result in zip(rows, results): + if result: + row["task_completed"] = result.completed + row["task_verified"] = result.verified + row["completion_reasoning"] = result.reasoning + else: + row["task_completed"] = None + row["task_verified"] = None + row["completion_reasoning"] = "Error during check" + output_rows.append(row) + + # Write output + print(f"Writing to {args.outfile}...") + with open(args.outfile, "w") as f: + for row in output_rows: + f.write(json.dumps(row, default=str) + "\n") + + # Summary + completed = sum(1 for r in results if r and r.completed) + verified = sum(1 for r in results if r and r.verified) + print("\nSummary:") + print(f" Completed: {completed}/{len(rows)}") + print(f" Verified: {verified}/{len(rows)}") + + +if __name__ == "__main__": + main() diff --git a/eval/claude_batch_solve.py b/eval/claude_batch_solve.py new file mode 100644 index 0000000000000000000000000000000000000000..154e23fd3b8a7f27e6b7559eaf3c7933e04cae39 --- /dev/null +++ b/eval/claude_batch_solve.py @@ -0,0 +1,230 @@ +import asyncio +import json +import os +import threading +from pathlib import Path +from typing import Any + +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, + query, +) +from dotenv import load_dotenv + +load_dotenv() + +# Thread-safe file writing +file_lock = threading.Lock() + + +def convert_message_to_chat_format(message: Any) -> dict | None: + """Convert SDK message to standard chat format with role/content/tool_calls.""" + + if isinstance(message, SystemMessage): + # Extract tools list from init data for system message + if message.subtype == "init": + tools = message.data.get("tools", []) + tools_desc = "\n".join(f"- {tool}" for tool in tools) + return { + "role": "system", + "content": f"You are a helpful assistant with access to the following tools:\n{tools_desc}", + } + return None + + elif isinstance(message, AssistantMessage): + text_content = "" + tool_calls = [] + + for block in message.content: + if isinstance(block, TextBlock): + text_content += block.text + elif isinstance(block, ToolUseBlock): + tool_calls.append( + { + "id": block.id, + "function": { + "name": block.name, + "arguments": block.input, + }, + } + ) + + result = {"role": "assistant", "content": text_content} + if tool_calls: + result["tool_calls"] = tool_calls + return result + + elif isinstance(message, UserMessage): + # UserMessage can contain tool results or text + if isinstance(message.content, str): + return {"role": "user", "content": message.content} + elif isinstance(message.content, list): + # Check for tool results + tool_results = [] + text_content = "" + for block in message.content: + if isinstance(block, ToolResultBlock): + # Format tool result content + if isinstance(block.content, str): + content = block.content + elif isinstance(block.content, list): + content = json.dumps(block.content) + else: + content = str(block.content) if block.content else "" + + tool_results.append( + { + "tool_use_id": block.tool_use_id, + "content": content, + "is_error": block.is_error, + } + ) + elif isinstance(block, TextBlock): + text_content += block.text + + if tool_results: + return { + "role": "user", + "content": f"\n{json.dumps(tool_results, indent=2)}\n", + } + else: + return {"role": "user", "content": text_content} + return None + + elif isinstance(message, ResultMessage): + # ResultMessage is metadata, not a conversation message + return None + + return None + + +async def solve_task( + question: str, + difficulty: str, + task_idx: int, + total: int, + semaphore: asyncio.Semaphore, +) -> dict: + """Solve a single task using Claude Agent SDK.""" + async with semaphore: + print(f"[{task_idx}/{total}] Starting: {question[:60]}...") + + messages = [] + solution = None + + try: + async for message in query( + prompt=question, + options=ClaudeAgentOptions( + cwd=os.getcwd(), + permission_mode="bypassPermissions", + disallowed_tools=["Write", "Edit", "Bash", "Glob", "Grep"], + mcp_servers={ + "huggingface": { + "type": "http", + "url": "https://huggingface.co/mcp", + "headers": { + "Authorization": f"Bearer {os.environ['HF_TOKEN']}" + }, + } + }, + ), + ): + # Convert to chat format and append if valid + chat_msg = convert_message_to_chat_format(message) + if chat_msg: + messages.append(chat_msg) + + # Extract text from assistant messages + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + solution = block.text + # Check for result messages + elif isinstance(message, ResultMessage): + if message.is_error: + print(f"[{task_idx}/{total}] βœ— Agent error: {message.subtype}") + return { + "question": question, + "difficulty": difficulty, + "solution": None, + "messages": messages, + "error": f"Agent error: {message.subtype}", + } + elif message.result: + solution = message.result + + print(f"[{task_idx}/{total}] βœ“ Done: {question[:60]}...") + return { + "question": question, + "difficulty": difficulty, + "solution": solution, + "messages": messages, + "error": None, + } + except Exception as e: + print(f"[{task_idx}/{total}] βœ— Error: {e}") + return { + "question": question, + "difficulty": difficulty, + "solution": None, + "messages": messages, + "error": str(e), + } + + +def write_result(output_path: Path, result: dict): + """Thread-safe write to output file.""" + with file_lock: + with open(output_path, "a") as f: + f.write(json.dumps(result) + "\n") + + +async def main(): + # Load tasks from filled_tasks.jsonl + tasks_path = Path(__file__).parent / "filled_tasks.jsonl" + tasks = [] + with open(tasks_path) as f: + for line in f: + tasks.append(json.loads(line)) + + # Output file - clear it first + output_path = Path(__file__).parent / "solved_tasks.jsonl" + output_path.write_text("") + + # Semaphore to limit concurrency + max_concurrent = 5 + semaphore = asyncio.Semaphore(max_concurrent) + + total = len(tasks) + print(f"Processing {total} tasks with {max_concurrent} concurrent agents...") + + async def process_and_save(task: dict, idx: int): + result = await solve_task( + task["question"], task["difficulty"], idx, total, semaphore + ) + write_result(output_path, result) + return result + + # Create all tasks + coroutines = [process_and_save(task, i + 1) for i, task in enumerate(tasks)] + + # Run all concurrently (semaphore limits actual parallelism) + results = await asyncio.gather(*coroutines, return_exceptions=True) + + successful = sum( + 1 for r in results if isinstance(r, dict) and r.get("error") is None + ) + print(f"\nCompleted: {successful}/{total} successful") + print(f"Results saved to {output_path}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/eval/create_eval_dataset.py b/eval/create_eval_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7a56c3d83914034aa1c0b070d6cece13327ad7fc --- /dev/null +++ b/eval/create_eval_dataset.py @@ -0,0 +1,160 @@ +from itertools import product + +from datasets import Dataset + +# Task templates (excluding Very hard difficulty) +tasks = [ + { + "task": "Evaluate models {M} on benchmarks {B}", + "difficulty": "Easy", + "category": "Evaluation", + "params": ["M", "B"], + }, + { + "task": "Train models {M} on datasets {D} evaluating them on benchmarks {B}", + "difficulty": "Medium", + "category": "Training", + "params": ["M", "D", "B"], + }, + { + "task": "Run an ablation for hyperparameter {P} for model {M} on dataset {D}", + "difficulty": "Hard", + "category": "Ablation", + "params": ["P", "M", "D"], + }, + { + "task": "Generate completions with model {M} on benchmarks {B} using engine {E}", + "difficulty": "Medium", + "category": "Generation", + "params": ["M", "B", "E"], + }, + # { + # "task": "Merge models {M} using linear averaging to find the best result on benchmarks {B}", + # "difficulty": "Hard", + # "category": "Model Merging", + # "params": ["M", "B"], + # }, + { + "task": "Decontaminate dataset {D} against benchmarks {B}", + "difficulty": "Hard", + "category": "Data Processing", + "params": ["D", "B"], + }, + { + "task": "Format dataset {D} for compatibility with framework {F} on task {T}", + "difficulty": "Easy", + "category": "Data Formatting", + "params": ["D", "F", "T"], + }, +] + +# Parameter values +values = { + "M": [ + "Qwen/Qwen3-4B-Instruct-2507", + "openai/gpt-oss-20b", + "gpt-4o-mini", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "anthropic's latest model", + ], + "B": [ + "Idavidrein/gpqa", + "HuggingFaceH4/MATH-500", + "lighteval/SimpleQA", + "TIGER-Lab/MMLU-Pro", + ], + "D": [ + "HuggingFaceH4/multi_turn_if", + "HuggingFaceH4/ultrachat_200k", + "HuggingFaceH4/AceReason-1.1-SFT config: math_no_think", + ], + "E": [ + "vllm", + "sglang", + ], + "F": [ + "trl", + "axolotl", + "verl", + ], + "P": [ + "learning_rate", + "batch_size", + "num_epochs", + ], + "T": [ + "SFT", + "GRPO", + ], +} + +# Task-specific instance limits +# For each task, specify which parameter(s) to pivot on and how many instances per pivot combination +# pivot can be a single parameter string or a list of parameters +task_limits = [ + {"pivot": "B", "instances_per_pivot": 1}, # Task 0: 1 instance per + {"pivot": ["M", "B"], "instances_per_pivot": 3}, # Task 1: 3 instances per model + {"pivot": ["P", "D"], "instances_per_pivot": 3}, # Task 2: + {"pivot": "E", "instances_per_pivot": 2}, # Task 3: 2 instances per benchmark + # {"pivot": "M", "instances_per_pivot": 2}, # Task 4 + {"pivot": "D", "instances_per_pivot": 2}, # Task 5: 2 instances per dataset + {"pivot": ["D", "F", "T"], "instances_per_pivot": 2}, # Task 6: +] + + +def main(): + eval_data = [] + + for task_idx, task_dict in enumerate(tasks): + template = task_dict["task"] + params = task_dict["params"] + limit_config = task_limits[task_idx] + + pivot_params = limit_config["pivot"] + instances_per_pivot = limit_config["instances_per_pivot"] + + # Normalize pivot to list + if isinstance(pivot_params, str): + pivot_params = [pivot_params] + + # Get all combinations of pivot values + pivot_param_values = [values[p] for p in pivot_params] + pivot_combinations = product(*pivot_param_values) + + # For each pivot combination, generate limited instances + for pivot_combo in pivot_combinations: + # Get combinations of other (non-pivot) parameters + other_params = [p for p in params if p not in pivot_params] + other_param_values = [values[p] for p in other_params] + other_combinations = list(product(*other_param_values)) + + # Limit to specified number of instances per pivot combination + limited_combinations = other_combinations[:instances_per_pivot] + + # Generate instances + for combo in limited_combinations: + # Build kwargs with pivot values and other values + kwargs = dict(zip(pivot_params, pivot_combo)) + kwargs.update(dict(zip(other_params, combo))) + + concrete_task = template.format(**kwargs) + eval_data.append( + { + "task": concrete_task, + "difficulty": task_dict["difficulty"], + "category": task_dict["category"], + } + ) + + print(f"Generated {len(eval_data)} instances from {len(tasks)} templates") + + dataset = Dataset.from_list(eval_data) + print(f"\nDataset: {len(dataset)} rows") + print(f"Sample: {dataset[0]['task']}") + + dataset.push_to_hub("akseljoonas/qyestions", private=False) + print("\nβœ“ Pushed to akseljoonas/qyestions") + + +if __name__ == "__main__": + main() diff --git a/eval/eval_set.ipynb b/eval/eval_set.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3515f34727e3ff64b71b07fbc9cfd036aacf4995 --- /dev/null +++ b/eval/eval_set.ipynb @@ -0,0 +1,755 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "febne6uj10o", + "source": "#!/usr/bin/env python3\n\"\"\"Script to create HuggingFace 401 error fix documentation\"\"\"\n\nimport os\nfrom pathlib import Path\n\n# The full content of the documentation\ndocumentation_content = \"\"\"# HuggingFace 401 Unauthorized Error Fix - Dataset Push to HuggingFaceFW/fineweb-edu\n\n## Problem Summary\n\nWhen attempting to push to the HuggingFace dataset repository `HuggingFaceFW/fineweb-edu`, users may encounter a **401 Unauthorized** error. This is a large-scale educational dataset (1.3T tokens, 5.4TB) that requires proper authentication, token permissions, and git-lfs configuration for successful uploads.\n\n**Authenticated User:** akseljoonas \n**Repository:** HuggingFaceFW/fineweb-edu (dataset) \n**Repository Stats:** 5.3M downloads | 873 likes | Last updated: July 11, 2025\n\n---\n\n## Root Causes of 401 Errors\n\nBased on recent issues (2025) and HuggingFace documentation, 401 errors typically stem from:\n\n### 1. **Insufficient Token Permissions**\n- Token lacks **write** permission (only has read access)\n- Token is expired or invalid\n- Using organization token instead of personal access token\n\n### 2. **Git Credential Configuration Issues**\n- Token not saved to git credential helper\n- Git attempting to use cached incorrect credentials\n- Missing `--add-to-git-credential` flag during login\n\n### 3. **Git-LFS Authentication Failures**\n- Git-LFS not properly configured\n- LFS files not tracked correctly (threshold issues)\n- Token not being passed to git-lfs operations\n- CAS (Content Addressable Storage) service authentication failures (new in 2025)\n\n### 4. **API Version Compatibility (2025 Issue)**\n- Modern access tokens only work with API v2 endpoints\n- `huggingface_hub` may internally use API v1 endpoints causing 401 errors\n- Reported as recently as October 2025\n\n### 5. **Large File Upload Issues**\n- Authorization errors when uploading many files (~1000+ files, 300GB+)\n- Timeout issues with LFS authentication on large batches\n\n---\n\n## Diagnostic Steps\n\n### Step 1: Verify Authentication Status\n\n```bash\n# Check who you're authenticated as\nhuggingface-cli whoami\n\n# Or using Python\npython3 -c \"from huggingface_hub import whoami; print(whoami())\"\n```\n\n**Expected Output:** Should show username `akseljoonas` and token permissions\n\n### Step 2: Check Token Permissions\n\n```bash\n# Login and verify token has WRITE permission\nhuggingface-cli login --token YOUR_TOKEN\n\n# Look for this line in output:\n# Token is valid (permission: write).\n```\n\n**Important:** If you see `(permission: read)`, your token is insufficient for pushing!\n\n### Step 3: Verify Git Configuration\n\n```bash\n# Check git credential configuration\ngit config --global --list | grep credential\n\n# Check for git-lfs installation\ngit lfs version\n\n# Check git-lfs environment\ngit lfs env\n```\n\n### Step 4: Check Repository Access\n\n```python\nfrom huggingface_hub import HfApi, auth_check\n\ntry:\n # Verify you have access to the repository\n auth_check(\"HuggingFaceFW/fineweb-edu\", repo_type=\"dataset\")\n print(\"βœ“ Access granted to repository\")\nexcept Exception as e:\n print(f\"βœ— Access denied: {e}\")\n```\n\n### Step 5: Inspect Local Repository (if cloned)\n\n```bash\n# Navigate to your local repo\ncd /path/to/fineweb-edu\n\n# Check git remote\ngit remote -v\n\n# Check git-lfs tracking\ngit lfs track\n\n# Check .gitattributes file\ncat .gitattributes\n```\n\n---\n\n## Complete Fix Solutions\n\n### Solution 1: Re-authenticate with Correct Token Scope βœ… RECOMMENDED\n\nThis is the most common fix for 401 errors.\n\n```bash\n# Step 1: Create a new token with WRITE permissions\n# Go to: https://huggingface.co/settings/tokens\n# Click \"New token\"\n# Select role: \"write\" (NOT \"read\")\n# Give it a name like \"dataset-push-token\"\n# Copy the token (starts with hf_...)\n\n# Step 2: Login with the token AND add to git credentials\nhuggingface-cli login --token YOUR_WRITE_TOKEN --add-to-git-credential\n\n# Step 3: Verify the login\nhuggingface-cli whoami\n```\n\n**Expected Output:**\n```\nToken is valid (permission: write).\nYour token has been saved in your configured git credential helpers (store).\nYour token has been saved to /home/username/.cache/huggingface/token\nLogin successful\n```\n\n**Python Alternative:**\n```python\nfrom huggingface_hub import login\n\n# Login with write token and save to git credentials\nlogin(token=\"hf_YOUR_WRITE_TOKEN\", add_to_git_credential=True)\n```\n\n---\n\n### Solution 2: Configure Git Credentials Manually\n\nIf `--add-to-git-credential` doesn't work automatically:\n\n```bash\n# Step 1: Configure git credential store\ngit config --global credential.helper store\n\n# Step 2: Create/edit the credentials file\n# Location: ~/.git-credentials (Linux/Mac) or C:\\\\Users\\\\\\\\.git-credentials (Windows)\necho \"https://YOUR_USERNAME:YOUR_HF_TOKEN@huggingface.co\" >> ~/.git-credentials\n\n# Step 3: Verify\ncat ~/.git-credentials | grep huggingface\n```\n\n**Format for credentials file:**\n```\nhttps://akseljoonas:hf_YOUR_TOKEN@huggingface.co\n```\n\n---\n\n### Solution 3: Fix Git-LFS Configuration\n\nFor large datasets like fineweb-edu, git-lfs is essential:\n\n```bash\n# Step 1: Install git-lfs (if not installed)\n# Ubuntu/Debian:\nsudo apt-get install git-lfs\n\n# macOS:\nbrew install git-lfs\n\n# Windows: Download from https://git-lfs.github.com/\n\n# Step 2: Initialize git-lfs globally\ngit lfs install\n\n# Step 3: In your repository, track large files\ncd /path/to/fineweb-edu\n\n# Track common large file types for datasets\ngit lfs track \"*.parquet\"\ngit lfs track \"*.arrow\"\ngit lfs track \"*.bin\"\ngit lfs track \"*.safetensors\"\ngit lfs track \"*.h5\"\ngit lfs track \"*.json.gz\"\n\n# Step 4: Verify tracking\ngit lfs track\n\n# Step 5: Check .gitattributes was updated\ncat .gitattributes\n```\n\n**Default Large File Threshold:**\n- HuggingFace automatically uses LFS for files > 10MB\n- Files under 10MB are stored as regular git objects\n\n---\n\n### Solution 4: Use HuggingFace Hub API Instead of Git (RECOMMENDED for Large Datasets)\n\nFor very large datasets like fineweb-edu, using the Python API is more reliable than git push:\n\n```python\nfrom huggingface_hub import HfApi, login\nfrom pathlib import Path\n\n# Step 1: Authenticate\nlogin(token=\"hf_YOUR_WRITE_TOKEN\", add_to_git_credential=True)\n\n# Step 2: Initialize API client\napi = HfApi()\n\n# Step 3: Upload files to the dataset repository\n# For a single file:\napi.upload_file(\n path_or_fileobj=\"/path/to/local/file.parquet\",\n path_in_repo=\"data/file.parquet\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n)\n\n# For multiple files in a folder:\napi.upload_folder(\n folder_path=\"/path/to/local/folder\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n commit_message=\"Add new data files\",\n)\n\n# For very large uploads, use multi_commits=True:\napi.upload_large_folder(\n folder_path=\"/path/to/large/dataset\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n multi_commits=True,\n commit_message=\"Upload large dataset batch\",\n)\n```\n\n**Benefits over git push:**\n- Better handling of large files (no LFS authentication issues)\n- Automatic retry on failures\n- Progress tracking\n- No credential caching problems\n- Works around 2025 API v1/v2 compatibility issues\n\n---\n\n### Solution 5: Handle CAS Service Errors (2025 Issue)\n\nIf you see errors mentioning \"CAS service\" or \"Content Addressable Storage\":\n\n```python\nfrom huggingface_hub import HfApi\nimport time\n\napi = HfApi()\n\n# Use smaller batch sizes with delays\nfiles_to_upload = list(Path(\"/your/dataset\").glob(\"*.parquet\"))\n\nfor file_path in files_to_upload:\n try:\n api.upload_file(\n path_or_fileobj=str(file_path),\n path_in_repo=f\"data/{file_path.name}\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n )\n print(f\"βœ“ Uploaded {file_path.name}\")\n time.sleep(2) # Small delay to avoid overwhelming CAS service\n except Exception as e:\n print(f\"βœ— Failed to upload {file_path.name}: {e}\")\n```\n\n---\n\n### Solution 6: Check Repository Permissions\n\nVerify you have write access to the repository:\n\n```python\nfrom huggingface_hub import HfApi, whoami\n\napi = HfApi()\n\n# Check your user info\nuser_info = whoami()\nprint(f\"Username: {user_info['name']}\")\nprint(f\"Organizations: {user_info.get('orgs', [])}\")\n\n# Check if you're part of HuggingFaceFW organization\norgs = user_info.get('orgs', [])\nhas_access = any(org.get('name') == 'HuggingFaceFW' for org in orgs)\n\nif has_access:\n print(\"βœ“ You are a member of HuggingFaceFW organization\")\nelse:\n print(\"βœ— You are NOT a member of HuggingFaceFW organization\")\n print(\" You may need to request access or use a PR instead\")\n```\n\n**If you don't have write access:**\n```bash\n# Create a pull request instead of pushing directly\nhuggingface-cli upload HuggingFaceFW/fineweb-edu /path/to/file --create-pr\n```\n\nOr with Python:\n```python\napi.upload_file(\n path_or_fileobj=\"/path/to/file\",\n path_in_repo=\"data/file.parquet\",\n repo_id=\"HuggingFaceFW/fineweb-edu\",\n repo_type=\"dataset\",\n create_pr=True, # Creates a PR instead of direct push\n)\n```\n\n---\n\n## Git-LFS Configuration Details\n\n### File Size Thresholds\n\n| File Size | Storage Method | Configuration |\n|-----------|---------------|---------------|\n| < 10 MB | Regular Git | No special config needed |\n| > 10 MB | Git-LFS | Automatically tracked by HF |\n| > 5 GB | Git-LFS + Special handling | Use API upload methods |\n\n### Common .gitattributes for Datasets\n\n```gitattributes\n# Large data files\n*.parquet filter=lfs diff=lfs merge=lfs -text\n*.arrow filter=lfs diff=lfs merge=lfs -text\n*.bin filter=lfs diff=lfs merge=lfs -text\n*.safetensors filter=lfs diff=lfs merge=lfs -text\n*.h5 filter=lfs diff=lfs merge=lfs -text\n*.hdf5 filter=lfs diff=lfs merge=lfs -text\n\n# Compressed files\n*.tar.gz filter=lfs diff=lfs merge=lfs -text\n*.zip filter=lfs diff=lfs merge=lfs -text\n*.json.gz filter=lfs diff=lfs merge=lfs -text\n\n# Model files\n*.onnx filter=lfs diff=lfs merge=lfs -text\n*.pb filter=lfs diff=lfs merge=lfs -text\n*.pt filter=lfs diff=lfs merge=lfs -text\n*.pth filter=lfs diff=lfs merge=lfs -text\n```\n\n### Verify LFS is Working\n\n```bash\n# Check which files are tracked by LFS\ngit lfs ls-files\n\n# Check LFS status\ngit lfs status\n\n# Verify a specific file is using LFS\ngit lfs ls-files | grep \"your-file.parquet\"\n\n# See LFS configuration\ngit lfs env\n```\n\n---\n\n## Environment Variables\n\nUseful environment variables for debugging:\n\n```bash\n# Set HuggingFace token via environment variable\nexport HF_TOKEN=\"hf_YOUR_TOKEN\"\n\n# Disable implicit token sending (for debugging)\nexport HF_HUB_DISABLE_IMPLICIT_TOKEN=1\n\n# Enable verbose git LFS output\nexport GIT_TRACE=1\nexport GIT_CURL_VERBOSE=1\nexport GIT_LFS_TRACE=1\n\n# Set custom cache directory\nexport HF_HOME=\"/path/to/custom/cache\"\n```\n\n---\n\n## Testing the Fix\n\nAfter applying the fixes, test with a small file first:\n\n```python\nfrom huggingface_hub import HfApi\nimport tempfile\nfrom pathlib import Path\n\napi = HfApi()\n\n# Create a small test file\nwith tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:\n f.write(\"Test file for authentication verification\")\n test_file = f.name\n\ntry:\n # Try uploading to a test repository you own\n # DO NOT test on fineweb-edu directly!\n result = api.upload_file(\n path_or_fileobj=test_file,\n path_in_repo=\"test_auth.txt\",\n repo_id=\"YOUR_USERNAME/test-repo\", # Use your own test repo\n repo_type=\"dataset\",\n )\n print(f\"βœ“ Authentication working! File uploaded to: {result}\")\nexcept Exception as e:\n print(f\"βœ— Authentication failed: {e}\")\nfinally:\n Path(test_file).unlink() # Clean up test file\n```\n\n---\n\n## Quick Reference - Commands Checklist\n\n```bash\n# 1. Check current authentication\nhuggingface-cli whoami\n\n# 2. Re-login with write token\nhuggingface-cli login --token YOUR_WRITE_TOKEN --add-to-git-credential\n\n# 3. Verify git credentials\ngit config --global credential.helper store\ncat ~/.git-credentials | grep huggingface\n\n# 4. Check git-lfs\ngit lfs version\ngit lfs install\n\n# 5. In your repo, verify LFS tracking\ncd /path/to/repo\ngit lfs track\ncat .gitattributes\n\n# 6. Test authentication with Python\npython3 -c \"from huggingface_hub import whoami; print(whoami())\"\n```\n\n---\n\n## Common Error Messages and Solutions\n\n| Error Message | Cause | Solution |\n|---------------|-------|----------|\n| `401 Unauthorized` | Invalid or read-only token | Use Solution 1: Re-authenticate with write token |\n| `403 Forbidden` | No access to repository | Check repository permissions (Solution 6) |\n| `Repository not found` | Wrong repo ID or private repo without access | Verify repo exists and you have access |\n| `LFS authentication failed` | Git credentials not configured | Use Solution 2: Configure git credentials |\n| `CAS service error` | 2025 API issue | Use Solution 5: Smaller batches with delays |\n| `This repository requires LFS` | Missing git-lfs | Use Solution 3: Install and configure git-lfs |\n| `batch response: This repository is over its data limit` | Repository quota exceeded | Contact repository owner |\n\n---\n\n## Best Practices for Large Datasets\n\nFor datasets like fineweb-edu (1.3T tokens):\n\n1. **Use the HuggingFace Hub API** instead of git push\n2. **Upload in batches** rather than all at once\n3. **Use `upload_large_folder()`** with `multi_commits=True`\n4. **Monitor upload progress** and implement retry logic\n5. **Test with small files first** before uploading large batches\n6. **Use fine-grained tokens** for production environments\n7. **Keep tokens secure** - use environment variables or secure vaults\n\n---\n\n## Additional Resources\n\n- [HuggingFace Hub Python Library](https://huggingface.co/docs/huggingface_hub)\n- [Security Tokens Documentation](https://huggingface.co/docs/hub/security-tokens)\n- [Git-LFS Documentation](https://git-lfs.github.com/)\n- [HuggingFace CLI Guide](https://huggingface.co/docs/huggingface_hub/guides/cli)\n\n---\n\n## Document Version\n\n- **Created:** December 18, 2025\n- **Last Updated:** December 18, 2025\n- **Tested Against:** HuggingFace Hub API v1.2.3+\n- **Authenticated User:** akseljoonas\n- **Target Repository:** HuggingFaceFW/fineweb-edu (dataset)\n\n---\n\n## Sources & References\n\n- [I got Authorization error - Hugging Face Forums](https://discuss.huggingface.co/t/i-got-authorization-error/32881)\n- [Can't push to a dataset repository - Hugging Face Forums](https://discuss.huggingface.co/t/cant-push-to-a-dataset-repository/36611)\n- [LFS: Authorization error when uploading large files](https://lightrun.com/answers/huggingface-huggingface_hub-lfs-authorization-error-when-uploading-manylarge-files)\n- [401 Client Error - huggingface_hub Issue #2586](https://github.com/huggingface/huggingface_hub/issues/2586)\n- [Modern Access Tokens API v2 issue - Issue #3479](https://github.com/huggingface/huggingface_hub/issues/3479)\n- [Hugging Face Hub Dataset Upload CAS Error - Issue #7760](https://github.com/huggingface/datasets/issues/7760)\n- [HuggingFace Security Tokens Documentation](https://huggingface.co/docs/hub/security-tokens)\n\"\"\"\n\n# Expand the ~ to the user's home directory\noutput_path = Path.home() / \"huggingface_401_fix_documentation.md\"\n\n# Write the documentation to the file\ntry:\n with open(output_path, 'w', encoding='utf-8') as f:\n f.write(documentation_content)\n print(f\"βœ“ Successfully created documentation at: {output_path}\")\n print(f\"βœ“ File size: {output_path.stat().st_size} bytes\")\nexcept Exception as e:\n print(f\"βœ— Error creating file: {e}\")\n raise", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "x6z3pkwzo8", + "source": "import csv\n\n# Model data collected from Hugging Face API - Apache-2.0 licensed text-classification models under 500MB\nmodel_data = [\n {'model_id': 'kmack/malicious-url-detection', 'downloads': 2000000, 'likes': 1, 'size_mb': 255.2, 'license': 'apache-2.0'},\n {'model_id': 'mixedbread-ai/mxbai-rerank-xsmall-v1', 'downloads': 960600, 'likes': 49, 'size_mb': 491.7, 'license': 'apache-2.0'},\n {'model_id': 'cross-encoder/ms-marco-TinyBERT-L2-v2', 'downloads': 598100, 'likes': 36, 'size_mb': 172.09, 'license': 'apache-2.0'},\n {'model_id': 'cybersectony/phishing-email-detection-distilbert_v2.4.1', 'downloads': 300500, 'likes': 23, 'size_mb': 255.26, 'license': 'apache-2.0'},\n {'model_id': 'jamal-ibrahim/risk_assesment', 'downloads': 98700, 'likes': 0, 'size_mb': 255.42, 'license': 'apache-2.0'},\n {'model_id': 'agufsamudra/indo-sentiment-analysis', 'downloads': 92100, 'likes': 0, 'size_mb': 475.0, 'license': 'apache-2.0'}\n]\n\n# Already sorted by downloads descending\ncsv_path = '/tmp/apache2_text_classification_models.csv'\nwith open(csv_path, 'w', newline='') as f:\n writer = csv.DictWriter(f, fieldnames=['model_id', 'downloads', 'likes', 'size_mb', 'license'])\n writer.writeheader()\n writer.writerows(model_data)\n\nprint(f'βœ“ CSV file created at: {csv_path}')\nprint(f'βœ“ Total models: {len(model_data)}')\nprint(f'βœ“ All models are Apache-2.0 licensed and under 500MB')\nprint(f'βœ“ Sorted by downloads (descending)')", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "t9et9n50wgr", + "source": "# This is just to check the notebook structure\nprint(\"test\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "n4awck8w5ok", + "source": "# Write the KV Cache benchmark script\nbenchmark_script = '''#!/usr/bin/env python3\n\"\"\"\nKV Cache Quantization Benchmark Script\nCompares FP16 vs INT8 quantized KV cache performance on CNN/DailyMail summarization task\n\"\"\"\n\nimport json\nimport time\nimport torch\nfrom datasets import load_dataset\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nfrom rouge_score import rouge_scorer\nimport gc\nfrom typing import Dict, List, Tuple\nimport numpy as np\n\n# Configuration\nMODEL_NAME = \"meta-llama/Llama-3.2-1B\"\nDATASET_NAME = \"cnn_dailymail\"\nDATASET_CONFIG = \"3.0.0\"\nNUM_SAMPLES = 100\nMAX_NEW_TOKENS = 128\nDO_SAMPLE = False\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nprint(f\"Using device: {DEVICE}\")\nprint(f\"PyTorch version: {torch.__version__}\")\n\n# Install required packages (instructions for user)\nprint(\"\\\\nRequired packages:\")\nprint(\"pip install transformers datasets rouge-score torch hqq accelerate\")\nprint(\"-\" * 80)\n\n\ndef load_model_and_tokenizer():\n \"\"\"Load the model and tokenizer\"\"\"\n print(f\"\\\\nLoading model: {MODEL_NAME}\")\n \n tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n \n # Set padding token if not set\n if tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n \n model = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n torch_dtype=torch.float16 if DEVICE == \"cuda\" else torch.float32,\n device_map=\"auto\" if DEVICE == \"cuda\" else None,\n )\n \n if DEVICE == \"cpu\":\n model = model.to(DEVICE)\n \n model.eval()\n \n print(f\"Model loaded successfully on {DEVICE}\")\n return model, tokenizer\n\n\ndef load_data() -> List[Dict]:\n \"\"\"Load CNN/DailyMail dataset\"\"\"\n print(f\"\\\\nLoading {NUM_SAMPLES} samples from {DATASET_NAME} dataset...\")\n \n dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=\"test\")\n samples = dataset.select(range(min(NUM_SAMPLES, len(dataset))))\n \n data = []\n for sample in samples:\n data.append({\n \"article\": sample[\"article\"],\n \"highlights\": sample[\"highlights\"],\n })\n \n print(f\"Loaded {len(data)} samples\")\n return data\n\n\ndef prepare_prompt(article: str) -> str:\n \"\"\"Prepare prompt for summarization\"\"\"\n prompt = f\"\"\"Summarize the following article in one or two sentences:\n\nArticle: {article[:1000]}\n\nSummary:\"\"\"\n return prompt\n\n\ndef generate_summaries(\n model, \n tokenizer, \n data: List[Dict], \n cache_implementation: str = \"default\",\n cache_config: Dict = None\n) -> Tuple[List[str], float, float]:\n \"\"\"\n Generate summaries and measure performance\n \n Returns:\n summaries: List of generated summaries\n tokens_per_sec: Throughput in tokens/second\n peak_memory_mb: Peak memory usage in MB\n \"\"\"\n summaries = []\n total_tokens = 0\n start_time = time.time()\n \n if DEVICE == \"cuda\":\n torch.cuda.reset_peak_memory_stats()\n initial_memory = torch.cuda.memory_allocated()\n \n print(f\"\\\\nGenerating summaries with cache_implementation='{cache_implementation}'...\")\n \n for i, sample in enumerate(data):\n prompt = prepare_prompt(sample[\"article\"])\n \n inputs = tokenizer(\n prompt, \n return_tensors=\"pt\", \n truncation=True, \n max_length=2048\n ).to(DEVICE)\n \n # Generate with specified cache configuration\n generation_kwargs = {\n \"max_new_tokens\": MAX_NEW_TOKENS,\n \"do_sample\": DO_SAMPLE,\n \"pad_token_id\": tokenizer.pad_token_id,\n }\n \n if cache_implementation != \"default\":\n generation_kwargs[\"cache_implementation\"] = cache_implementation\n if cache_config:\n generation_kwargs[\"cache_config\"] = cache_config\n \n with torch.no_grad():\n outputs = model.generate(**inputs, **generation_kwargs)\n \n # Decode only the generated tokens (exclude prompt)\n generated_tokens = outputs[0][inputs.input_ids.shape[1]:]\n summary = tokenizer.decode(generated_tokens, skip_special_tokens=True)\n summaries.append(summary.strip())\n \n total_tokens += len(generated_tokens)\n \n if (i + 1) % 10 == 0:\n print(f\" Processed {i + 1}/{len(data)} samples\")\n \n end_time = time.time()\n elapsed_time = end_time - start_time\n tokens_per_sec = total_tokens / elapsed_time\n \n if DEVICE == \"cuda\":\n peak_memory = torch.cuda.max_memory_allocated()\n peak_memory_mb = (peak_memory - initial_memory) / (1024 * 1024)\n else:\n peak_memory_mb = 0.0\n \n print(f\" Generated {total_tokens} tokens in {elapsed_time:.2f}s\")\n print(f\" Throughput: {tokens_per_sec:.2f} tokens/sec\")\n if DEVICE == \"cuda\":\n print(f\" Peak memory: {peak_memory_mb:.2f} MB\")\n \n return summaries, tokens_per_sec, peak_memory_mb\n\n\ndef calculate_rouge_scores(predictions: List[str], references: List[str]) -> Dict[str, float]:\n \"\"\"Calculate ROUGE-L scores\"\"\"\n print(\"\\\\nCalculating ROUGE-L scores...\")\n \n scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)\n scores = []\n \n for pred, ref in zip(predictions, references):\n score = scorer.score(ref, pred)\n scores.append(score['rougeL'].fmeasure)\n \n avg_score = np.mean(scores)\n std_score = np.std(scores)\n \n print(f\" ROUGE-L: {avg_score:.4f} Β± {std_score:.4f}\")\n \n return {\n \"mean\": float(avg_score),\n \"std\": float(std_score),\n \"scores\": [float(s) for s in scores]\n }\n\n\ndef benchmark_cache(\n model,\n tokenizer,\n data: List[Dict],\n cache_type: str,\n cache_implementation: str = \"default\",\n cache_config: Dict = None\n) -> Dict:\n \"\"\"Run benchmark for a specific cache configuration\"\"\"\n print(f\"\\\\n{'='*80}\")\n print(f\"Benchmarking {cache_type}\")\n print(f\"{'='*80}\")\n \n # Clear cache\n if DEVICE == \"cuda\":\n torch.cuda.empty_cache()\n gc.collect()\n \n # Generate summaries\n summaries, tokens_per_sec, peak_memory_mb = generate_summaries(\n model, \n tokenizer, \n data,\n cache_implementation=cache_implementation,\n cache_config=cache_config\n )\n \n # Calculate ROUGE scores\n references = [sample[\"highlights\"] for sample in data]\n rouge_scores = calculate_rouge_scores(summaries, references)\n \n results = {\n \"cache_type\": cache_type,\n \"cache_implementation\": cache_implementation,\n \"cache_config\": cache_config,\n \"tokens_per_sec\": float(tokens_per_sec),\n \"peak_memory_mb\": float(peak_memory_mb),\n \"rouge_l_mean\": rouge_scores[\"mean\"],\n \"rouge_l_std\": rouge_scores[\"std\"],\n \"num_samples\": len(data),\n \"total_tokens_generated\": len(summaries) * MAX_NEW_TOKENS,\n }\n \n return results, summaries\n\n\ndef main():\n \"\"\"Main benchmark function\"\"\"\n print(\"=\"*80)\n print(\"KV Cache Quantization Benchmark\")\n print(\"=\"*80)\n print(f\"Model: {MODEL_NAME}\")\n print(f\"Dataset: {DATASET_NAME}\")\n print(f\"Num samples: {NUM_SAMPLES}\")\n print(f\"Max new tokens: {MAX_NEW_TOKENS}\")\n \n # Load model and data\n model, tokenizer = load_model_and_tokenizer()\n data = load_data()\n \n # Benchmark FP16 (default) cache\n fp16_results, fp16_summaries = benchmark_cache(\n model, \n tokenizer, \n data,\n cache_type=\"FP16 (Default)\",\n cache_implementation=\"default\",\n cache_config=None\n )\n \n # Benchmark INT8 quantized cache with HQQ\n int8_results, int8_summaries = benchmark_cache(\n model,\n tokenizer,\n data,\n cache_type=\"INT8 (HQQ Quantized)\",\n cache_implementation=\"quantized\",\n cache_config={\n \"backend\": \"HQQ\",\n \"nbits\": 8,\n \"axis_key\": 1,\n \"axis_value\": 1\n }\n )\n \n # Compare results\n print(\"\\\\n\" + \"=\"*80)\n print(\"COMPARISON RESULTS\")\n print(\"=\"*80)\n \n speedup = int8_results[\"tokens_per_sec\"] / fp16_results[\"tokens_per_sec\"]\n rouge_diff = int8_results[\"rouge_l_mean\"] - fp16_results[\"rouge_l_mean\"]\n \n if fp16_results[\"peak_memory_mb\"] > 0:\n memory_savings_pct = (1 - int8_results[\"peak_memory_mb\"] / fp16_results[\"peak_memory_mb\"]) * 100\n else:\n memory_savings_pct = 0.0\n \n print(f\"\\\\nFP16 Cache:\")\n print(f\" Throughput: {fp16_results['tokens_per_sec']:.2f} tokens/sec\")\n print(f\" ROUGE-L: {fp16_results['rouge_l_mean']:.4f} Β± {fp16_results['rouge_l_std']:.4f}\")\n print(f\" Peak Memory: {fp16_results['peak_memory_mb']:.2f} MB\")\n \n print(f\"\\\\nINT8 Quantized Cache (HQQ):\")\n print(f\" Throughput: {int8_results['tokens_per_sec']:.2f} tokens/sec\")\n print(f\" ROUGE-L: {int8_results['rouge_l_mean']:.4f} Β± {int8_results['rouge_l_std']:.4f}\")\n print(f\" Peak Memory: {int8_results['peak_memory_mb']:.2f} MB\")\n \n print(f\"\\\\nComparison:\")\n print(f\" Speedup: {speedup:.2f}x\")\n print(f\" ROUGE-L Difference: {rouge_diff:+.4f}\")\n print(f\" Memory Savings: {memory_savings_pct:.1f}%\")\n \n # Prepare final results\n final_results = {\n \"config\": {\n \"model\": MODEL_NAME,\n \"dataset\": DATASET_NAME,\n \"num_samples\": NUM_SAMPLES,\n \"max_new_tokens\": MAX_NEW_TOKENS,\n \"do_sample\": DO_SAMPLE,\n \"device\": DEVICE,\n },\n \"fp16_cache\": fp16_results,\n \"int8_quantized_cache\": int8_results,\n \"comparison\": {\n \"speedup\": float(speedup),\n \"rouge_l_difference\": float(rouge_diff),\n \"memory_savings_percent\": float(memory_savings_pct),\n }\n }\n \n # Save results to JSON\n output_file = \"/tmp/kv_cache_benchmark_results.json\"\n with open(output_file, \"w\") as f:\n json.dump(final_results, f, indent=2)\n \n print(f\"\\\\n{'='*80}\")\n print(f\"Results saved to: {output_file}\")\n print(f\"{'='*80}\")\n \n return final_results\n\n\nif __name__ == \"__main__\":\n main()\n'''\n\n# Write benchmark script\nwith open('/tmp/kv_cache_benchmark.py', 'w') as f:\n f.write(benchmark_script)\n\nimport os\nos.chmod('/tmp/kv_cache_benchmark.py', 0o755)\n\n# Write README\nreadme_content = '''# KV Cache Quantization Benchmark\n\nThis benchmark compares FP16 (default) vs INT8 quantized KV cache performance using Llama-3.2-1B on the CNN/DailyMail summarization task.\n\n## Overview\n\nThe script evaluates:\n- **Throughput**: Tokens generated per second\n- **Memory Usage**: Peak memory consumption during generation\n- **Quality**: ROUGE-L scores comparing generated summaries to reference summaries\n\n## Requirements\n\nInstall the required packages:\n\n```bash\npip install transformers datasets rouge-score torch hqq accelerate\n```\n\n### GPU Requirements\n- CUDA-compatible GPU recommended (script will fall back to CPU if no GPU is available)\n- At least 8GB VRAM for Llama-3.2-1B with FP16\n- At least 4GB VRAM for INT8 quantized cache\n\n## Usage\n\n### Basic Usage\n\nRun the benchmark with default settings (100 samples):\n\n```bash\npython /tmp/kv_cache_benchmark.py\n```\n\n### Configuration\n\nYou can modify the configuration variables at the top of the script:\n\n```python\nMODEL_NAME = \"meta-llama/Llama-3.2-1B\" # Model to benchmark\nDATASET_NAME = \"cnn_dailymail\" # Dataset name\nDATASET_CONFIG = \"3.0.0\" # Dataset version\nNUM_SAMPLES = 100 # Number of test samples\nMAX_NEW_TOKENS = 128 # Max tokens to generate per sample\nDO_SAMPLE = False # Use greedy decoding\n```\n\n### Output\n\nThe script will:\n1. Load the model and dataset\n2. Run FP16 (default) cache benchmark\n3. Run INT8 quantized cache benchmark with HQQ\n4. Calculate ROUGE-L scores for both configurations\n5. Display comparison results\n6. Save detailed results to `/tmp/kv_cache_benchmark_results.json`\n\n## Results Format\n\nThe output JSON file contains:\n- Configuration details\n- FP16 cache results (throughput, memory, ROUGE-L)\n- INT8 quantized cache results\n- Comparison metrics (speedup, quality difference, memory savings)\n\nExample output:\n```json\n{\n \"config\": {\n \"model\": \"meta-llama/Llama-3.2-1B\",\n \"dataset\": \"cnn_dailymail\",\n \"num_samples\": 100,\n \"max_new_tokens\": 128,\n \"device\": \"cuda\"\n },\n \"fp16_cache\": {\n \"tokens_per_sec\": 150.5,\n \"peak_memory_mb\": 2048.3,\n \"rouge_l_mean\": 0.3245\n },\n \"int8_quantized_cache\": {\n \"tokens_per_sec\": 180.2,\n \"peak_memory_mb\": 1024.1,\n \"rouge_l_mean\": 0.3198\n },\n \"comparison\": {\n \"speedup\": 1.20,\n \"rouge_l_difference\": -0.0047,\n \"memory_savings_percent\": 50.0\n }\n}\n```\n\n## Understanding the Results\n\n### Speedup\n- Values > 1.0 indicate INT8 quantization is faster\n- Typical range: 1.1x - 1.5x speedup\n\n### Memory Savings\n- Percentage reduction in peak memory usage\n- Typical range: 40% - 50% reduction\n\n### ROUGE-L Difference\n- Negative values indicate slight quality degradation\n- Small differences (< 0.01) are generally acceptable\n- ROUGE-L measures overlap between generated and reference summaries\n\n## Troubleshooting\n\n### CUDA Out of Memory\nIf you encounter OOM errors:\n1. Reduce `NUM_SAMPLES`\n2. Reduce `MAX_NEW_TOKENS`\n3. Ensure no other processes are using GPU memory\n\n### ImportError for HQQ\nMake sure you have installed the HQQ package:\n```bash\npip install hqq\n```\n\n### Slow Performance on CPU\nThe benchmark is designed for GPU. CPU performance will be significantly slower but still functional.\n\n## Advanced Usage\n\n### Custom Cache Configurations\n\nYou can modify the cache configuration in the `benchmark_cache` function:\n\n```python\n# Example: Different quantization settings\nint4_results, int4_summaries = benchmark_cache(\n model,\n tokenizer,\n data,\n cache_type=\"INT4 (HQQ Quantized)\",\n cache_implementation=\"quantized\",\n cache_config={\n \"backend\": \"HQQ\",\n \"nbits\": 4, # 4-bit quantization\n \"axis_key\": 1,\n \"axis_value\": 1\n }\n)\n```\n\n### Different Models\n\nTo benchmark different models, change the `MODEL_NAME`:\n\n```python\nMODEL_NAME = \"meta-llama/Llama-3.2-3B\" # Larger model\n# or\nMODEL_NAME = \"microsoft/phi-2\" # Different architecture\n```\n\nMake sure the model supports the quantized cache implementation.\n\n## Technical Details\n\n### HQQ (Half-Quadratic Quantization)\n- Quantizes KV cache tensors to lower precision (INT8 or INT4)\n- Reduces memory footprint and bandwidth requirements\n- Minimal impact on generation quality\n- Supported by Hugging Face Transformers\n\n### KV Cache\n- Stores key and value tensors from attention layers\n- Avoids recomputing previous token representations\n- Memory usage grows with sequence length\n- Quantization reduces per-token memory cost\n\n## References\n\n- [Hugging Face Transformers Documentation](https://huggingface.co/docs/transformers)\n- [HQQ: Half-Quadratic Quantization](https://github.com/mobiusml/hqq)\n- [CNN/DailyMail Dataset](https://huggingface.co/datasets/cnn_dailymail)\n- [ROUGE Metric](https://huggingface.co/spaces/evaluate-metric/rouge)\n\n## License\n\nThis script is provided as-is for benchmarking purposes. Please refer to the licenses of the individual components (model, dataset, libraries) for their usage terms.\n'''\n\nwith open('/tmp/KV_CACHE_README.md', 'w') as f:\n f.write(readme_content)\n\nprint(\"βœ“ Created /tmp/kv_cache_benchmark.py\")\nprint(\"βœ“ Created /tmp/KV_CACHE_README.md\")\nprint(\"\\\\nFiles are ready to use!\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "u8gky39qvr", + "source": "# Create CSV with apache-2.0 models under 500MB\nimport csv\n\n# All apache-2.0 licensed models from search, excluding BART (>500MB)\nmodels = [\n [\"nateraw/codecarbon-text-classification\", 1000],\n [\"nickmuchi/distilroberta-finetuned-financial-text-classification\", 48900],\n [\"jxuhf/Fine-tuning-text-classification-model-Habana-Gaudi\", 736],\n [\"krupper/text-complexity-classification\", 8200],\n [\"DenilsenAxel/nlp-text-classification\", 999],\n [\"lucasgbezerra/classification_text_model\", 340],\n [\"fatenghali/text_classification_model\", 506],\n [\"maayansharon/climate_text_classification_mini_model\", 343],\n [\"Aaryan562/distilbert-base-uncased-fine-tuned-text-classification\", 283],\n [\"dmjimenezbravo/electra-small-discriminator-text-classification-en-finetuned-amazon_reviews_multi-en\", 312],\n [\"ratish/bert-textClassification_v1.1\", 278],\n [\"ratish/bert-textClassification_v1.4\", 313],\n [\"Amite5h/TextClassificationmulticlass\", 298],\n [\"Sleoruiz/roberta-base-fine-tuned-text-classification-pesos-fixed\", 7],\n [\"Sleoruiz/roberta-base-fine-tuned-text-classification-pesos-fixed-2\", 9],\n [\"Sleoruiz/roberta-bne-fine-tuned-text-classification-SL-data-augmentation-dss\", 314],\n [\"Sleoruiz/roberta-bne-fine-tuned-text-classification-SL-dss\", 454],\n [\"Cynthiaiii4/Text_classification_HW\", 8],\n [\"tKah/Textclassification-Bert\", 245],\n [\"Sleoruiz/roberta-bne-fine-tuned-text-classification-SL-1200samples\", 287],\n [\"Leslie123/stackoverflow-text-classification\", 225],\n [\"Cynthiaiii4/Text_classification_bert-base-uncased\", 6],\n [\"Cynthiaiii4/Text_classification_model_blu\", 7],\n [\"Cynthiaiii4/Text_classification_model_bbc\", 6],\n [\"sfurkan/LexBERT-textclassification-turkish-uncased\", 8],\n]\n\n# Write to CSV\noutput_path = \"/Users/akseljoonas/Documents/hf-agent/text_classification_models.csv\"\nwith open(output_path, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n writer = csv.writer(f)\n writer.writerow([\"model_id\", \"downloads\"])\n writer.writerows(models)\n\nprint(f\"βœ“ CSV file created: {output_path}\")\nprint(f\"βœ“ Total models: {len(models)}\")\nprint(f\"βœ“ Excluded: IT-community/BART_cnn_news_text_classification (>500MB)\")\n\n# Show first few rows\nprint(\"\\nFirst 5 rows:\")\nfor i, (model_id, downloads) in enumerate(models[:5], 1):\n print(f\" {i}. {model_id}: {downloads:,} downloads\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "njt45dhwbfb", + "source": "# Execute the cell above to create the CSV\n# Then verify it was created\nimport os\ncsv_path = \"/Users/akseljoonas/Documents/hf-agent/text_classification_models.csv\"\nif os.path.exists(csv_path):\n print(f\"βœ“ CSV file exists at: {csv_path}\")\n print(f\"βœ“ File size: {os.path.getsize(csv_path)} bytes\")\n \n # Read and display first few lines\n with open(csv_path, \"r\") as f:\n lines = f.readlines()\n print(f\"βœ“ Total lines: {len(lines)}\")\n print(\"\\nFirst 10 lines:\")\n for line in lines[:10]:\n print(f\" {line.rstrip()}\")\nelse:\n print(f\"βœ— CSV file not found at: {csv_path}\")\n print(\"Run the cell above first to create it.\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "704sq89c26n", + "source": "# Direct CSV creation without dependencies\ncsv_content = \"\"\"model_id,downloads\nnateraw/codecarbon-text-classification,1000\nnickmuchi/distilroberta-finetuned-financial-text-classification,48900\njxuhf/Fine-tuning-text-classification-model-Habana-Gaudi,736\nkrupper/text-complexity-classification,8200\nDenilsenAxel/nlp-text-classification,999\nlucasgbezerra/classification_text_model,340\nfatenghali/text_classification_model,506\nmaayansharon/climate_text_classification_mini_model,343\nAaryan562/distilbert-base-uncased-fine-tuned-text-classification,283\ndmjimenezbravo/electra-small-discriminator-text-classification-en-finetuned-amazon_reviews_multi-en,312\nratish/bert-textClassification_v1.1,278\nratish/bert-textClassification_v1.4,313\nAmite5h/TextClassificationmulticlass,298\nSleoruiz/roberta-base-fine-tuned-text-classification-pesos-fixed,7\nSleoruiz/roberta-base-fine-tuned-text-classification-pesos-fixed-2,9\nSleoruiz/roberta-bne-fine-tuned-text-classification-SL-data-augmentation-dss,314\nSleoruiz/roberta-bne-fine-tuned-text-classification-SL-dss,454\nCynthiaiii4/Text_classification_HW,8\ntKah/Textclassification-Bert,245\nSleoruiz/roberta-bne-fine-tuned-text-classification-SL-1200samples,287\nLeslie123/stackoverflow-text-classification,225\nCynthiaiii4/Text_classification_bert-base-uncased,6\nCynthiaiii4/Text_classification_model_blu,7\nCynthiaiii4/Text_classification_model_bbc,6\nsfurkan/LexBERT-textclassification-turkish-uncased,8\"\"\"\n\n# Write directly\nwith open(\"/Users/akseljoonas/Documents/hf-agent/text_classification_models.csv\", \"w\") as f:\n f.write(csv_content)\n\nprint(\"βœ“ CSV created successfully!\")\nprint(f\"βœ“ 25 models (apache-2.0 license, <500MB)\")\nprint(\"βœ“ 1 model excluded: IT-community/BART_cnn_news_text_classification (>500MB)\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "155tkweh88r", + "source": "# Create train_dpo.py file\nscript_content = '''\"\"\"DPO Training Script - Complete Implementation\"\"\"\nimport torch\nfrom datasets import load_dataset\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom trl import DPOTrainer, DPOConfig\n\nprint(\"=\"*80)\nprint(\"DPO Training - End-to-End Validation\")\nprint(\"=\"*80)\n\n# Configuration\nMODEL_NAME = \"Qwen/Qwen2-0.5B-Instruct\"\nDATASET_NAME = \"trl-lib/ultrafeedback_binarized\"\nOUTPUT_DIR = \"./dpo_output\"\nMAX_STEPS = 10\nBATCH_SIZE = 2\n\nprint(f\"\\\\n[CONFIG] Model: {MODEL_NAME}\")\nprint(f\"[CONFIG] Dataset: {DATASET_NAME}\")\nprint(f\"[CONFIG] Max steps: {MAX_STEPS}\")\nprint(f\"[CONFIG] Batch size: {BATCH_SIZE}\")\n\n# Step 1: Load tokenizer\nprint(\"\\\\n[1/6] Loading tokenizer...\")\ntokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\nprint(f\"βœ“ Tokenizer loaded\")\n\n# Step 2: Load dataset\nprint(\"\\\\n[2/6] Loading dataset...\")\ndataset = load_dataset(DATASET_NAME, split=\"train[:100]\")\nprint(f\"βœ“ Dataset loaded: {len(dataset)} samples\")\n\n# Step 3: Load model\nprint(\"\\\\n[3/6] Loading model...\")\nmodel = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,\n device_map=\"auto\",\n)\nprint(f\"βœ“ Model loaded: {model.num_parameters()/1e6:.1f}M parameters\")\n\n# Step 4: Configure training\nprint(\"\\\\n[4/6] Configuring DPO training...\")\ntraining_args = DPOConfig(\n output_dir=OUTPUT_DIR,\n max_steps=MAX_STEPS,\n per_device_train_batch_size=BATCH_SIZE,\n learning_rate=5e-7,\n logging_steps=2,\n save_steps=10,\n beta=0.1,\n fp16=torch.cuda.is_available(),\n remove_unused_columns=False,\n report_to=\"none\",\n)\nprint(\"βœ“ Configuration created\")\n\n# Step 5: Train\nprint(\"\\\\n[5/6] Starting DPO training...\")\nprint(\"-\"*80)\ntrainer = DPOTrainer(\n model=model,\n args=training_args,\n train_dataset=dataset,\n tokenizer=tokenizer,\n)\ntrain_result = trainer.train()\nprint(\"-\"*80)\nprint(f\"βœ“ Training completed! Loss: {train_result.training_loss:.4f}\")\n\n# Step 6: Save\nprint(\"\\\\n[6/6] Saving model...\")\ntrainer.save_model(OUTPUT_DIR)\nprint(f\"βœ“ Model saved to {OUTPUT_DIR}\")\n\nprint(\"\\\\n\" + \"=\"*80)\nprint(\"DPO TRAINING COMPLETED SUCCESSFULLY!\")\nprint(\"=\"*80)\nprint(f\"\\\\nOutput: {OUTPUT_DIR}\")\nprint(f\"Steps: {train_result.global_step}\")\nprint(f\"Final loss: {train_result.training_loss:.4f}\")\n'''\n\nimport os\nos.chdir('/Users/akseljoonas/Documents/hf-agent')\nwith open('train_dpo.py', 'w') as f:\n f.write(script_content)\n \nprint(\"βœ“ train_dpo.py created successfully!\")\nprint(f\"Location: {os.path.abspath('train_dpo.py')}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "wdnesxsmfq", + "source": "# Check if required packages are installed\nimport subprocess\nimport sys\n\npackages = ['torch', 'transformers', 'datasets', 'trl']\n\nprint(\"Checking installed packages...\")\nfor package in packages:\n try:\n __import__(package)\n version = subprocess.run([sys.executable, '-m', 'pip', 'show', package], \n capture_output=True, text=True, check=True)\n version_line = [line for line in version.stdout.split('\\n') if line.startswith('Version:')]\n if version_line:\n print(f\"βœ“ {package}: {version_line[0].split(':')[1].strip()}\")\n else:\n print(f\"βœ“ {package}: installed\")\n except ImportError:\n print(f\"βœ— {package}: NOT INSTALLED\")\n print(f\" Installing {package}...\")\n subprocess.run([sys.executable, '-m', 'pip', 'install', package], check=True)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "6lxro03b5k", + "source": "# Run the train_dpo.py script\nimport subprocess\nimport os\n\nos.chdir('/Users/akseljoonas/Documents/hf-agent')\n\nprint(\"Starting DPO training script...\")\nprint(\"=\"*80)\n\n# Run the script and capture output in real-time\nprocess = subprocess.Popen(\n ['python', 'train_dpo.py'],\n stdout=subprocess.PIPE,\n stderr=subprocess.STDOUT,\n text=True,\n bufsize=1\n)\n\n# Print output in real-time\nfor line in process.stdout:\n print(line, end='')\n\n# Wait for completion\nreturn_code = process.wait()\n\nprint(\"\\n\" + \"=\"*80)\nif return_code == 0:\n print(\"βœ“ Script completed successfully!\")\nelse:\n print(f\"βœ— Script failed with return code: {return_code}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "kk03ij6wpx", + "source": "# Alternative: Run the training directly in the notebook for immediate feedback\nimport os\nos.chdir('/Users/akseljoonas/Documents/hf-agent')\n\n# Execute the script\nexec(open('train_dpo.py').read())", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "58ilnz6pedu", + "source": "# Write the file directly\nimport os\nos.chdir('/Users/akseljoonas/Documents/hf-agent')\n\nwith open('train_dpo.py', 'w', encoding='utf-8') as f:\n f.write('\"\"\"DPO Training Script - Complete Implementation\"\"\"\\n')\n f.write('import torch\\n')\n f.write('from datasets import load_dataset\\n')\n f.write('from transformers import AutoModelForCausalLM, AutoTokenizer\\n')\n f.write('from trl import DPOTrainer, DPOConfig\\n\\n')\n f.write('print(\"=\"*80)\\n')\n f.write('print(\"DPO Training - End-to-End Validation\")\\n')\n f.write('print(\"=\"*80)\\n\\n')\n f.write('# Configuration\\n')\n f.write('MODEL_NAME = \"Qwen/Qwen2-0.5B-Instruct\"\\n')\n f.write('DATASET_NAME = \"trl-lib/ultrafeedback_binarized\"\\n')\n f.write('OUTPUT_DIR = \"./dpo_output\"\\n')\n f.write('MAX_STEPS = 10\\n')\n f.write('BATCH_SIZE = 2\\n\\n')\n f.write('print(f\"\\\\n[CONFIG] Model: {MODEL_NAME}\")\\n')\n f.write('print(f\"[CONFIG] Dataset: {DATASET_NAME}\")\\n')\n f.write('print(f\"[CONFIG] Max steps: {MAX_STEPS}\")\\n')\n f.write('print(f\"[CONFIG] Batch size: {BATCH_SIZE}\")\\n\\n')\n f.write('# Step 1: Load tokenizer\\n')\n f.write('print(\"\\\\n[1/6] Loading tokenizer...\")\\n')\n f.write('tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\\n')\n f.write('if tokenizer.pad_token is None:\\n')\n f.write(' tokenizer.pad_token = tokenizer.eos_token\\n')\n f.write('print(f\"βœ“ Tokenizer loaded\")\\n\\n')\n f.write('# Step 2: Load dataset\\n')\n f.write('print(\"\\\\n[2/6] Loading dataset...\")\\n')\n f.write('dataset = load_dataset(DATASET_NAME, split=\"train[:100]\")\\n')\n f.write('print(f\"βœ“ Dataset loaded: {len(dataset)} samples\")\\n\\n')\n f.write('# Step 3: Load model\\n')\n f.write('print(\"\\\\n[3/6] Loading model...\")\\n')\n f.write('model = AutoModelForCausalLM.from_pretrained(\\n')\n f.write(' MODEL_NAME,\\n')\n f.write(' torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,\\n')\n f.write(' device_map=\"auto\",\\n')\n f.write(')\\n')\n f.write('print(f\"βœ“ Model loaded: {model.num_parameters()/1e6:.1f}M parameters\")\\n\\n')\n f.write('# Step 4: Configure training\\n')\n f.write('print(\"\\\\n[4/6] Configuring DPO training...\")\\n')\n f.write('training_args = DPOConfig(\\n')\n f.write(' output_dir=OUTPUT_DIR,\\n')\n f.write(' max_steps=MAX_STEPS,\\n')\n f.write(' per_device_train_batch_size=BATCH_SIZE,\\n')\n f.write(' learning_rate=5e-7,\\n')\n f.write(' logging_steps=2,\\n')\n f.write(' save_steps=10,\\n')\n f.write(' beta=0.1,\\n')\n f.write(' fp16=torch.cuda.is_available(),\\n')\n f.write(' remove_unused_columns=False,\\n')\n f.write(' report_to=\"none\",\\n')\n f.write(')\\n')\n f.write('print(\"βœ“ Configuration created\")\\n\\n')\n f.write('# Step 5: Train\\n')\n f.write('print(\"\\\\n[5/6] Starting DPO training...\")\\n')\n f.write('print(\"-\"*80)\\n')\n f.write('trainer = DPOTrainer(\\n')\n f.write(' model=model,\\n')\n f.write(' args=training_args,\\n')\n f.write(' train_dataset=dataset,\\n')\n f.write(' tokenizer=tokenizer,\\n')\n f.write(')\\n')\n f.write('train_result = trainer.train()\\n')\n f.write('print(\"-\"*80)\\n')\n f.write('print(f\"βœ“ Training completed! Loss: {train_result.training_loss:.4f}\")\\n\\n')\n f.write('# Step 6: Save\\n')\n f.write('print(\"\\\\n[6/6] Saving model...\")\\n')\n f.write('trainer.save_model(OUTPUT_DIR)\\n')\n f.write('print(f\"βœ“ Model saved to {OUTPUT_DIR}\")\\n\\n')\n f.write('print(\"\\\\n\" + \"=\"*80)\\n')\n f.write('print(\"DPO TRAINING COMPLETED SUCCESSFULLY!\")\\n')\n f.write('print(\"=\"*80)\\n')\n f.write('print(f\"\\\\nOutput: {OUTPUT_DIR}\")\\n')\n f.write('print(f\"Steps: {train_result.global_step}\")\\n')\n f.write('print(f\"Final loss: {train_result.training_loss:.4f}\")\\n')\n\nprint(\"βœ“ train_dpo.py created!\")\nprint(f\"File location: {os.path.abspath('train_dpo.py')}\")\n\n# Verify file exists\nif os.path.exists('train_dpo.py'):\n file_size = os.path.getsize('train_dpo.py')\n print(f\"File size: {file_size} bytes\")\nelse:\n print(\"ERROR: File was not created!\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "7qvebak22u2", + "source": "## Next Steps: Run the Training Script\n\nThe `train_dpo.py` file has been created. To run it:\n\n1. **From Terminal/Command Line:**\n ```bash\n cd /Users/akseljoonas/Documents/hf-agent\n python train_dpo.py\n ```\n\n2. **Or run directly in this notebook** by executing the next cell below.", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19f3dd6b", + "metadata": {}, + "outputs": [], + "source": [ + "# Seed examples for task bootstrapping\n", + "tasks_with_difficulty = {\n", + " # lewis\n", + " \"Evaluate models {M_i} on benchmarks {B_i}\": \"Easy\",\n", + " \"Train models {M_i} on datasets {D_i} with benchmarks {B_i}\": \"Medium\",\n", + " \"Run an ablation for hyperparameter P for model M on dataset D\": \"Hard\",\n", + " \"Generate completions with model M on dataset D using engine E\": \"Medium\",\n", + " \"Merge models {M_i} using linear averaging to find the best result on benchmarks {B_i}\": \"Hard\",\n", + " \"Given datasets {D_i}, ablate the best SFT mixture for model M across benchmarks {B_i}\": \"Very hard\",\n", + " \"Decontaminate dataset D against benchmarks {B_i}\": \"Hard\",\n", + " \"Benchmark RL framework F for best throughput on G GPUs\": \"Very hard\",\n", + " \"Implement post-training algorithm A from paper P in framework F. Validate it runs end-to-end\": \"Very hard\",\n", + " \"Implement benchmark B in framework F. Validate it reproduces some published results\": \"Very hard\",\n", + " \"Format dataset D for compatibility with framework F on task T\": \"Easy\",\n", + "\n", + " # abubakar\n", + " \"Remove the background from this image: [image path]\": \"Easy\",\n", + " \"Transcribe all of the audio files in this directory\": \"Easy\",\n", + " \"Transcribe all of the audio files in this directory, choose the model that'll be cheapest and also relatively accurate\": \"Medium (judgment call or interaction needed to figure out what accuracy levels are acceptable)\",\n", + " \"Remove the background music from this audio file\": \"Medium (needs to find Gradio Space and call its API0\",\n", + " \"Change this video track to be from English to Spanish\": \"Medium (needs to link several models together)\",\n", + " \"Translate this flyer from English to Spanish, keeping the layout and images the same\": \"Medium (needs to link several models together)\",\n", + "\n", + " # leandro\n", + " \"What's the best model for X?\": \"Easy\",\n", + " \"What datasets are available for X? (X={domain x task x modality})\": \"Easy\",\n", + " \"Is there a space to do Y?\": \"Easy\",\n", + " \"I have this script and this error - what's the issue?\": \"Medium\",\n", + " \"This space is broken, how can i fix it?\": \"Medium\",\n", + " \"I built a space but it is super slow. What can I do?\": \"Medium\",\n", + " \"How can I run modal X locally?\": \"Medium\",\n", + " \"I want to build a space with model Y to do X?\": \"Hard\",\n", + " \"How can I serve a model with multiple LoRAs?\": \"Hard\",\n", + "\n", + " # claude\n", + " \"What's the best model for sentiment analysis on financial text?\": \"Easy\",\n", + " \"Are there any medical image segmentation datasets on HuggingFace for CT scans?\": \"Easy\",\n", + " \"Which text classification models support 4-bit quantization?\": \"Medium\",\n", + " \"Are there inference endpoints available for Whisper large-v3?\": \"Easy\",\n", + " \"What's the license for the SA-Med2D-20M dataset?\": \"Easy\",\n", + " \"Which vision models fit in 8GB VRAM for image segmentation?\": \"Medium\",\n", + " \"What datasets are available for 3D medical image segmentation?\": \"Medium\",\n", + " \"Is there a space to do text-to-speech with emotion control?\": \"Medium\",\n", + " \"I'm getting \\\"CUDA out of memory\\\" when loading Llama-2-7b even though nvidia-smi shows I have 6GB free - what's the issue?\": \"Medium\",\n", + " \"My Gradio space shows \\\"Connection errored out\\\" after working fine yesterday, no code changes - how can I fix it?\": \"Medium\",\n", + " \"I built a Gradio space for Stable Diffusion but inference takes 5+ minutes on a 4090 - what can I do?\": \"Medium\",\n", + " \"My Whisper model outputs different transcriptions after quantization to int8 - why?\": \"Medium\",\n", + " \"Getting \\\"RuntimeError: CUDA error: out of memory. Tried to allocate 70.00 MiB\\\" but only 2.87 GiB is allocated - what's happening?\": \"Medium\",\n", + " \"My HuggingFace space build fails with \\\"failed to create containerd task\\\" - how to fix?\": \"Medium\",\n", + " \"DistilBERT model gives \\\"you should probably train your model\\\" warning even though it's a pretrained model from the Hub\": \"Easy\",\n", + " \"Space was working fine but now receiving build errors - receiving this error even with a new space\": \"Medium\",\n", + " \"Inference is correct locally but wrong on deployed space\": \"Medium\",\n", + " \"Getting CUDA OOM despite having enough memory according to nvidia-smi\": \"Medium\",\n", + " \"How can I run Mistral-7B-v0.1 locally with multiple LoRA adapters?\": \"Hard\",\n", + " \"How can I serve Llama-2-7b with vLLM and dynamically load multiple LoRA adapters?\": \"Hard\",\n", + " \"How do I batch inference requests in my Gradio space for better throughput?\": \"Medium\",\n", + " \"Can I run Whisper large-v3 with faster-whisper for 4x speedup?\": \"Medium\",\n", + " \"How to run Llama 2 on CPU after fine-tuning with LoRA?\": \"Medium\",\n", + " \"Best way to handle 50+ concurrent requests in a Gradio space without OOM?\": \"Hard\",\n", + " \"How do I add custom stopping criteria for text generation with Transformers?\": \"Hard\",\n", + " \"Can I merge multiple LoRA adapters before inference to reduce latency?\": \"Hard\",\n", + " \"How can I optimize my LLM inference with one base LLM and multiple LoRA adapters?\": \"Hard\",\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7014bef", + "metadata": {}, + "outputs": [], + "source": [ + "len(tasks_with_difficulty)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a8bd7ed", + "metadata": {}, + "outputs": [], + "source": [ + "import litellm\n", + "import json\n", + "from pydantic import BaseModel\n", + "from enum import Enum\n", + "\n", + "\n", + "class Difficulty(str, Enum):\n", + " EASY = \"Easy\"\n", + " MEDIUM = \"Medium\"\n", + " HARD = \"Hard\"\n", + " VERY_HARD = \"Very hard\"\n", + "\n", + "\n", + "class Task(BaseModel):\n", + " description: str\n", + " difficulty: Difficulty\n", + "\n", + "\n", + "class GeneratedTasks(BaseModel):\n", + " tasks: list[Task]\n", + "\n", + "\n", + "def build_prompt(tasks_dict: dict[str, str]) -> str:\n", + " task_descriptions = \"\".join(\n", + " [f'- \"{task}\" [{difficulty}]\\n' for task, difficulty in tasks_dict.items()]\n", + " )\n", + "\n", + " return f\"\"\"Given the following examples of tasks (with their estimated difficulty levels in brackets):\n", + "\n", + "{task_descriptions}\n", + "\n", + "Generate exactly 10 new unique tasks with their difficulty levels (Easy, Medium, Hard, or Very hard).\n", + "The new tasks should be bootstrapped by analogy or creative mutation of the provided ones, but not be direct copies.\n", + "Vary the domains, instructions, and scenario details. Write crisp, concrete task phrasing. Preserve variety in both tasks and difficulties.\n", + "Do not repeat any of the input tasks verbatim. Create plausible, meaningful tasks relevant to LLM training, evaluation, dataprocessing, issue handling, tooling, etc.\n", + "\"\"\"\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85ef3dcb", + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"gpt-5\"\n", + "\n", + "# Number of iterations to generate tasks (10 tasks per iteration)\n", + "num_iterations = 20\n", + "\n", + "# Copy the seed tasks to avoid modifying the original\n", + "all_tasks = tasks_with_difficulty.copy()\n", + "\n", + "for i in range(num_iterations):\n", + " prompt = build_prompt(all_tasks)\n", + "\n", + " # Query LLM using litellm with structured output\n", + " response = litellm.completion(\n", + " model=model_name,\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are an expert at generating diverse ML/AI task instructions using products from HuggingFace and can enumerate them with proper difficulty.\",\n", + " },\n", + " {\"role\": \"user\", \"content\": prompt},\n", + " ],\n", + " response_format=GeneratedTasks,\n", + " )\n", + "\n", + " # Parse the structured output\n", + " generated = GeneratedTasks.model_validate_json(\n", + " response.choices[0].message.content\n", + " )\n", + "\n", + " # Add new tasks to the dictionary\n", + " new_count = 0\n", + " for task in generated.tasks:\n", + " if task.description not in all_tasks:\n", + " all_tasks[task.description] = task.difficulty.value\n", + " new_count += 1\n", + "\n", + " print(f\"Iteration {i + 1}/{num_iterations}: Added {new_count} new tasks. Total: {len(all_tasks)}\")\n", + "\n", + "# Save to disk\n", + "with open(\"generated_tasks_with_difficulty.json\", \"w\") as f:\n", + " json.dump(all_tasks, f, indent=2)\n", + "\n", + "print(f\"\\nFinal task count: {len(all_tasks)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c0ad570", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import Dataset\n", + "\n", + "# Convert dict to proper columns\n", + "questions = list(all_tasks.keys())\n", + "difficulties = list(all_tasks.values())\n", + "data = {\"question\": questions, \"difficulty\": difficulties}\n", + "\n", + "dataset = Dataset.from_dict(data)\n", + "print(f\"\\nDataset: {len(dataset)} rows\")\n", + "print(f\"Sample: {dataset[0]['question']} ({dataset[0]['difficulty']})\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "427a2186", + "metadata": {}, + "outputs": [], + "source": [ + "dataset.push_to_hub(\"akseljoonas/benchmark-tasks\", private=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "204b9760", + "metadata": {}, + "outputs": [], + "source": [ + "all_tasks = json.load(open(\"generated_tasks_with_difficulty.json\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50e67652", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract variables from each question using LLM\n", + "\n", + "class ExtractedVariables(BaseModel):\n", + " variables: list[str] # List of variable names/placeholders found in the question\n", + "\n", + "\n", + "def extract_variables_prompt(question: str) -> str:\n", + " return f\"\"\"Analyze this task description and list any variables or placeholders that would need to be filled in with specific values. This is a AI/ML/LLM task, so the variables are typically model names, dataset names, hyperparameter names, etc.\n", + "\n", + "Task: \"{question}\"\n", + "\n", + "Variables are typically indicated by:\n", + "- Curly braces like {{M_i}}, {{D_i}}, {{B_i}}\n", + "- Single letters representing placeholders like \"model M\", \"dataset D\", \"hyperparameter P\"\n", + "- Bracketed placeholders like [image path]\n", + "- Generic references like \"X\", \"Y\" that stand for specific values\n", + "\n", + "Examples of tasks with variables:\n", + "\n", + " \"Evaluate models {{M_i}} on benchmarks {{B_i}}\" -> variables: [\"M_i\", \"B_i\"]\n", + " \"Train models {{M_i}} on datasets {{D_i}} with benchmarks {{B_i}}\" -> variables: [\"M_i\", \"D_i\", \"B_i\"]\n", + " \"Run an ablation for hyperparameter P for model M on dataset D\" -> variables: [\"P\", \"M\", \"D\"]\n", + " \"Generate completions with model M on dataset D using engine E\" -> variables: [\"M\", \"D\", \"E\"]\n", + " \"Merge models {{M_i}} using linear averaging to find the best result on benchmarks {{B_i}}\" -> variables: [\"M_i\", \"B_i\"]\n", + " \"Given datasets {{D_i}}, ablate the best SFT mixture for model M across benchmarks {{B_i}}\" -> variables: [\"D_i\", \"M\", \"B_i\"]\n", + " \"Decontaminate dataset D against benchmarks {{B_i}}\" -> variables: [\"D\", \"B_i\"]\n", + " \"Benchmark RL framework F for best throughput on G GPUs\" -> variables: [\"F\", \"G\"]\n", + " \"Implement post-training algorithm A from paper P in framework F. Validate it runs end-to-end\" -> variables: [\"A\", \"P\", \"F\"]\n", + " \"Implement benchmark B in framework F. Validate it reproduces some published results\" -> variables: [\"B\", \"F\"]\n", + " \"Format dataset D for compatibility with framework F on task T\" -> variables: [\"D\", \"F\", \"T\"]\n", + " \"Remove the background from this image: [image path]\" -> variables: [\"[image path]\"]\n", + " \"Are there any medical image segmentation datasets on HuggingFace for CT scans?\" -> variables: []\n", + " \"Build a sharded FAISS IVF-PQ index for 100M embeddings stored on S3; integrate with HF datasets streaming and report recall@10 and QPS\" -> variables: []\n", + "\n", + "\n", + "Return an empty list if the question is fully concrete with no variables.\n", + "Only return the variable names/symbols, not their descriptions.\"\"\"\n", + "\n", + "\n", + "# Run extraction for each question in parallel\n", + "from concurrent.futures import ThreadPoolExecutor, as_completed\n", + "\n", + "variable_model = \"gpt-5-mini\"\n", + "\n", + "\n", + "def extract_variables_for_task(question: str, difficulty: str) -> dict:\n", + " \"\"\"Extract variables for a single task and return the record.\"\"\"\n", + " response = litellm.completion(\n", + " model=variable_model,\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are an expert at identifying placeholder variables in task descriptions.\",\n", + " },\n", + " {\"role\": \"user\", \"content\": extract_variables_prompt(question)},\n", + " ],\n", + " response_format=ExtractedVariables,\n", + " )\n", + "\n", + " extracted = ExtractedVariables.model_validate_json(\n", + " response.choices[0].message.content\n", + " )\n", + "\n", + " return {\n", + " \"question\": question,\n", + " \"difficulty\": difficulty,\n", + " \"var_list\": extracted.variables,\n", + " }\n", + "\n", + "\n", + "# Run in parallel with 100 workers\n", + "tasks_with_metadata: list[dict] = []\n", + "all_variables: set[str] = set()\n", + "questions_with_vars: dict[str, list[str]] = {}\n", + "\n", + "with ThreadPoolExecutor(max_workers=100) as executor:\n", + " futures = {\n", + " executor.submit(extract_variables_for_task, q, d): q\n", + " for q, d in all_tasks.items()\n", + " }\n", + "\n", + " for future in as_completed(futures):\n", + " record = future.result()\n", + " tasks_with_metadata.append(record)\n", + "\n", + " if record[\"var_list\"]:\n", + " questions_with_vars[record[\"question\"]] = record[\"var_list\"]\n", + " all_variables.update(record[\"var_list\"])\n", + "\n", + " print(f\"Processed {len(tasks_with_metadata)} tasks\")\n", + "\n", + "# Save to JSONL\n", + "with open(\"tasks_with_variables.jsonl\", \"w\") as f:\n", + " for record in tasks_with_metadata:\n", + " f.write(json.dumps(record) + \"\\n\")\n", + "\n", + "print(f\"Saved {len(tasks_with_metadata)} tasks to tasks_with_variables.jsonl\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "548f1bf0", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Questions with variables: {len(questions_with_vars)} / {len(all_tasks)}\")\n", + "print(f\"\\nUnique variables found ({len(all_variables)}):\")\n", + "for var in sorted(all_variables):\n", + " print(f\" - {var}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3cef6645", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 250 tasks\n", + "Questions with variables: 111 / 250\n", + "\n", + "Unique variables found (29):\n", + " - A\n", + " - A_i\n", + " - B\n", + " - B_i\n", + " - C\n", + " - D\n", + " - D_i\n", + " - E\n", + " - F\n", + " - G\n", + " - M\n", + " - M0\n", + " - M_i\n", + " - N\n", + " - P\n", + " - R\n", + " - R_i\n", + " - S\n", + " - T\n", + " - T_i\n", + " - X\n", + " - Y\n", + " - [audio file]\n", + " - [directory]\n", + " - [image path]\n", + " - baseline\n", + " - domain\n", + " - modality\n", + " - task\n" + ] + } + ], + "source": [ + "# Load verified tasks and print all variables\n", + "with open(\"tasks_with_variables.jsonl\", \"r\") as f:\n", + " verified_tasks = [json.loads(line) for line in f]\n", + "\n", + "all_variables = set()\n", + "questions_with_vars = {}\n", + "\n", + "for task in verified_tasks:\n", + " if task[\"var_list\"]:\n", + " questions_with_vars[task[\"question\"]] = task[\"var_list\"]\n", + " all_variables.update(task[\"var_list\"])\n", + "\n", + "print(f\"Loaded {len(verified_tasks)} tasks\")\n", + "print(f\"Questions with variables: {len(questions_with_vars)} / {len(verified_tasks)}\")\n", + "print(f\"\\nUnique variables found ({len(all_variables)}):\")\n", + "for var in sorted(all_variables):\n", + " print(f\" - {var}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca774044", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Filling variables: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 250/250 [21:21<00:00, 5.13s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved 250 tasks to filled_tasks.jsonl\n", + "Tasks that had variables filled: 111\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import asyncio\n", + "import os\n", + "from claude_agent_sdk import (\n", + " query,\n", + " ClaudeAgentOptions,\n", + " AssistantMessage,\n", + " ResultMessage,\n", + " TextBlock,\n", + ")\n", + "\n", + "\n", + "def build_fill_prompt(task: dict) -> str:\n", + " vars_str = \", \".join(task[\"var_list\"])\n", + " return f\"\"\"You have access to HuggingFace tools via MCP. Use them to find real, concrete values to fill in the variables in this task.\n", + "\n", + "Task template: \"{task[\"question\"]}\"\n", + "Variables to fill: {vars_str}\n", + "\n", + "Search HuggingFace for real models, datasets, benchmarks, frameworks, etc. that would make this task concrete and executable.\n", + "Pick the most popular, well-known resources (models etc) when possible.\n", + "\n", + "Return ONLY the filled question in the end with variables replaced by concrete values. No JSON, no explanation, just the filled question.\n", + "\n", + "Example:\n", + "Task: \"Evaluate models {{M_i}} on benchmarks {{B_i}}\"\n", + "Variables: M_i, B_i\n", + "Response: Evaluate models Qwen/Qwen3-4B-Instruct-2507, mistralai/Devstral-Small-2-24B-Instruct-2512 on benchmarks hellaswag, google/frames-benchmark\n", + "\"\"\"\n", + "\n", + "\n", + "# Semaphore to limit concurrent processes\n", + "MAX_CONCURRENT = 5\n", + "semaphore = asyncio.Semaphore(MAX_CONCURRENT)\n", + "\n", + "\n", + "async def fill_task_variables(task: dict) -> dict:\n", + " \"\"\"Use Claude Agent SDK to fill in variables for a single task.\"\"\"\n", + " if not task[\"var_list\"]:\n", + " return task.copy()\n", + "\n", + " async with semaphore:\n", + " prompt = build_fill_prompt(task)\n", + " filled_question = None\n", + " all_messages = []\n", + "\n", + " async for message in query(\n", + " prompt=prompt,\n", + " options=ClaudeAgentOptions(\n", + " cwd=os.getcwd(),\n", + " permission_mode=\"bypassPermissions\",\n", + " disallowed_tools=[\n", + " \"Write\", \"Edit\", \"Bash\", \"Glob\", \"Grep\"\n", + " \n", + " ],\n", + " ),\n", + " ):\n", + " all_messages.append(message)\n", + "\n", + " # Extract text from assistant messages\n", + " if isinstance(message, AssistantMessage):\n", + " for block in message.content:\n", + " if isinstance(block, TextBlock):\n", + " filled_question = block.text\n", + " # Check for result messages\n", + " elif isinstance(message, ResultMessage):\n", + " if message.is_error:\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(f\"ERROR for task: {task['question']}\")\n", + " print(f\"Error subtype: {message.subtype}\")\n", + " print(\"\\nFull messages:\")\n", + " for msg in all_messages:\n", + " print(f\" {msg}\")\n", + " print(\"=\" * 80)\n", + " raise RuntimeError(f\"Agent error: {message.subtype}\")\n", + " elif message.result:\n", + " filled_question = message.result\n", + "\n", + " # Use filled question or fall back to original\n", + " if filled_question:\n", + " filled_question = filled_question.strip()\n", + " else:\n", + " filled_question = task[\"question\"]\n", + "\n", + " return {\n", + " \"question\": filled_question,\n", + " \"difficulty\": task[\"difficulty\"],\n", + " \"var_list\": task[\"var_list\"],\n", + " }\n", + "\n", + "\n", + "# Run all tasks in parallel with tqdm progress\n", + "from tqdm.asyncio import tqdm_asyncio\n", + "\n", + "\n", + "async def fill_all_tasks_parallel(tasks: list[dict]) -> list[dict]:\n", + " \"\"\"Fill all tasks with limited concurrency and progress bar.\"\"\"\n", + " coros = [fill_task_variables(t) for t in tasks]\n", + " return await tqdm_asyncio.gather(*coros, desc=\"Filling variables\")\n", + "\n", + "\n", + "# Process all tasks (with and without variables)\n", + "filled_tasks = await fill_all_tasks_parallel(verified_tasks)\n", + "\n", + "# Save to JSONL (same structure: question, difficulty, var_list)\n", + "with open(\"filled_tasks.jsonl\", \"w\") as f:\n", + " for task in filled_tasks:\n", + " f.write(json.dumps(task) + \"\\n\")\n", + "\n", + "tasks_with_vars_count = sum(1 for t in verified_tasks if t[\"var_list\"])\n", + "print(f\"Saved {len(filled_tasks)} tasks to filled_tasks.jsonl\")\n", + "print(f\"Tasks that had variables filled: {tasks_with_vars_count}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44c4e671", + "metadata": {}, + "outputs": [], + "source": "from pathlib import Path\n\nfuse_lora_content = r'''#!/usr/bin/env python3\n\"\"\"\nLoRA Fusion and Verification Script\n\nThis script:\n1. Loads a base model (Llama-2-7b-hf) and LoRA adapter (alpaca-lora-7b)\n2. Merges/fuses the LoRA weights into the base model\n3. Exports the fused model as safetensors format\n4. Verifies logits parity between on-the-fly LoRA and fused model\n5. Reports detailed metrics (MSE, max absolute difference, relative error)\n\"\"\"\n\nimport os\nimport torch\nimport numpy as np\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom peft import PeftModel\nimport gc\n\n\ndef print_section(title):\n \"\"\"Print a formatted section header\"\"\"\n print(\"\\n\" + \"=\"*80)\n print(f\" {title}\")\n print(\"=\"*80 + \"\\n\")\n\n\ndef free_memory():\n \"\"\"Free up GPU memory\"\"\"\n gc.collect()\n torch.cuda.empty_cache()\n\n\ndef load_models(base_model_name, lora_adapter_name):\n \"\"\"\n Load base model and LoRA adapter model\n \n Args:\n base_model_name: HuggingFace model ID for base model\n lora_adapter_name: HuggingFace model ID for LoRA adapter\n \n Returns:\n tuple: (lora_model, tokenizer)\n \"\"\"\n print_section(\"Loading Base Model and LoRA Adapter\")\n \n print(f\"Loading base model: {base_model_name}\")\n print(\"Using torch.float16 for memory efficiency...\")\n \n base_model = AutoModelForCausalLM.from_pretrained(\n base_model_name,\n torch_dtype=torch.float16,\n device_map=\"auto\",\n trust_remote_code=True\n )\n \n print(f\"Base model loaded successfully\")\n print(f\" - Model type: {type(base_model).__name__}\")\n print(f\" - Device map: {base_model.hf_device_map}\")\n \n print(f\"\\nLoading LoRA adapter: {lora_adapter_name}\")\n \n lora_model = PeftModel.from_pretrained(\n base_model,\n lora_adapter_name,\n torch_dtype=torch.float16,\n )\n \n print(f\"LoRA adapter loaded successfully\")\n print(f\" - Adapter type: {type(lora_model).__name__}\")\n \n print(f\"\\nLoading tokenizer from: {base_model_name}\")\n tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n \n # Set pad token if not present\n if tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n print(\" - Set pad_token to eos_token\")\n \n print(f\"Tokenizer loaded successfully\")\n \n return lora_model, tokenizer\n\n\ndef merge_and_export(lora_model, output_dir):\n \"\"\"\n Merge LoRA weights into base model and export as safetensors\n \n Args:\n lora_model: PEFT model with LoRA adapter\n output_dir: Directory to save the fused model\n \n Returns:\n merged_model: The fused model\n \"\"\"\n print_section(\"Merging LoRA Weights into Base Model\")\n \n print(\"Calling merge_and_unload()...\")\n merged_model = lora_model.merge_and_unload()\n \n print(\"LoRA weights successfully merged into base model\")\n print(f\" - Merged model type: {type(merged_model).__name__}\")\n \n print(f\"\\nExporting fused model to: {output_dir}\")\n print(\"Format: safetensors (safe_serialization=True)\")\n \n # Create output directory if it doesn't exist\n os.makedirs(output_dir, exist_ok=True)\n \n # Save the merged model\n merged_model.save_pretrained(\n output_dir,\n safe_serialization=True,\n max_shard_size=\"5GB\"\n )\n \n print(f\"Model successfully saved to {output_dir}\")\n \n # Also save the tokenizer\n tokenizer = lora_model.tokenizer if hasattr(lora_model, 'tokenizer') else None\n if tokenizer:\n tokenizer.save_pretrained(output_dir)\n print(f\"Tokenizer also saved to {output_dir}\")\n \n return merged_model\n\n\ndef generate_logits(model, tokenizer, prompt, max_length=50):\n \"\"\"\n Generate logits for a given prompt\n \n Args:\n model: The model to use for generation\n tokenizer: Tokenizer for encoding the prompt\n prompt: Text prompt\n max_length: Maximum sequence length\n \n Returns:\n torch.Tensor: Logits from the model\n \"\"\"\n # Tokenize input\n inputs = tokenizer(prompt, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_length)\n \n # Move inputs to the same device as model\n device = next(model.parameters()).device\n inputs = {k: v.to(device) for k, v in inputs.items()}\n \n # Generate logits\n with torch.no_grad():\n outputs = model(**inputs)\n logits = outputs.logits\n \n return logits\n\n\ndef calculate_metrics(logits1, logits2):\n \"\"\"\n Calculate metrics between two sets of logits\n \n Args:\n logits1: First set of logits\n logits2: Second set of logits\n \n Returns:\n dict: Dictionary containing various metrics\n \"\"\"\n # Convert to numpy for easier computation\n logits1_np = logits1.cpu().float().numpy()\n logits2_np = logits2.cpu().float().numpy()\n \n # Calculate metrics\n mse = np.mean((logits1_np - logits2_np) ** 2)\n mae = np.mean(np.abs(logits1_np - logits2_np))\n max_abs_diff = np.max(np.abs(logits1_np - logits2_np))\n \n # Relative error (avoid division by zero)\n epsilon = 1e-8\n relative_error = np.mean(np.abs(logits1_np - logits2_np) / (np.abs(logits1_np) + epsilon))\n \n # Cosine similarity (flatten the tensors)\n flat1 = logits1_np.flatten()\n flat2 = logits2_np.flatten()\n cosine_sim = np.dot(flat1, flat2) / (np.linalg.norm(flat1) * np.linalg.norm(flat2))\n \n return {\n 'mse': mse,\n 'mae': mae,\n 'max_abs_diff': max_abs_diff,\n 'relative_error': relative_error,\n 'cosine_similarity': cosine_sim\n }\n\n\ndef verify_logits_parity(lora_model, fused_model, tokenizer, test_prompts):\n \"\"\"\n Verify that logits from LoRA model match fused model\n \n Args:\n lora_model: Model with LoRA adapter applied on-the-fly\n fused_model: Model with merged LoRA weights\n tokenizer: Tokenizer for encoding prompts\n test_prompts: List of test prompts\n \n Returns:\n bool: True if all tests pass (MSE < 1e-5)\n \"\"\"\n print_section(\"Verifying Logits Parity\")\n \n all_passed = True\n results = []\n \n for i, prompt in enumerate(test_prompts, 1):\n print(f\"\\nTest {i}/{len(test_prompts)}\")\n print(f\"Prompt: {prompt[:100]}...\" if len(prompt) > 100 else f\"Prompt: {prompt}\")\n print(\"-\" * 80)\n \n # Generate logits from both models\n print(\"Generating logits from LoRA model (on-the-fly)...\")\n lora_logits = generate_logits(lora_model, tokenizer, prompt)\n \n print(\"Generating logits from fused model...\")\n fused_logits = generate_logits(fused_model, tokenizer, prompt)\n \n # Calculate metrics\n metrics = calculate_metrics(lora_logits, fused_logits)\n results.append(metrics)\n \n # Print results\n print(\"\\nMetrics:\")\n print(f\" MSE (Mean Squared Error): {metrics['mse']:.2e}\")\n print(f\" MAE (Mean Absolute Error): {metrics['mae']:.2e}\")\n print(f\" Max Absolute Difference: {metrics['max_abs_diff']:.2e}\")\n print(f\" Relative Error: {metrics['relative_error']:.2e}\")\n print(f\" Cosine Similarity: {metrics['cosine_similarity']:.6f}\")\n \n # Check if MSE is below threshold\n threshold = 1e-5\n passed = metrics['mse'] < threshold\n \n status = \"PASS\" if passed else \"FAIL\"\n print(f\"\\nStatus: {status} (MSE < {threshold}: {metrics['mse']:.2e} < {threshold})\")\n \n if not passed:\n all_passed = False\n \n # Print summary\n print_section(\"Summary\")\n \n avg_mse = np.mean([r['mse'] for r in results])\n avg_mae = np.mean([r['mae'] for r in results])\n max_abs_diff_overall = np.max([r['max_abs_diff'] for r in results])\n avg_relative_error = np.mean([r['relative_error'] for r in results])\n avg_cosine_sim = np.mean([r['cosine_similarity'] for r in results])\n \n print(f\"Tests run: {len(test_prompts)}\")\n print(f\"\\nAverage Metrics Across All Tests:\")\n print(f\" Average MSE: {avg_mse:.2e}\")\n print(f\" Average MAE: {avg_mae:.2e}\")\n print(f\" Maximum Absolute Difference: {max_abs_diff_overall:.2e}\")\n print(f\" Average Relative Error: {avg_relative_error:.2e}\")\n print(f\" Average Cosine Similarity: {avg_cosine_sim:.6f}\")\n \n print(f\"\\nOverall Result: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}\")\n \n return all_passed\n\n\ndef format_alpaca_prompt(instruction, input_text=\"\"):\n \"\"\"\n Format prompt in Alpaca instruction format\n \n Args:\n instruction: The instruction text\n input_text: Optional input context\n \n Returns:\n str: Formatted prompt\n \"\"\"\n if input_text:\n return f\"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n\"\"\"\n else:\n return f\"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n\"\"\"\n\n\ndef main():\n \"\"\"Main execution function\"\"\"\n print_section(\"LoRA Fusion and Verification Pipeline\")\n \n # Configuration\n base_model_name = \"meta-llama/Llama-2-7b-hf\"\n lora_adapter_name = \"tloen/alpaca-lora-7b\"\n output_dir = \"./alpaca-llama2-7b-fused\"\n \n print(\"Configuration:\")\n print(f\" Base Model: {base_model_name}\")\n print(f\" LoRA Adapter: {lora_adapter_name}\")\n print(f\" Output Directory: {output_dir}\")\n print(f\" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}\")\n print(f\" PyTorch Version: {torch.__version__}\")\n \n # Step 1: Load models\n lora_model, tokenizer = load_models(base_model_name, lora_adapter_name)\n \n # Step 2: Merge and export\n fused_model = merge_and_export(lora_model, output_dir)\n \n # Step 3: Prepare test prompts\n test_prompts = [\n # Test 1: Simple Alpaca instruction\n format_alpaca_prompt(\"Tell me about alpacas.\"),\n \n # Test 2: Alpaca instruction with input\n format_alpaca_prompt(\n \"Summarize the following text.\",\n \"Alpacas are domesticated South American camelids. They are raised for their soft fleece and are known for their gentle temperament.\"\n ),\n \n # Test 3: Complex instruction\n format_alpaca_prompt(\"Write a Python function that calculates the fibonacci sequence.\"),\n \n # Test 4: Simple question (non-Alpaca format for variety)\n \"What is the capital of France?\",\n \n # Test 5: Code generation\n format_alpaca_prompt(\"Explain what machine learning is in simple terms.\")\n ]\n \n print(f\"\\nPrepared {len(test_prompts)} test prompts\")\n \n # Step 4: Verify logits parity\n all_passed = verify_logits_parity(lora_model, fused_model, tokenizer, test_prompts)\n \n # Final summary\n print_section(\"Pipeline Complete\")\n \n print(f\"Fused model saved to: {os.path.abspath(output_dir)}\")\n print(f\"Format: safetensors\")\n print(f\"Verification: {'SUCCESS - All tests passed' if all_passed else 'FAILED - Some tests did not pass'}\")\n \n if all_passed:\n print(\"\\nThe fused model produces identical logits to the on-the-fly LoRA application.\")\n print(\"You can safely use the fused model as a drop-in replacement.\")\n else:\n print(\"\\nWARNING: The fused model does not produce identical logits.\")\n print(\"Please review the metrics above to understand the discrepancies.\")\n \n return 0 if all_passed else 1\n\n\nif __name__ == \"__main__\":\n import sys\n exit_code = main()\n sys.exit(exit_code)\n'''\n\n# Write to /tmp/fuse_lora.py\nPath('/tmp/fuse_lora.py').write_text(fuse_lora_content)\nprint(\"βœ“ Successfully created /tmp/fuse_lora.py\")\n" + }, + { + "cell_type": "code", + "id": "lm4uok5rtr", + "source": "from pathlib import Path\n\nfilter_toxic_content = r'''#!/usr/bin/env python3\n\"\"\"\nFilter Toxic Dataset Script\n\nThis script:\n1. Loads the lmsys/toxic-chat dataset (toxicchat0124 version)\n2. Loads the unitary/toxic-bert classifier model\n3. Runs inference on all examples to classify toxicity\n4. Logs detailed per-label removal statistics\n5. Filters out toxic content (using 0.5 threshold)\n6. Creates stratified train/validation/test splits (70/15/15)\n7. Saves the filtered dataset and generates a comprehensive JSON report\n\"\"\"\n\nimport json\nimport logging\nfrom collections import defaultdict\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Dict, List, Tuple\n\nimport numpy as np\nimport torch\nfrom datasets import Dataset, DatasetDict, load_dataset\nfrom sklearn.model_selection import train_test_split\nfrom tqdm import tqdm\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer\n\n# Configure logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s - %(levelname)s - %(message)s\",\n handlers=[\n logging.FileHandler(\"filter_toxic_dataset.log\"),\n logging.StreamHandler()\n ]\n)\nlogger = logging.getLogger(__name__)\n\n# Toxic-BERT label indices\nTOXIC_LABELS = {\n 0: \"toxic\",\n 1: \"severe_toxic\",\n 2: \"obscene\",\n 3: \"threat\",\n 4: \"insult\",\n 5: \"identity_hate\"\n}\n\nclass ToxicityFilter:\n \"\"\"Main class for filtering toxic content from datasets.\"\"\"\n \n def __init__(\n self,\n model_name: str = \"unitary/toxic-bert\",\n threshold: float = 0.5,\n batch_size: int = 32,\n device: str = None\n ):\n \"\"\"Initialize the toxicity filter.\"\"\"\n self.model_name = model_name\n self.threshold = threshold\n self.batch_size = batch_size\n self.device = device or (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n \n logger.info(f\"Initializing ToxicityFilter with model: {model_name}\")\n logger.info(f\"Device: {self.device}, Batch size: {batch_size}, Threshold: {threshold}\")\n \n # Load model and tokenizer\n self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n self.model = AutoModelForSequenceClassification.from_pretrained(model_name)\n self.model.to(self.device)\n self.model.eval()\n \n # Statistics tracking\n self.stats = {\n \"total_examples\": 0,\n \"filtered_examples\": 0,\n \"kept_examples\": 0,\n \"label_stats\": {label: {\"count\": 0, \"removed\": 0} for label in TOXIC_LABELS.values()},\n \"threshold\": threshold,\n \"model\": model_name,\n \"device\": self.device\n }\n \n logger.info(\"Model loaded successfully\")\n \n def classify_batch(self, texts: List[str]) -> Tuple[np.ndarray, np.ndarray]:\n \"\"\"Classify a batch of texts for toxicity.\"\"\"\n # Tokenize\n inputs = self.tokenizer(\n texts,\n padding=True,\n truncation=True,\n max_length=512,\n return_tensors=\"pt\"\n )\n inputs = {k: v.to(self.device) for k, v in inputs.items()}\n \n # Inference\n with torch.no_grad():\n outputs = self.model(**inputs)\n probabilities = torch.sigmoid(outputs.logits).cpu().numpy()\n \n # Determine if any label exceeds threshold\n predictions = (probabilities > self.threshold).any(axis=1)\n \n return predictions, probabilities\n \n def process_dataset(\n self,\n dataset: Dataset,\n text_column: str = \"user_input\"\n ) -> Tuple[Dataset, Dataset, Dict]:\n \"\"\"Process dataset and filter toxic content.\"\"\"\n logger.info(f\"Processing dataset with {len(dataset)} examples\")\n \n self.stats[\"total_examples\"] = len(dataset)\n \n # Storage for results\n all_predictions = []\n all_probabilities = []\n \n # Process in batches with progress bar\n num_batches = (len(dataset) + self.batch_size - 1) // self.batch_size\n \n for i in tqdm(range(0, len(dataset), self.batch_size), \n desc=\"Classifying toxicity\", \n total=num_batches):\n batch_texts = dataset[text_column][i:i + self.batch_size]\n predictions, probabilities = self.classify_batch(batch_texts)\n \n all_predictions.extend(predictions)\n all_probabilities.extend(probabilities)\n \n # Convert to numpy arrays\n all_predictions = np.array(all_predictions)\n all_probabilities = np.array(all_probabilities)\n \n # Calculate per-label statistics\n for label_idx, label_name in TOXIC_LABELS.items():\n label_probs = all_probabilities[:, label_idx]\n toxic_for_label = label_probs > self.threshold\n \n self.stats[\"label_stats\"][label_name][\"count\"] = int(toxic_for_label.sum())\n self.stats[\"label_stats\"][label_name][\"removal_rate\"] = float(\n toxic_for_label.sum() / len(dataset)\n )\n \n logger.info(\n f\"Label '{label_name}': {toxic_for_label.sum()} examples \"\n f\"({toxic_for_label.sum() / len(dataset) * 100:.2f}%) exceed threshold\"\n )\n \n # Add predictions and probabilities to dataset\n dataset_with_scores = dataset.add_column(\"is_toxic\", all_predictions.tolist())\n \n # Add individual label probabilities\n for label_idx, label_name in TOXIC_LABELS.items():\n dataset_with_scores = dataset_with_scores.add_column(\n f\"prob_{label_name}\",\n all_probabilities[:, label_idx].tolist()\n )\n \n # Split into filtered (clean) and toxic datasets\n filtered_dataset = dataset_with_scores.filter(lambda x: not x[\"is_toxic\"])\n toxic_dataset = dataset_with_scores.filter(lambda x: x[\"is_toxic\"])\n \n self.stats[\"filtered_examples\"] = len(toxic_dataset)\n self.stats[\"kept_examples\"] = len(filtered_dataset)\n self.stats[\"filter_rate\"] = self.stats[\"filtered_examples\"] / self.stats[\"total_examples\"]\n \n logger.info(f\"Filtered {len(toxic_dataset)} toxic examples ({self.stats['filter_rate']*100:.2f}%)\")\n logger.info(f\"Kept {len(filtered_dataset)} clean examples\")\n \n return filtered_dataset, toxic_dataset, self.stats\n \n def create_stratified_splits(\n self,\n dataset: Dataset,\n train_size: float = 0.7,\n val_size: float = 0.15,\n test_size: float = 0.15,\n stratify_column: str = None,\n random_state: int = 42\n ) -> DatasetDict:\n \"\"\"Create stratified train/validation/test splits.\"\"\"\n assert abs(train_size + val_size + test_size - 1.0) < 1e-6, \"Split sizes must sum to 1.0\"\n \n logger.info(f\"Creating stratified splits: train={train_size}, val={val_size}, test={test_size}\")\n \n # Convert to pandas for sklearn\n df = dataset.to_pandas()\n \n # Prepare stratification column if specified\n stratify = None\n if stratify_column and stratify_column in df.columns:\n stratify = df[stratify_column]\n logger.info(f\"Stratifying on column: {stratify_column}\")\n \n # First split: train vs (val + test)\n train_df, temp_df = train_test_split(\n df,\n train_size=train_size,\n random_state=random_state,\n stratify=stratify\n )\n \n # Second split: val vs test\n val_ratio = val_size / (val_size + test_size)\n val_stratify = None\n if stratify is not None:\n val_stratify = temp_df[stratify_column]\n \n val_df, test_df = train_test_split(\n temp_df,\n train_size=val_ratio,\n random_state=random_state,\n stratify=val_stratify\n )\n \n # Convert back to datasets\n dataset_dict = DatasetDict({\n \"train\": Dataset.from_pandas(train_df, preserve_index=False),\n \"validation\": Dataset.from_pandas(val_df, preserve_index=False),\n \"test\": Dataset.from_pandas(test_df, preserve_index=False)\n })\n \n # Log split sizes\n logger.info(f\"Split sizes:\")\n logger.info(f\" Train: {len(dataset_dict['train'])} ({len(dataset_dict['train'])/len(dataset)*100:.2f}%)\")\n logger.info(f\" Validation: {len(dataset_dict['validation'])} ({len(dataset_dict['validation'])/len(dataset)*100:.2f}%)\")\n logger.info(f\" Test: {len(dataset_dict['test'])} ({len(dataset_dict['test'])/len(dataset)*100:.2f}%)\")\n \n # Verify stratification if applicable\n if stratify_column and stratify_column in df.columns:\n logger.info(\"Verifying stratification:\")\n \n for split_name in [\"train\", \"validation\", \"test\"]:\n split_df = dataset_dict[split_name].to_pandas()\n split_dist = split_df[stratify_column].value_counts(normalize=True).sort_index()\n logger.info(f\" {split_name} distribution: {split_dist.to_dict()}\")\n \n return dataset_dict\n\n\ndef main():\n \"\"\"Main execution function.\"\"\"\n \n # Configuration\n DATASET_NAME = \"lmsys/toxic-chat\"\n DATASET_CONFIG = \"toxicchat0124\"\n MODEL_NAME = \"unitary/toxic-bert\"\n THRESHOLD = 0.5\n BATCH_SIZE = 32\n OUTPUT_DIR = Path(\"./filtered_toxic_chat\")\n REPORT_PATH = OUTPUT_DIR / \"filtering_report.json\"\n \n # Create output directory\n OUTPUT_DIR.mkdir(exist_ok=True)\n \n logger.info(\"=\"*80)\n logger.info(\"Starting Toxic Dataset Filtering Pipeline\")\n logger.info(\"=\"*80)\n logger.info(f\"Dataset: {DATASET_NAME} ({DATASET_CONFIG})\")\n logger.info(f\"Model: {MODEL_NAME}\")\n logger.info(f\"Threshold: {THRESHOLD}\")\n logger.info(f\"Output directory: {OUTPUT_DIR}\")\n \n # Step 1: Load dataset\n logger.info(\"\\n[Step 1/6] Loading dataset...\")\n try:\n dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=\"train\")\n logger.info(f\"Loaded {len(dataset)} examples\")\n logger.info(f\"Dataset columns: {dataset.column_names}\")\n except Exception as e:\n logger.error(f\"Failed to load dataset: {e}\")\n raise\n \n # Step 2: Initialize filter\n logger.info(\"\\n[Step 2/6] Initializing toxicity filter...\")\n filter_obj = ToxicityFilter(\n model_name=MODEL_NAME,\n threshold=THRESHOLD,\n batch_size=BATCH_SIZE\n )\n \n # Step 3: Process dataset\n logger.info(\"\\n[Step 3/6] Processing dataset and classifying toxicity...\")\n filtered_dataset, toxic_dataset, stats = filter_obj.process_dataset(\n dataset,\n text_column=\"user_input\"\n )\n \n # Step 4: Create stratified splits\n logger.info(\"\\n[Step 4/6] Creating stratified train/validation/test splits...\")\n \n # Try to stratify on a relevant column if available\n stratify_col = None\n if \"jailbreaking\" in filtered_dataset.column_names:\n stratify_col = \"jailbreaking\"\n elif \"toxicity\" in filtered_dataset.column_names:\n stratify_col = \"toxicity\"\n \n dataset_splits = filter_obj.create_stratified_splits(\n filtered_dataset,\n train_size=0.7,\n val_size=0.15,\n test_size=0.15,\n stratify_column=stratify_col\n )\n \n # Step 5: Save datasets\n logger.info(\"\\n[Step 5/6] Saving filtered datasets...\")\n \n # Save main filtered dataset with splits\n dataset_splits.save_to_disk(str(OUTPUT_DIR / \"filtered_dataset\"))\n logger.info(f\"Saved filtered dataset splits to {OUTPUT_DIR / 'filtered_dataset'}\")\n \n # Save toxic examples separately for analysis\n toxic_dataset.save_to_disk(str(OUTPUT_DIR / \"toxic_examples\"))\n logger.info(f\"Saved {len(toxic_dataset)} toxic examples to {OUTPUT_DIR / 'toxic_examples'}\")\n \n # Step 6: Generate comprehensive report\n logger.info(\"\\n[Step 6/6] Generating comprehensive JSON report...\")\n \n report = {\n \"metadata\": {\n \"timestamp\": datetime.now().isoformat(),\n \"dataset_source\": DATASET_NAME,\n \"dataset_config\": DATASET_CONFIG,\n \"model\": MODEL_NAME,\n \"threshold\": THRESHOLD,\n \"batch_size\": BATCH_SIZE,\n \"device\": filter_obj.device\n },\n \"dataset_statistics\": {\n \"original_size\": stats[\"total_examples\"],\n \"filtered_size\": stats[\"kept_examples\"],\n \"removed_size\": stats[\"filtered_examples\"],\n \"removal_rate\": f\"{stats['filter_rate']*100:.2f}%\",\n \"retention_rate\": f\"{(1-stats['filter_rate'])*100:.2f}%\"\n },\n \"per_label_statistics\": {},\n \"split_statistics\": {\n \"train\": {\n \"size\": len(dataset_splits[\"train\"]),\n \"percentage\": f\"{len(dataset_splits['train'])/stats['kept_examples']*100:.2f}%\"\n },\n \"validation\": {\n \"size\": len(dataset_splits[\"validation\"]),\n \"percentage\": f\"{len(dataset_splits['validation'])/stats['kept_examples']*100:.2f}%\"\n },\n \"test\": {\n \"size\": len(dataset_splits[\"test\"]),\n \"percentage\": f\"{len(dataset_splits['test'])/stats['kept_examples']*100:.2f}%\"\n }\n },\n \"output_paths\": {\n \"filtered_dataset\": str(OUTPUT_DIR / \"filtered_dataset\"),\n \"toxic_examples\": str(OUTPUT_DIR / \"toxic_examples\"),\n \"report\": str(REPORT_PATH)\n }\n }\n \n # Add per-label statistics\n for label_name, label_stats in stats[\"label_stats\"].items():\n report[\"per_label_statistics\"][label_name] = {\n \"count_above_threshold\": label_stats[\"count\"],\n \"removal_rate\": f\"{label_stats['removal_rate']*100:.2f}%\",\n \"percentage_of_dataset\": f\"{label_stats['removal_rate']*100:.2f}%\"\n }\n \n # Add stratification verification if applicable\n if stratify_col:\n report[\"stratification\"] = {\n \"stratified_on\": stratify_col,\n \"verification\": \"Stratification verified - see logs for distribution details\"\n }\n \n # Save report\n with open(REPORT_PATH, \"w\") as f:\n json.dump(report, f, indent=2)\n \n logger.info(f\"Report saved to {REPORT_PATH}\")\n \n # Print summary\n logger.info(\"\\n\" + \"=\"*80)\n logger.info(\"FILTERING COMPLETE - SUMMARY\")\n logger.info(\"=\"*80)\n logger.info(f\"Original dataset: {stats['total_examples']} examples\")\n logger.info(f\"Filtered (clean): {stats['kept_examples']} examples ({(1-stats['filter_rate'])*100:.2f}%)\")\n logger.info(f\"Removed (toxic): {stats['filtered_examples']} examples ({stats['filter_rate']*100:.2f}%)\")\n logger.info(\"\\nPer-label removal rates:\")\n for label_name, label_stats in stats[\"label_stats\"].items():\n logger.info(f\" {label_name:15s}: {label_stats['count']:5d} examples ({label_stats['removal_rate']*100:5.2f}%)\")\n logger.info(\"\\nDataset splits:\")\n logger.info(f\" Train: {len(dataset_splits['train']):5d} examples (70.0%)\")\n logger.info(f\" Validation: {len(dataset_splits['validation']):5d} examples (15.0%)\")\n logger.info(f\" Test: {len(dataset_splits['test']):5d} examples (15.0%)\")\n logger.info(f\"\\nAll outputs saved to: {OUTPUT_DIR}\")\n logger.info(\"=\"*80)\n\n\nif __name__ == \"__main__\":\n main()\n'''\n\n# Write to parent directory (hf-agent/)\nimport os\nparent_dir = Path(__file__).parent.parent if '__file__' in globals() else Path.cwd().parent\noutput_path = parent_dir / 'filter_toxic_dataset.py'\noutput_path.write_text(filter_toxic_content)\n\n# Make it executable\nimport stat\nst = output_path.stat()\noutput_path.chmod(st.st_mode | stat.S_IEXEC | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)\n\nprint(f\"βœ“ Successfully created {output_path}\")\nprint(f\"βœ“ File is executable\")\n\n# Also check dependencies\nprint(\"\\nChecking dependencies...\")\ntry:\n import sklearn\n print(\"βœ“ scikit-learn is installed (version: {})\".format(sklearn.__version__))\nexcept ImportError:\n print(\"βœ— scikit-learn is NOT installed\")\n\ntry:\n import tqdm\n print(\"βœ“ tqdm is installed (version: {})\".format(tqdm.__version__))\nexcept ImportError:\n print(\"βœ— tqdm is NOT installed\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "vv876s0gpqk", + "source": "# Create requirements.txt and README.md files\nfrom pathlib import Path\n\n# Navigate to project root (parent of eval/)\nproject_root = Path('/Users/akseljoonas/Documents/hf-agent')\n\nrequirements_content = '''# HF-Agent Requirements\n# Production-ready dependencies for the HF-Agent project\n# Install with: pip install -r requirements.txt or use uv sync (recommended)\n\n# Core ML/AI Dependencies\ntorch>=2.0.0\ntransformers>=4.35.0\ndatasets>=2.14.0\nnumpy>=1.24.0\naccelerate>=0.24.0\n\n# Agent SDK and API\nclaude-agent-sdk>=0.1.0\nlitellm>=1.0.0\npydantic>=2.12.3\n\n# Hugging Face Integration\nhuggingface-hub>=1.0.1\nfastmcp>=2.4.0\n\n# Evaluation Framework\ninspect-ai>=0.3.149\nlmnr[all]>=0.7.23\n\n# Utilities\npython-dotenv>=1.2.1\nrequests>=2.32.5\ntenacity>=8.0.0\ntqdm>=4.65.0\npandas>=2.3.3\n\n# Optional but recommended for evaluation\nscikit-learn>=1.3.0 # For stratified splits in dataset processing\npeft>=0.7.0 # For LoRA fusion tasks\n'''\n\nreadme_content = '''# HF Agent\n\nAn MLE agent CLI with MCP (Model Context Protocol) integration, built-in tool support, and comprehensive evaluation framework.\n\n## Quick Start\n\n### Installation\n\n```bash\n# Clone the repository\ngit clone git@github.com:huggingface/hf_agent.git\ncd hf-agent\n\n# Install dependencies (using uv - recommended)\nuv sync\n\n# Or use pip\npip install -r requirements.txt\n```\n\n### Set Up Environment\n\nCreate a `.env` file in the project root:\n\n```bash\n# Required for Claude Agent SDK\nANTHROPIC_API_KEY=your_api_key_here\n\n# Required for Hugging Face features\nHF_TOKEN=your_hf_token_here\n\n# Optional: LiteLLM API keys if using other providers\nOPENAI_API_KEY=your_openai_key_here\n```\n\n### Interactive CLI\n\n```bash\nuv run python -m agent.main\n```\n\nThis starts an interactive chat session with the agent. Type your messages and the agent will respond, using tools as needed.\n\n## Features\n\n### Core Capabilities\n\n- **Agent SDK Integration**: Built on Claude Agent SDK with support for async operations and streaming\n- **MCP Protocol Support**: Full Model Context Protocol integration for extensible tool management\n- **Built-in Tools**: File operations (Read/Write), Bash execution, and more\n- **Hugging Face Integration**: Search models, datasets, papers, and spaces directly through MCP\n- **LiteLLM Backend**: Flexible LLM provider support (Anthropic, OpenAI, custom)\n- **Context Management**: Intelligent message history tracking and compaction\n- **Evaluation Framework**: Rubric-based evaluation pipeline implementing Rubrics as Rewards (RaR) paper\n\n### Evaluation Suite\n\nThe `eval/` directory contains a comprehensive benchmark framework:\n\n- **Rubric Generation**: Instance-specific evaluation criteria from QA pairs\n- **Multiple Solvers**: Benchmark `hf_agent`, `claude_code`, or custom solvers\n- **Leaderboard Integration**: Track performance over time on HuggingFace datasets\n- **Inspect AI Integration**: Full integration with the Inspect AI evaluation framework\n\nSee [eval/README.md](eval/README.md) for detailed evaluation documentation.\n\n## Running the Agent\n\n### Basic Usage\n\n```bash\n# Start interactive mode\nuv run python -m agent.main\n```\n\n### With Custom Configuration\n\n```bash\n# Use a specific MCP server configuration\nuv run python -m agent.main --config agent/config_mcp_example.json\n```\n\n### Batch Processing\n\nProcess multiple tasks concurrently using the batch solver:\n\n```bash\n# Run batch evaluation with 5 concurrent agents\nuv run python eval/amp_batch_solve.py\n```\n\nThis processes tasks from `eval/filled_tasks.jsonl` and outputs results to `eval/solved_tasks.jsonl`.\n\n## Configuration\n\n### Agent Configuration\n\nCreate a JSON config file (e.g., `agent/config_mcp_example.json`):\n\n```json\n{\n \"model_name\": \"anthropic/claude-sonnet-4-5-20250929\",\n \"max_iterations\": 10,\n \"mcp_servers\": [\n {\n \"name\": \"huggingface\",\n \"command\": \"uvx\",\n \"args\": [\"fastmcp\", \"run\", \"huggingface\"],\n \"env\": {\n \"HF_TOKEN\": \"${HF_TOKEN}\"\n }\n }\n ]\n}\n```\n\n### Customizing Tools\n\nEdit `agent/core/tools.py` to add built-in tools:\n\n```python\ndef create_builtin_tools() -> list[ToolSpec]:\n return [\n ToolSpec(\n name=\"your_tool\",\n description=\"What your tool does\",\n parameters={\n \"type\": \"object\",\n \"properties\": {\n \"param\": {\"type\": \"string\", \"description\": \"Parameter description\"}\n },\n \"required\": [\"param\"]\n },\n handler=your_async_handler\n ),\n # ... existing tools\n ]\n```\n\n### Adding MCP Servers\n\nAdd to your config JSON:\n\n```json\n{\n \"mcp_servers\": [\n {\n \"name\": \"your_server\",\n \"command\": \"command\",\n \"args\": [\"arg1\", \"arg2\"],\n \"env\": {\"KEY\": \"value\"}\n }\n ]\n}\n```\n\n## Evaluation\n\n### Generate Rubrics\n\n```bash\nuv run python eval/generate_rubrics.py \\\n --infile qa_pairs.jsonl \\\n --outfile qa_rubrics.jsonl \\\n --model anthropic/claude-sonnet-4-5-20250929 \\\n --push-to-hub akseljoonas/hf-agent-benchmark@rubrics\n```\n\n### Run Evaluation\n\n```bash\n# Evaluate hf-agent\nuv run inspect eval eval/task.py@hf-benchmark-with-rubrics \\\n -T dataset_name=akseljoonas/hf-agent-rubrics \\\n -T dataset_split=train \\\n -T limit=25 \\\n -T solver_name=hf_agent \\\n -T solver_kwargs='{\"config_path\":\"agent/config_mcp_example.json\",\"max_iterations\":10}' \\\n --log-dir logs/inspect\n\n# Evaluate Claude Code headlessly\nuv run inspect eval eval/task.py@hf-benchmark-with-rubrics \\\n -T solver_name=claude_code \\\n -T solver_kwargs='{\"allowed_tools\":\"Bash,Read\",\"output_format\":\"json\"}'\n```\n\n### Push to Leaderboard\n\n```bash\nuv run python eval/run_eval_with_leaderboard.py \\\n --hf-dataset akseljoonas/hf-agent-leaderboard \\\n --hf-token $HF_TOKEN \\\n --solver-name hf_agent \\\n --solver-kwargs '{\"config_path\":\"agent/config_mcp_example.json\",\"max_iterations\":10}' \\\n --dataset akseljoonas/hf-agent-rubrics@train \\\n --limit 25\n```\n\n## Troubleshooting\n\n### Common Issues\n\n#### 1. MCP Server Connection Errors\n\n**Problem**: Agent fails to connect to MCP servers.\n\n**Solutions**:\n- Verify MCP server command is in PATH: `which uvx` or `which fastmcp`\n- Check environment variables are set correctly in `.env`\n- Ensure HF_TOKEN is valid: `huggingface-cli whoami`\n- Try running MCP server manually: `uvx fastmcp run huggingface`\n\n#### 2. CUDA Out of Memory\n\n**Problem**: GPU memory errors during model loading or inference.\n\n**Solutions**:\n- Use smaller batch sizes in evaluation scripts\n- Enable gradient checkpointing for large models\n- Use `torch.float16` or `torch.bfloat16` for reduced memory\n- Clear CUDA cache: `torch.cuda.empty_cache()`\n- Use CPU inference for testing: `device_map=\"cpu\"`\n\n#### 3. LiteLLM API Errors\n\n**Problem**: API key or rate limit errors.\n\n**Solutions**:\n- Verify API keys in `.env`: `ANTHROPIC_API_KEY`, `OPENAI_API_KEY`\n- Check rate limits for your API provider\n- Add retry logic with exponential backoff (already included via `tenacity`)\n- Monitor usage: `litellm --debug`\n\n#### 4. Import Errors\n\n**Problem**: `ModuleNotFoundError` for packages.\n\n**Solutions**:\n```bash\n# Reinstall dependencies\nuv sync\n\n# Or with pip\npip install -r requirements.txt\n\n# Check Python version (requires >=3.12)\npython --version\n```\n\n#### 5. Evaluation Rubrics Not Loading\n\n**Problem**: Rubric scorer fails or returns invalid scores.\n\n**Solutions**:\n- Verify rubrics dataset format matches expected schema\n- Check that `eval/generate_rubrics.py` completed successfully\n- Validate JSONL format: each line should be valid JSON\n- Inspect rubric structure: must have `criteria` list with `criterion`, `weight`, `type`\n\n#### 6. Permission Errors with Bash Tool\n\n**Problem**: Agent cannot execute bash commands.\n\n**Solutions**:\n- Verify `permission_mode` in config: should be `\"bypassPermissions\"` for batch mode\n- Check file permissions: `chmod +x script.sh`\n- Ensure working directory exists and is writable\n- Review `disallowed_tools` list in configuration\n\n### Getting Help\n\n- **Documentation**: See [eval/README.md](eval/README.md) for evaluation details\n- **Issues**: Open an issue on GitHub with error logs\n- **Logs**: Check `logs/inspect/` for detailed evaluation logs\n- **Debug Mode**: Set `LITELLM_LOG=DEBUG` environment variable\n\n## Example Output\n\n### Successful Evaluation\n\n```\n[1/25] Starting: What's the best model for sentiment analysis...\n[1/25] βœ“ Done: What's the best model for sentiment analysis...\n[2/25] Starting: How can I serve a model with multiple LoRAs...\n[2/25] βœ“ Done: How can I serve a model with multiple LoRAs...\n\nCompleted: 25/25 successful\nResults saved to eval/solved_tasks.jsonl\n```\n\n### Rubric Scoring\n\n```\nTask: \"Find the best text-generation model for medical domain\"\nCriteria:\n βœ“ Searches HuggingFace for domain-specific models (weight: 5) - PASS\n βœ“ Considers model size and hardware requirements (weight: 3) - PASS\n βœ“ Checks model licenses for commercial use (weight: 4) - PASS\n βœ— Provides code example for inference (weight: 2) - FAIL\n \nScore: 0.857 (12/14 weighted points)\n```\n\n## Project Structure\n\n```\nhf-agent/\nβ”œβ”€β”€ agent/ # Main agent implementation\nβ”‚ β”œβ”€β”€ config.py # Configuration models\nβ”‚ β”œβ”€β”€ main.py # Interactive CLI entry point\nβ”‚ β”œβ”€β”€ context_manager/\nβ”‚ β”‚ └── manager.py # Message history management\nβ”‚ └── core/\nβ”‚ β”œβ”€β”€ agent_loop.py # Main agent loop and handlers\nβ”‚ β”œβ”€β”€ session.py # Session management\nβ”‚ β”œβ”€β”€ mcp_client.py # MCP SDK integration\nβ”‚ └── tools.py # ToolRouter and built-in tools\nβ”‚\nβ”œβ”€β”€ eval/ # Evaluation suite\nβ”‚ β”œβ”€β”€ README.md # Detailed evaluation docs\nβ”‚ β”œβ”€β”€ generate_rubrics.py # Rubric generation from QA pairs\nβ”‚ β”œβ”€β”€ rubric_eval.py # RaR-Explicit scoring implementation\nβ”‚ β”œβ”€β”€ task.py # Inspect AI task definitions\nβ”‚ β”œβ”€β”€ solvers.py # Solver registry (hf_agent, claude_code, etc.)\nβ”‚ β”œβ”€β”€ hf_agent_connector.py # Bridge to agent stack\nβ”‚ β”œβ”€β”€ leaderboard.py # HuggingFace leaderboard utilities\nβ”‚ β”œβ”€β”€ run_eval_with_leaderboard.py # CLI wrapper for evals\nβ”‚ β”œβ”€β”€ amp_batch_solve.py # Concurrent batch processing\nβ”‚ └── models.py # Shared Pydantic models\nβ”‚\nβ”œβ”€β”€ requirements.txt # Python dependencies\nβ”œβ”€β”€ pyproject.toml # Project metadata (for uv)\nβ”œβ”€β”€ README.md # This file\n└── .env # Environment variables (create this)\n```\n\n## Advanced Usage\n\n### Custom Solver Implementation\n\nCreate a new solver in `eval/solvers.py`:\n\n```python\n@solver\ndef my_custom_solver():\n async def solve(state: TaskState, generate: Generate):\n # Your solver logic here\n response = await your_agent_call(state.input_text)\n return response\n return solve\n```\n\nRegister and use:\n\n```bash\nuv run inspect eval eval/task.py@hf-benchmark-with-rubrics \\\n -T solver_name=my_custom_solver\n```\n\n### Streaming Responses\n\nEnable streaming in the agent connector:\n\n```python\nfrom agent.core.session import Session\n\nsession = Session(config)\nasync for chunk in session.stream_response(prompt):\n print(chunk, end=\"\", flush=True)\n```\n\n### Cost Tracking\n\nMonitor API costs using LiteLLM callbacks:\n\n```python\nimport litellm\nlitellm.success_callback = [\"langfuse\"] # Or other integrations\n```\n\n## Contributing\n\n1. Fork the repository\n2. Create a feature branch: `git checkout -b feature/your-feature`\n3. Make your changes\n4. Run tests: `uv run pytest`\n5. Commit with clear messages: `git commit -m \"Add feature X\"`\n6. Push and create a Pull Request\n\n## License\n\n[Your License Here]\n\n## Acknowledgments\n\n- Built on [Claude Agent SDK](https://github.com/anthropics/claude-agent-sdk)\n- Evaluation framework inspired by [Rubrics as Rewards](https://arxiv.org/abs/2507.17746)\n- Powered by [Hugging Face](https://huggingface.co/) ecosystem\n'''\n\n# Write files\n(project_root / 'requirements.txt').write_text(requirements_content)\n(project_root / 'README_NEW.md').write_text(readme_content)\n\nprint(f\"βœ“ Created {project_root / 'requirements.txt'}\")\nprint(f\"βœ“ Created {project_root / 'README_NEW.md'}\")\nprint(\"\\nBoth files are production-ready!\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "7fljbwefz1v", + "source": "from pathlib import Path\n\n# Complete monitoring script for HF Job 694306ebc67c9f186cfe3879\nmonitoring_script = r'''#!/usr/bin/env python3\n\"\"\"\nHugging Face Job Monitor\nJob ID: 694306ebc67c9f186cfe3879\nvLLM Benchmark: Testing 4 block sizes (8, 16, 32, 64) for Llama-3.1-8B-Instruct\n\"\"\"\nimport time\nimport os\nimport sys\nfrom huggingface_hub import HfApi\nfrom dotenv import load_dotenv\n\ndef main():\n # Load environment\n load_dotenv()\n \n # Configuration\n job_id = \"694306ebc67c9f186cfe3879\"\n check_interval = 60 # seconds\n \n # Initialize API\n token = os.environ.get('HF_TOKEN')\n if not token:\n print(\"ERROR: HF_TOKEN environment variable not set\")\n print(\"Please set it in your .env file or export it:\")\n print(\" export HF_TOKEN='your_token_here'\")\n sys.exit(1)\n \n api = HfApi(token=token)\n \n # Display header\n print(\"=\"*80)\n print(f\"Monitoring Hugging Face Job: {job_id}\")\n print(\"=\"*80)\n print(\"Benchmark: vLLM with 4 block sizes (8, 16, 32, 64)\")\n print(\"Model: Llama-3.1-8B-Instruct\")\n print(f\"Check Interval: {check_interval} seconds\")\n print(\"=\"*80)\n \n seen_log_length = 0\n check_count = 0\n \n while True:\n try:\n check_count += 1\n \n # Inspect job status\n job_info = api.inspect_job(job_id)\n \n # Display status\n timestamp = time.strftime('%Y-%m-%d %H:%M:%S')\n print(f\"\\n[Check #{check_count}] [{timestamp}]\")\n print(f\"Status: {job_info.status.stage}\")\n \n if job_info.status.message:\n print(f\"Message: {job_info.status.message}\")\n \n # Fetch and process logs\n try:\n current_logs = \"\"\n for log_line in api.fetch_job_logs(job_id):\n current_logs += log_line + \"\\n\"\n \n # Display only new log content\n if len(current_logs) > seen_log_length:\n new_content = current_logs[seen_log_length:]\n if new_content.strip():\n print(\"\\n--- New Log Output ---\")\n print(new_content)\n print(\"--- End New Logs ---\")\n seen_log_length = len(current_logs)\n \n # Look for benchmark results markers\n if \"BENCHMARK RESULTS SUMMARY\" in current_logs:\n print(\"\\n\" + \"=\"*80)\n print(\"🎯 BENCHMARK RESULTS SUMMARY DETECTED!\")\n print(\"=\"*80)\n \n if \"JSON Results\" in current_logs:\n print(\"\\n\" + \"=\"*80)\n print(\"πŸ“Š JSON RESULTS DETECTED!\")\n print(\"=\"*80)\n \n except Exception as log_error:\n print(f\"Note: Could not fetch logs: {log_error}\")\n \n # Check if job has completed\n if job_info.status.stage in [\"COMPLETED\", \"CANCELED\", \"ERROR\", \"DELETED\"]:\n print(\"\\n\" + \"=\"*80)\n print(f\"JOB FINISHED\")\n print(f\"Final Status: {job_info.status.stage}\")\n print(\"=\"*80)\n \n # Fetch and display complete final output\n print(\"\\nFetching complete job output...\")\n try:\n final_logs = \"\"\n for log_line in api.fetch_job_logs(job_id):\n final_logs += log_line + \"\\n\"\n \n print(\"\\n\" + \"=\"*80)\n print(\"COMPLETE JOB OUTPUT\")\n print(\"=\"*80 + \"\\n\")\n print(final_logs)\n print(\"\\n\" + \"=\"*80)\n print(\"END OF COMPLETE OUTPUT\")\n print(\"=\"*80)\n \n except Exception as e:\n print(f\"Error fetching final logs: {e}\")\n \n print(f\"\\nJob URL: {job_info.url}\")\n print(f\"Job ID: {job_id}\")\n \n # Exit with appropriate code\n if job_info.status.stage == \"COMPLETED\":\n sys.exit(0)\n else:\n sys.exit(1)\n \n # Wait before next check\n print(f\"\\nWaiting {check_interval} seconds before next check...\")\n print(f\"(Current status: {job_info.status.stage})\")\n print(\"(Press Ctrl+C to stop monitoring)\")\n time.sleep(check_interval)\n \n except KeyboardInterrupt:\n print(\"\\n\\n\" + \"=\"*80)\n print(\"Monitoring interrupted by user (Ctrl+C)\")\n print(\"=\"*80)\n try:\n job_info = api.inspect_job(job_id)\n print(f\"\\nLatest Status: {job_info.status.stage}\")\n print(f\"Job URL: {job_info.url}\")\n except:\n pass\n print(f\"\\nYou can resume monitoring by running this script again\")\n sys.exit(0)\n \n except Exception as e:\n print(f\"\\nError: {e}\")\n print(f\"Retrying in {check_interval} seconds...\")\n time.sleep(check_interval)\n\nif __name__ == \"__main__\":\n main()\n'''\n\n# Write script to eval directory\neval_dir = Path('/Users/akseljoonas/Documents/hf-agent/eval')\nscript_path = eval_dir / 'monitor_hf_job.py'\nscript_path.write_text(monitoring_script)\n\n# Make executable\nimport stat\nst = script_path.stat()\nscript_path.chmod(st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)\n\nprint(f\"βœ“ Created monitoring script: {script_path}\")\nprint(f\"\\nTo start monitoring, run one of:\")\nprint(f\" python {script_path}\")\nprint(f\" uv run python {script_path}\")\nprint(f\"\\nThe script will:\")\nprint(\" - Check job status every 60 seconds\")\nprint(\" - Display new log output as it becomes available\")\nprint(\" - Detect BENCHMARK RESULTS SUMMARY and JSON Results\")\nprint(\" - Display complete output when job finishes\")\nprint(\" - Exit automatically when job completes or fails\")\nprint(\"\\nPress Ctrl+C to stop monitoring at any time\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "yjf9l5kmab8", + "source": "from pathlib import Path\nimport sys\n\n# Add parent directory to path\nsys.path.insert(0, str(Path.cwd().parent))\n\n# Define all the scripts we need to create\nproject_root = Path('/Users/akseljoonas/Documents/hf-agent')\n\n# 1. convert_to_webdataset.py\nconvert_script = r'''#!/usr/bin/env python3\n\"\"\"\nConvert HuggingFaceFW/fineweb-edu dataset to WebDataset format with checksum validation.\n\nThis script loads the fineweb-edu dataset and converts it to WebDataset tar archives\nwith proper sharding, checksum validation, and metadata tracking.\n\"\"\"\n\nimport argparse\nimport hashlib\nimport json\nimport logging\nimport os\nimport sys\nfrom pathlib import Path\nfrom typing import Dict, Optional, Any\nimport tarfile\nfrom io import BytesIO\n\nfrom datasets import load_dataset\nfrom tqdm import tqdm\n\n# Configure logging\nlogging.basicConfig(\n level=logging.INFO,\n format='%(asctime)s - %(levelname)s - %(message)s'\n)\nlogger = logging.getLogger(__name__)\n\n\nclass WebDatasetConverter:\n \"\"\"Convert HuggingFace dataset to WebDataset format with checksums.\"\"\"\n \n def __init__(\n self,\n dataset_name: str = \"HuggingFaceFW/fineweb-edu\",\n config_name: Optional[str] = None,\n split: str = \"train\",\n output_dir: str = \"./webdataset_output\",\n shard_size_mb: int = 500,\n max_samples: Optional[int] = None,\n streaming: bool = True\n ):\n \"\"\"\n Initialize the converter.\n \n Args:\n dataset_name: HuggingFace dataset identifier\n config_name: Dataset configuration name (e.g., \"sample-10BT\")\n split: Dataset split to convert\n output_dir: Directory to save WebDataset shards\n shard_size_mb: Target size for each shard in MB\n max_samples: Maximum number of samples to convert (None for all)\n streaming: Use streaming mode for large datasets\n \"\"\"\n self.dataset_name = dataset_name\n self.config_name = config_name\n self.split = split\n self.output_dir = Path(output_dir)\n self.shard_size_bytes = shard_size_mb * 1024 * 1024\n self.max_samples = max_samples\n self.streaming = streaming\n \n # Create output directory\n self.output_dir.mkdir(parents=True, exist_ok=True)\n \n # Track checksums and metadata\n self.checksums: Dict[str, str] = {}\n self.shard_metadata: Dict[str, Dict[str, Any]] = {}\n self.total_samples = 0\n self.current_shard = 0\n self.current_shard_size = 0\n self.current_shard_samples = 0\n \n def compute_sha256(self, filepath: Path) -> str:\n \"\"\"Compute SHA256 checksum of a file.\"\"\"\n sha256_hash = hashlib.sha256()\n with open(filepath, \"rb\") as f:\n for byte_block in iter(lambda: f.read(4096), b\"\"):\n sha256_hash.update(byte_block)\n return sha256_hash.hexdigest()\n \n def format_sample_id(self, index: int) -> str:\n \"\"\"Format sample ID with zero padding.\"\"\"\n return f\"sample_{index:012d}\"\n \n def create_tar_member(self, name: str, data: bytes) -> tarfile.TarInfo:\n \"\"\"Create a tar member from data.\"\"\"\n tarinfo = tarfile.TarInfo(name=name)\n tarinfo.size = len(data)\n return tarinfo\n \n def should_create_new_shard(self) -> bool:\n \"\"\"Check if we should start a new shard.\"\"\"\n return self.current_shard_size >= self.shard_size_bytes\n \n def get_shard_path(self, shard_num: int) -> Path:\n \"\"\"Get the path for a shard file.\"\"\"\n return self.output_dir / f\"fineweb_edu_{shard_num:06d}.tar\"\n \n def write_sample_to_tar(\n self,\n tar: tarfile.TarFile,\n sample_id: str,\n text: str,\n metadata: Dict[str, Any]\n ) -> int:\n \"\"\"\n Write a sample to the tar archive.\n \n Returns the size in bytes written.\n \"\"\"\n # Write text file\n text_bytes = text.encode('utf-8')\n text_name = f\"{sample_id}.txt\"\n text_info = self.create_tar_member(text_name, text_bytes)\n tar.addfile(text_info, BytesIO(text_bytes))\n \n # Write JSON metadata file\n json_bytes = json.dumps(metadata, ensure_ascii=False).encode('utf-8')\n json_name = f\"{sample_id}.json\"\n json_info = self.create_tar_member(json_name, json_bytes)\n tar.addfile(json_info, BytesIO(json_bytes))\n \n # Return total size\n return len(text_bytes) + len(json_bytes)\n \n def finalize_shard(self, shard_path: Path):\n \"\"\"Compute checksum and save metadata for a completed shard.\"\"\"\n if shard_path.exists():\n # Compute checksum\n checksum = self.compute_sha256(shard_path)\n shard_name = shard_path.name\n self.checksums[shard_name] = checksum\n \n # Store metadata\n self.shard_metadata[shard_name] = {\n \"shard_number\": self.current_shard,\n \"num_samples\": self.current_shard_samples,\n \"size_bytes\": shard_path.stat().st_size,\n \"checksum\": checksum\n }\n \n logger.info(\n f\"Finalized {shard_name}: {self.current_shard_samples} samples, \"\n f\"{shard_path.stat().st_size / (1024*1024):.2f} MB, \"\n f\"checksum: {checksum[:16]}...\"\n )\n \n def convert(self):\n \"\"\"Convert the dataset to WebDataset format.\"\"\"\n logger.info(f\"Loading dataset: {self.dataset_name}\")\n if self.config_name:\n logger.info(f\"Config: {self.config_name}\")\n logger.info(f\"Split: {self.split}\")\n logger.info(f\"Streaming: {self.streaming}\")\n \n # Load dataset\n try:\n dataset = load_dataset(\n self.dataset_name,\n name=self.config_name,\n split=self.split,\n streaming=self.streaming\n )\n except Exception as e:\n logger.error(f\"Failed to load dataset: {e}\")\n sys.exit(1)\n \n logger.info(f\"Dataset loaded successfully\")\n \n # Initialize first shard\n shard_path = self.get_shard_path(self.current_shard)\n tar = tarfile.open(shard_path, 'w')\n \n try:\n # Process samples\n sample_iter = iter(dataset)\n if self.max_samples:\n logger.info(f\"Processing up to {self.max_samples} samples\")\n \n # Create progress bar\n pbar = tqdm(\n total=self.max_samples,\n desc=\"Converting samples\",\n unit=\"samples\"\n )\n \n for idx, sample in enumerate(sample_iter):\n if self.max_samples and idx >= self.max_samples:\n break\n \n # Check if we need a new shard\n if self.should_create_new_shard() and self.current_shard_samples > 0:\n # Finalize current shard\n tar.close()\n self.finalize_shard(shard_path)\n \n # Start new shard\n self.current_shard += 1\n self.current_shard_size = 0\n self.current_shard_samples = 0\n shard_path = self.get_shard_path(self.current_shard)\n tar = tarfile.open(shard_path, 'w')\n logger.info(f\"Starting new shard: {shard_path.name}\")\n \n # Create sample ID\n sample_id = self.format_sample_id(self.total_samples)\n \n # Extract text and metadata\n text = sample.get('text', '')\n metadata = {\n 'id': sample.get('id', ''),\n 'url': sample.get('url', ''),\n 'dump': sample.get('dump', ''),\n 'score': sample.get('score', None),\n 'token_count': sample.get('token_count', None),\n 'language': sample.get('language', ''),\n 'language_score': sample.get('language_score', None),\n 'sample_id': sample_id,\n 'sample_index': self.total_samples\n }\n \n # Write to tar\n sample_size = self.write_sample_to_tar(tar, sample_id, text, metadata)\n \n # Update counters\n self.current_shard_size += sample_size\n self.current_shard_samples += 1\n self.total_samples += 1\n pbar.update(1)\n \n pbar.close()\n \n # Finalize last shard\n tar.close()\n self.finalize_shard(shard_path)\n \n except Exception as e:\n logger.error(f\"Error during conversion: {e}\")\n tar.close()\n raise\n \n # Write checksums and metadata\n self.write_checksums()\n self.write_dataset_metadata()\n \n logger.info(f\"\\nConversion complete!\")\n logger.info(f\"Total samples: {self.total_samples}\")\n logger.info(f\"Total shards: {self.current_shard + 1}\")\n logger.info(f\"Output directory: {self.output_dir}\")\n \n def write_checksums(self):\n \"\"\"Write checksums.json file.\"\"\"\n checksums_path = self.output_dir / \"checksums.json\"\n with open(checksums_path, 'w') as f:\n json.dump(self.checksums, f, indent=2)\n logger.info(f\"Checksums written to: {checksums_path}\")\n \n def write_dataset_metadata(self):\n \"\"\"Write dataset_metadata.json file.\"\"\"\n metadata = {\n \"dataset_name\": self.dataset_name,\n \"config_name\": self.config_name,\n \"split\": self.split,\n \"total_samples\": self.total_samples,\n \"num_shards\": self.current_shard + 1,\n \"shard_size_mb\": self.shard_size_bytes / (1024 * 1024),\n \"shards\": self.shard_metadata,\n \"format\": \"webdataset\",\n \"sample_structure\": {\n \"text\": \".txt file\",\n \"metadata\": \".json file (id, url, dump, score, token_count, language, language_score, sample_id, sample_index)\"\n }\n }\n \n metadata_path = self.output_dir / \"dataset_metadata.json\"\n with open(metadata_path, 'w') as f:\n json.dump(metadata, f, indent=2)\n logger.info(f\"Dataset metadata written to: {metadata_path}\")\n\n\ndef main():\n \"\"\"Main entry point.\"\"\"\n parser = argparse.ArgumentParser(\n description=\"Convert HuggingFaceFW/fineweb-edu to WebDataset format\"\n )\n parser.add_argument(\n \"--dataset\",\n type=str,\n default=\"HuggingFaceFW/fineweb-edu\",\n help=\"HuggingFace dataset name\"\n )\n parser.add_argument(\n \"--config\",\n type=str,\n default=None,\n help=\"Dataset configuration (e.g., 'sample-10BT', 'sample-100BT', 'sample-350BT')\"\n )\n parser.add_argument(\n \"--split\",\n type=str,\n default=\"train\",\n help=\"Dataset split to convert\"\n )\n parser.add_argument(\n \"--output-dir\",\n type=str,\n default=\"./webdataset_output\",\n help=\"Output directory for WebDataset shards\"\n )\n parser.add_argument(\n \"--shard-size\",\n type=int,\n default=500,\n help=\"Target shard size in MB\"\n )\n parser.add_argument(\n \"--max-samples\",\n type=int,\n default=None,\n help=\"Maximum number of samples to convert (for testing)\"\n )\n parser.add_argument(\n \"--no-streaming\",\n action=\"store_true\",\n help=\"Disable streaming mode (loads entire dataset into memory)\"\n )\n \n args = parser.parse_args()\n \n # Create converter\n converter = WebDatasetConverter(\n dataset_name=args.dataset,\n config_name=args.config,\n split=args.split,\n output_dir=args.output_dir,\n shard_size_mb=args.shard_size,\n max_samples=args.max_samples,\n streaming=not args.no_streaming\n )\n \n # Run conversion\n converter.convert()\n\n\nif __name__ == \"__main__\":\n main()\n'''\n\n# Write the conversion script\n(project_root / 'convert_to_webdataset.py').write_text(convert_script)\nprint(f\"βœ“ Created {project_root / 'convert_to_webdataset.py'}\")\nprint(f\" Size: {len(convert_script)} bytes\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "9fqo47lnws", + "source": "# 2. webdataset_loader.py\nloader_script = r'''#!/usr/bin/env python3\n\"\"\"\nWebDataset Streaming Loader with Checksum Validation.\n\nThis module provides a streaming loader for WebDataset format with:\n- Checksum validation before loading shards\n- PyTorch DataLoader compatible interface\n- Support for distributed training (worker sharding)\n- Optional sample filtering and transformation\n\"\"\"\n\nimport hashlib\nimport json\nimport logging\nimport warnings\nfrom pathlib import Path\nfrom typing import Dict, Optional, Callable, Any, List, Iterator\nimport tarfile\nfrom io import BytesIO\n\nimport torch\nfrom torch.utils.data import IterableDataset, DataLoader\nimport webdataset as wds\n\n# Configure logging\nlogging.basicConfig(level=logging.INFO)\nlogger = logging.getLogger(__name__)\n\n\nclass ChecksumValidator:\n \"\"\"Validate checksums for WebDataset shards.\"\"\"\n \n def __init__(self, checksums_file: Path):\n \"\"\"\n Initialize validator with checksums file.\n \n Args:\n checksums_file: Path to checksums.json file\n \"\"\"\n self.checksums_file = Path(checksums_file)\n self.checksums: Dict[str, str] = {}\n self._load_checksums()\n \n def _load_checksums(self):\n \"\"\"Load checksums from JSON file.\"\"\"\n if not self.checksums_file.exists():\n raise FileNotFoundError(f\"Checksums file not found: {self.checksums_file}\")\n \n with open(self.checksums_file, 'r') as f:\n self.checksums = json.load(f)\n \n logger.info(f\"Loaded {len(self.checksums)} checksums from {self.checksums_file}\")\n \n def compute_sha256(self, filepath: Path) -> str:\n \"\"\"Compute SHA256 checksum of a file.\"\"\"\n sha256_hash = hashlib.sha256()\n with open(filepath, \"rb\") as f:\n for byte_block in iter(lambda: f.read(4096), b\"\"):\n sha256_hash.update(byte_block)\n return sha256_hash.hexdigest()\n \n def validate_shard(self, shard_path: Path) -> bool:\n \"\"\"\n Validate a shard's checksum.\n \n Args:\n shard_path: Path to the shard file\n \n Returns:\n True if checksum matches, False otherwise\n \"\"\"\n shard_name = shard_path.name\n \n if shard_name not in self.checksums:\n logger.warning(f\"No checksum found for shard: {shard_name}\")\n return False\n \n expected_checksum = self.checksums[shard_name]\n actual_checksum = self.compute_sha256(shard_path)\n \n if actual_checksum != expected_checksum:\n logger.error(\n f\"Checksum mismatch for {shard_name}!\\n\"\n f\" Expected: {expected_checksum}\\n\"\n f\" Actual: {actual_checksum}\"\n )\n return False\n \n logger.debug(f\"Checksum validated for {shard_name}\")\n return True\n \n def validate_all_shards(self, shard_dir: Path) -> bool:\n \"\"\"\n Validate all shards in a directory.\n \n Args:\n shard_dir: Directory containing shard files\n \n Returns:\n True if all shards are valid, False otherwise\n \"\"\"\n shard_dir = Path(shard_dir)\n all_valid = True\n \n for shard_name in self.checksums.keys():\n shard_path = shard_dir / shard_name\n \n if not shard_path.exists():\n logger.error(f\"Shard not found: {shard_path}\")\n all_valid = False\n continue\n \n if not self.validate_shard(shard_path):\n all_valid = False\n \n return all_valid\n\n\nclass WebDatasetLoader(IterableDataset):\n \"\"\"\n Streaming WebDataset loader with checksum validation and PyTorch compatibility.\n \"\"\"\n \n def __init__(\n self,\n data_dir: str,\n validate_checksums: bool = True,\n shuffle: bool = False,\n buffer_size: int = 1000,\n transform: Optional[Callable] = None,\n filter_fn: Optional[Callable] = None,\n shard_pattern: str = \"*.tar\"\n ):\n \"\"\"\n Initialize the WebDataset loader.\n \n Args:\n data_dir: Directory containing WebDataset shards\n validate_checksums: Whether to validate checksums before loading\n shuffle: Whether to shuffle samples (requires buffer)\n buffer_size: Buffer size for shuffling\n transform: Optional transformation function for samples\n filter_fn: Optional filter function to skip samples\n shard_pattern: Glob pattern for shard files\n \"\"\"\n super().__init__()\n \n self.data_dir = Path(data_dir)\n self.validate_checksums = validate_checksums\n self.shuffle = shuffle\n self.buffer_size = buffer_size\n self.transform = transform\n self.filter_fn = filter_fn\n self.shard_pattern = shard_pattern\n \n # Find all shards\n self.shard_paths = sorted(self.data_dir.glob(shard_pattern))\n \n if not self.shard_paths:\n raise ValueError(f\"No shards found in {data_dir} matching pattern {shard_pattern}\")\n \n logger.info(f\"Found {len(self.shard_paths)} shards in {data_dir}\")\n \n # Validate checksums if requested\n if self.validate_checksums:\n self._validate_all_checksums()\n \n # Load metadata\n self.metadata = self._load_metadata()\n \n def _validate_all_checksums(self):\n \"\"\"Validate checksums for all shards.\"\"\"\n checksums_file = self.data_dir / \"checksums.json\"\n \n if not checksums_file.exists():\n warnings.warn(\n f\"Checksums file not found: {checksums_file}. \"\n \"Skipping validation.\"\n )\n return\n \n validator = ChecksumValidator(checksums_file)\n \n logger.info(\"Validating checksums for all shards...\")\n all_valid = validator.validate_all_shards(self.data_dir)\n \n if not all_valid:\n raise ValueError(\"Checksum validation failed! Some shards are corrupted.\")\n \n logger.info(\"All checksums validated successfully\")\n \n def _load_metadata(self) -> Dict[str, Any]:\n \"\"\"Load dataset metadata if available.\"\"\"\n metadata_file = self.data_dir / \"dataset_metadata.json\"\n \n if metadata_file.exists():\n with open(metadata_file, 'r') as f:\n metadata = json.load(f)\n logger.info(f\"Loaded metadata: {metadata.get('total_samples', 'unknown')} samples\")\n return metadata\n else:\n logger.warning(f\"Metadata file not found: {metadata_file}\")\n return {}\n \n def _decode_sample(self, sample: Dict) -> Dict:\n \"\"\"\n Decode a sample from WebDataset format.\n \n Expected format:\n - sample['txt']: text content (bytes)\n - sample['json']: metadata (bytes)\n \"\"\"\n decoded = {}\n \n # Decode text\n if 'txt' in sample:\n decoded['text'] = sample['txt'].decode('utf-8')\n \n # Decode metadata\n if 'json' in sample:\n metadata = json.loads(sample['json'].decode('utf-8'))\n decoded.update(metadata)\n \n # Keep the key\n if '__key__' in sample:\n decoded['__key__'] = sample['__key__']\n \n return decoded\n \n def __iter__(self) -> Iterator[Dict]:\n \"\"\"Iterate over samples in the dataset.\"\"\"\n # Get worker info for distributed training\n worker_info = torch.utils.data.get_worker_info()\n \n if worker_info is not None:\n # Split shards among workers\n num_workers = worker_info.num_workers\n worker_id = worker_info.id\n \n # Select shards for this worker\n shards_per_worker = len(self.shard_paths) // num_workers\n start_idx = worker_id * shards_per_worker\n end_idx = start_idx + shards_per_worker if worker_id < num_workers - 1 else len(self.shard_paths)\n \n worker_shards = self.shard_paths[start_idx:end_idx]\n logger.info(f\"Worker {worker_id}/{num_workers}: processing {len(worker_shards)} shards\")\n else:\n worker_shards = self.shard_paths\n \n # Convert paths to URLs for webdataset\n shard_urls = [str(p) for p in worker_shards]\n \n # Create WebDataset pipeline\n dataset = wds.WebDataset(shard_urls)\n \n # Add shuffling if requested\n if self.shuffle:\n dataset = dataset.shuffle(self.buffer_size)\n \n # Decode samples\n dataset = dataset.map(self._decode_sample)\n \n # Apply filter if provided\n if self.filter_fn is not None:\n dataset = dataset.select(self.filter_fn)\n \n # Apply transformation if provided\n if self.transform is not None:\n dataset = dataset.map(self.transform)\n \n # Iterate over samples\n for sample in dataset:\n yield sample\n \n def get_dataloader(\n self,\n batch_size: int = 32,\n num_workers: int = 4,\n pin_memory: bool = True,\n collate_fn: Optional[Callable] = None\n ) -> DataLoader:\n \"\"\"\n Create a PyTorch DataLoader for this dataset.\n \n Args:\n batch_size: Batch size\n num_workers: Number of worker processes\n pin_memory: Whether to pin memory for faster GPU transfer\n collate_fn: Custom collate function for batching\n \n Returns:\n DataLoader instance\n \"\"\"\n return DataLoader(\n self,\n batch_size=batch_size,\n num_workers=num_workers,\n pin_memory=pin_memory,\n collate_fn=collate_fn\n )\n\n\n# Utility functions\n\ndef verify_checksums(data_dir: str) -> bool:\n \"\"\"\n Verify checksums for all shards in a directory.\n \n Args:\n data_dir: Directory containing WebDataset shards and checksums.json\n \n Returns:\n True if all checksums are valid, False otherwise\n \"\"\"\n data_dir = Path(data_dir)\n checksums_file = data_dir / \"checksums.json\"\n \n if not checksums_file.exists():\n logger.error(f\"Checksums file not found: {checksums_file}\")\n return False\n \n validator = ChecksumValidator(checksums_file)\n return validator.validate_all_shards(data_dir)\n\n\ndef default_collate_fn(batch: List[Dict]) -> Dict:\n \"\"\"\n Default collate function for batching WebDataset samples.\n \n Args:\n batch: List of decoded samples\n \n Returns:\n Batched dictionary with lists of values\n \"\"\"\n if not batch:\n return {}\n \n # Get all keys from first sample\n keys = batch[0].keys()\n \n # Collate each key\n collated = {}\n for key in keys:\n values = [sample[key] for sample in batch]\n collated[key] = values\n \n return collated\n\n\ndef main():\n \"\"\"Example usage and testing.\"\"\"\n import argparse\n \n parser = argparse.ArgumentParser(description=\"WebDataset Loader with Checksum Validation\")\n parser.add_argument(\"data_dir\", type=str, help=\"Directory containing WebDataset shards\")\n parser.add_argument(\"--validate-only\", action=\"store_true\", help=\"Only validate checksums\")\n parser.add_argument(\"--no-validate\", action=\"store_true\", help=\"Skip checksum validation\")\n parser.add_argument(\"--num-samples\", type=int, default=10, help=\"Number of samples to load (for testing)\")\n \n args = parser.parse_args()\n \n if args.validate_only:\n # Just validate checksums\n logger.info(\"Validating checksums...\")\n valid = verify_checksums(args.data_dir)\n \n if valid:\n logger.info(\"All checksums are valid!\")\n return 0\n else:\n logger.error(\"Checksum validation failed!\")\n return 1\n else:\n # Load and display samples\n logger.info(f\"Loading WebDataset from {args.data_dir}\")\n \n loader = WebDatasetLoader(\n args.data_dir,\n validate_checksums=not args.no_validate,\n shuffle=False\n )\n \n logger.info(f\"Loading {args.num_samples} samples...\")\n \n for i, sample in enumerate(loader):\n if i >= args.num_samples:\n break\n \n print(f\"\\nSample {i+1}:\")\n print(f\" Key: {sample.get('__key__', 'N/A')}\")\n print(f\" Text length: {len(sample.get('text', ''))} characters\")\n print(f\" Metadata: {', '.join(k for k in sample.keys() if k not in ['text', '__key__'])}\")\n \n logger.info(f\"Successfully loaded {min(i+1, args.num_samples)} samples\")\n return 0\n\n\nif __name__ == \"__main__\":\n import sys\n sys.exit(main())\n'''\n\n# Write the loader script\n(project_root / 'webdataset_loader.py').write_text(loader_script)\nprint(f\"\\nβœ“ Created {project_root / 'webdataset_loader.py'}\")\nprint(f\" Size: {len(loader_script)} bytes\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "tjcbvnxrcxn", + "source": "# 3. requirements.txt for WebDataset tools\nrequirements_txt = '''# WebDataset Conversion and Loading Requirements\n# For converting HuggingFaceFW/fineweb-edu to WebDataset format\n\n# Core dependencies\ndatasets>=2.14.0\nwebdataset>=0.2.48\ntorch>=2.0.0\ntqdm>=4.65.0\n\n# Optional but recommended\nnumpy>=1.24.0\n'''\n\n# Write requirements.txt\n(project_root / 'webdataset_requirements.txt').write_text(requirements_txt)\nprint(f\"\\nβœ“ Created {project_root / 'webdataset_requirements.txt'}\")\nprint(f\" Size: {len(requirements_txt)} bytes\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "kttxvkmvl3d", + "source": "# 4. example_usage.py\nexample_script = r'''#!/usr/bin/env python3\n\"\"\"\nExample usage of WebDataset conversion and loading scripts.\n\nThis script demonstrates:\n1. Converting a small sample of fineweb-edu to WebDataset format\n2. Validating checksums\n3. Loading data with the WebDataset loader\n4. Using the loader with PyTorch DataLoader\n\"\"\"\n\nimport logging\nfrom pathlib import Path\n\n# Import our modules\nfrom convert_to_webdataset import WebDatasetConverter\nfrom webdataset_loader import WebDatasetLoader, verify_checksums, default_collate_fn\n\n# Configure logging\nlogging.basicConfig(\n level=logging.INFO,\n format='%(asctime)s - %(levelname)s - %(message)s'\n)\nlogger = logging.getLogger(__name__)\n\n\ndef example_1_basic_conversion():\n \"\"\"Example 1: Basic conversion of a small dataset sample.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 1: Basic Conversion\")\n logger.info(\"=\"*80)\n \n # Convert a small sample (1000 documents) for testing\n converter = WebDatasetConverter(\n dataset_name=\"HuggingFaceFW/fineweb-edu\",\n config_name=\"sample-10BT\", # Use the 10BT sample\n split=\"train\",\n output_dir=\"./webdataset_sample\",\n shard_size_mb=50, # Smaller shards for testing\n max_samples=1000, # Just 1000 samples\n streaming=True\n )\n \n logger.info(\"Starting conversion...\")\n converter.convert()\n logger.info(\"Conversion complete!\\n\")\n\n\ndef example_2_validate_checksums():\n \"\"\"Example 2: Validate checksums for converted dataset.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 2: Checksum Validation\")\n logger.info(\"=\"*80)\n \n data_dir = \"./webdataset_sample\"\n \n logger.info(f\"Validating checksums in {data_dir}...\")\n valid = verify_checksums(data_dir)\n \n if valid:\n logger.info(\"βœ“ All checksums are valid!\")\n else:\n logger.error(\"βœ— Checksum validation failed!\")\n \n logger.info(\"\")\n\n\ndef example_3_basic_loading():\n \"\"\"Example 3: Basic loading and iteration.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 3: Basic Loading\")\n logger.info(\"=\"*80)\n \n # Create loader\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n shuffle=False\n )\n \n # Load and display a few samples\n logger.info(\"Loading first 5 samples...\")\n for i, sample in enumerate(loader):\n if i >= 5:\n break\n \n logger.info(f\"\\nSample {i+1}:\")\n logger.info(f\" Sample ID: {sample.get('sample_id', 'N/A')}\")\n logger.info(f\" Text length: {len(sample.get('text', ''))} characters\")\n logger.info(f\" URL: {sample.get('url', 'N/A')}\")\n logger.info(f\" Score: {sample.get('score', 'N/A')}\")\n logger.info(f\" Token count: {sample.get('token_count', 'N/A')}\")\n logger.info(f\" Language: {sample.get('language', 'N/A')}\")\n \n # Show first 200 characters of text\n text_preview = sample.get('text', '')[:200]\n logger.info(f\" Text preview: {text_preview}...\")\n \n logger.info(\"\")\n\n\ndef example_4_with_filtering():\n \"\"\"Example 4: Loading with filtering.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 4: Loading with Filtering\")\n logger.info(\"=\"*80)\n \n # Define a filter function (e.g., only high-quality documents)\n def high_quality_filter(sample):\n \"\"\"Only keep samples with score >= 3.0.\"\"\"\n score = sample.get('score')\n return score is not None and score >= 3.0\n \n # Create loader with filter\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n filter_fn=high_quality_filter,\n shuffle=False\n )\n \n # Count filtered samples\n logger.info(\"Counting high-quality samples (score >= 3.0)...\")\n count = 0\n scores = []\n \n for sample in loader:\n count += 1\n scores.append(sample.get('score', 0))\n if count >= 100: # Check first 100\n break\n \n logger.info(f\"Found {count} high-quality samples\")\n logger.info(f\"Average score: {sum(scores) / len(scores):.2f}\")\n logger.info(f\"Min score: {min(scores):.2f}\")\n logger.info(f\"Max score: {max(scores):.2f}\")\n logger.info(\"\")\n\n\ndef example_5_with_transformation():\n \"\"\"Example 5: Loading with transformation.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 5: Loading with Transformation\")\n logger.info(\"=\"*80)\n \n # Define a transformation function\n def transform_sample(sample):\n \"\"\"Add computed features to sample.\"\"\"\n # Add word count\n text = sample.get('text', '')\n sample['word_count'] = len(text.split())\n \n # Add character count\n sample['char_count'] = len(text)\n \n # Truncate text to first 500 characters for memory efficiency\n sample['text_truncated'] = text[:500]\n \n return sample\n \n # Create loader with transformation\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n transform=transform_sample,\n shuffle=False\n )\n \n # Load and display transformed samples\n logger.info(\"Loading 3 transformed samples...\")\n for i, sample in enumerate(loader):\n if i >= 3:\n break\n \n logger.info(f\"\\nTransformed Sample {i+1}:\")\n logger.info(f\" Word count: {sample.get('word_count', 'N/A')}\")\n logger.info(f\" Char count: {sample.get('char_count', 'N/A')}\")\n logger.info(f\" Token count: {sample.get('token_count', 'N/A')}\")\n logger.info(f\" Truncated text: {sample.get('text_truncated', '')[:100]}...\")\n \n logger.info(\"\")\n\n\ndef example_6_pytorch_dataloader():\n \"\"\"Example 6: Using with PyTorch DataLoader.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 6: PyTorch DataLoader Integration\")\n logger.info(\"=\"*80)\n \n # Create loader\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n shuffle=True, # Shuffle for training\n buffer_size=100\n )\n \n # Create PyTorch DataLoader\n dataloader = loader.get_dataloader(\n batch_size=8,\n num_workers=2,\n collate_fn=default_collate_fn\n )\n \n # Iterate over batches\n logger.info(\"Loading 3 batches...\")\n for i, batch in enumerate(dataloader):\n if i >= 3:\n break\n \n logger.info(f\"\\nBatch {i+1}:\")\n logger.info(f\" Batch size: {len(batch['text'])}\")\n logger.info(f\" Sample IDs: {batch['sample_id'][:3]}...\")\n logger.info(f\" Average text length: {sum(len(t) for t in batch['text']) / len(batch['text']):.0f} chars\")\n \n # Show scores if available\n if 'score' in batch:\n scores = [s for s in batch['score'] if s is not None]\n if scores:\n logger.info(f\" Average score: {sum(scores) / len(scores):.2f}\")\n \n logger.info(\"\")\n\n\ndef example_7_distributed_training():\n \"\"\"Example 7: Simulating distributed training setup.\"\"\"\n logger.info(\"=\"*80)\n logger.info(\"EXAMPLE 7: Distributed Training Simulation\")\n logger.info(\"=\"*80)\n \n # Create loader\n loader = WebDatasetLoader(\n data_dir=\"./webdataset_sample\",\n validate_checksums=True,\n shuffle=True,\n buffer_size=100\n )\n \n # Create DataLoader with multiple workers\n # Each worker will automatically get a subset of shards\n dataloader = loader.get_dataloader(\n batch_size=4,\n num_workers=4, # 4 workers will split shards among themselves\n collate_fn=default_collate_fn\n )\n \n logger.info(\"DataLoader with 4 workers created\")\n logger.info(\"Each worker will process a subset of shards\")\n logger.info(\"Loading first batch...\")\n \n # Load one batch to verify it works\n batch = next(iter(dataloader))\n logger.info(f\"Successfully loaded batch with {len(batch['text'])} samples\")\n logger.info(\"\")\n\n\ndef main():\n \"\"\"Run all examples.\"\"\"\n logger.info(\"\\n\" + \"=\"*80)\n logger.info(\"WebDataset Conversion and Loading Examples\")\n logger.info(\"=\"*80 + \"\\n\")\n \n try:\n # Example 1: Convert dataset\n example_1_basic_conversion()\n \n # Example 2: Validate checksums\n example_2_validate_checksums()\n \n # Example 3: Basic loading\n example_3_basic_loading()\n \n # Example 4: With filtering\n example_4_with_filtering()\n \n # Example 5: With transformation\n example_5_with_transformation()\n \n # Example 6: PyTorch DataLoader\n example_6_pytorch_dataloader()\n \n # Example 7: Distributed training\n example_7_distributed_training()\n \n logger.info(\"=\"*80)\n logger.info(\"All examples completed successfully!\")\n logger.info(\"=\"*80)\n \n except Exception as e:\n logger.error(f\"Error running examples: {e}\", exc_info=True)\n return 1\n \n return 0\n\n\nif __name__ == \"__main__\":\n import sys\n sys.exit(main())\n'''\n\n# Write the example script\n(project_root / 'example_usage.py').write_text(example_script)\nprint(f\"\\nβœ“ Created {project_root / 'example_usage.py'}\")\nprint(f\" Size: {len(example_script)} bytes\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "m1o9mhn826q", + "source": "# 5. README.md - Complete documentation\nreadme_content = '''# WebDataset Conversion for FineWeb-Edu\n\nComplete toolkit for converting HuggingFaceFW/fineweb-edu dataset to WebDataset format with checksum validation and streaming support.\n\n## Overview\n\nThis project provides production-ready tools for:\n- Converting HuggingFace datasets to WebDataset tar format\n- Validating data integrity with SHA256 checksums\n- Streaming large datasets efficiently\n- PyTorch DataLoader integration\n- Distributed training support\n\n## Features\n\n### Conversion (`convert_to_webdataset.py`)\n- βœ… Streaming mode for memory-efficient processing\n- βœ… Configurable shard sizes (~500MB default)\n- βœ… SHA256 checksum generation per shard\n- βœ… Comprehensive metadata tracking\n- βœ… Progress bars and detailed logging\n- βœ… Support for all fineweb-edu configurations\n\n### Loading (`webdataset_loader.py`)\n- βœ… Checksum validation before loading\n- βœ… PyTorch `IterableDataset` interface\n- βœ… Automatic worker-based shard distribution\n- βœ… Optional shuffling with configurable buffer\n- βœ… Sample filtering and transformation\n- βœ… Compatible with PyTorch DataLoader\n\n## Installation\n\n### Basic Installation\n\n```bash\npip install -r webdataset_requirements.txt\n```\n\n### Using uv (Recommended)\n\n```bash\n# If you have uv installed\nuv pip install -r webdataset_requirements.txt\n```\n\n### Dependencies\n\n- `datasets>=2.14.0` - HuggingFace datasets library\n- `webdataset>=0.2.48` - WebDataset format support\n- `torch>=2.0.0` - PyTorch for DataLoader\n- `tqdm>=4.65.0` - Progress bars\n- `numpy>=1.24.0` - Numerical operations\n\n## Quick Start\n\n### 1. Convert Dataset\n\nConvert a small sample for testing:\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-10BT \\\\\n --output-dir ./webdataset_output \\\\\n --shard-size 500 \\\\\n --max-samples 10000\n```\n\nConvert the full dataset:\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-350BT \\\\\n --output-dir ./webdataset_full \\\\\n --shard-size 500\n```\n\n### 2. Validate Checksums\n\n```bash\npython webdataset_loader.py ./webdataset_output --validate-only\n```\n\n### 3. Load and Use Data\n\n```python\nfrom webdataset_loader import WebDatasetLoader\n\n# Create loader\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n shuffle=True,\n buffer_size=1000\n)\n\n# Iterate over samples\nfor sample in loader:\n text = sample['text']\n metadata = sample['id'], sample['url'], sample['score']\n # ... process sample\n```\n\n### 4. Use with PyTorch DataLoader\n\n```python\nfrom webdataset_loader import WebDatasetLoader, default_collate_fn\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n shuffle=True\n)\n\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=4,\n collate_fn=default_collate_fn\n)\n\nfor batch in dataloader:\n texts = batch['text'] # List of strings\n scores = batch['score'] # List of floats\n # ... train your model\n```\n\n## Detailed Usage\n\n### Conversion Script\n\n#### Command-Line Arguments\n\n```bash\npython convert_to_webdataset.py [OPTIONS]\n\nOptions:\n --dataset TEXT HuggingFace dataset name\n [default: HuggingFaceFW/fineweb-edu]\n \n --config TEXT Dataset configuration\n Options: sample-10BT, sample-100BT, sample-350BT\n [default: None]\n \n --split TEXT Dataset split to convert\n [default: train]\n \n --output-dir TEXT Output directory for shards\n [default: ./webdataset_output]\n \n --shard-size INT Target shard size in MB\n [default: 500]\n \n --max-samples INT Maximum samples to convert (for testing)\n [default: None (all samples)]\n \n --no-streaming Disable streaming mode\n [default: streaming enabled]\n```\n\n#### Python API\n\n```python\nfrom convert_to_webdataset import WebDatasetConverter\n\nconverter = WebDatasetConverter(\n dataset_name=\"HuggingFaceFW/fineweb-edu\",\n config_name=\"sample-10BT\",\n split=\"train\",\n output_dir=\"./my_dataset\",\n shard_size_mb=500,\n max_samples=None, # Convert all samples\n streaming=True\n)\n\nconverter.convert()\n```\n\n#### Output Structure\n\n```\nwebdataset_output/\nβ”œβ”€β”€ fineweb_edu_000000.tar # Shard 0 (~500MB)\nβ”œβ”€β”€ fineweb_edu_000001.tar # Shard 1 (~500MB)\nβ”œβ”€β”€ ...\nβ”œβ”€β”€ checksums.json # SHA256 checksums\n└── dataset_metadata.json # Dataset info\n```\n\n#### Sample Format in Tar Files\n\nEach sample consists of two files:\n- `sample_000000000000.txt` - Plain text content\n- `sample_000000000000.json` - Metadata with fields:\n - `id`: Document ID\n - `url`: Source URL\n - `dump`: Dump identifier\n - `score`: Quality score\n - `token_count`: Number of tokens\n - `language`: Language code\n - `language_score`: Language detection confidence\n - `sample_id`: WebDataset sample ID\n - `sample_index`: Index in original dataset\n\n### Loading Script\n\n#### Command-Line Usage\n\n```bash\n# Validate checksums only\npython webdataset_loader.py ./webdataset_output --validate-only\n\n# Load and display samples\npython webdataset_loader.py ./webdataset_output --num-samples 10\n\n# Skip validation (faster, but risky)\npython webdataset_loader.py ./webdataset_output --no-validate\n```\n\n#### Python API - Basic Usage\n\n```python\nfrom webdataset_loader import WebDatasetLoader\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True, # Validate before loading\n shuffle=False, # Don't shuffle\n buffer_size=1000, # Buffer size for shuffling\n transform=None, # No transformation\n filter_fn=None, # No filtering\n shard_pattern=\"*.tar\" # Glob pattern for shards\n)\n\n# Iterate over samples\nfor sample in loader:\n print(sample['text'])\n print(sample['score'])\n```\n\n#### Python API - With Filtering\n\n```python\ndef high_quality_filter(sample):\n \"\"\"Only keep high-quality documents.\"\"\"\n return sample.get('score', 0) >= 3.0\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n filter_fn=high_quality_filter\n)\n\nfor sample in loader:\n # All samples have score >= 3.0\n process(sample)\n```\n\n#### Python API - With Transformation\n\n```python\ndef add_features(sample):\n \"\"\"Add computed features.\"\"\"\n text = sample['text']\n sample['word_count'] = len(text.split())\n sample['char_count'] = len(text)\n return sample\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n transform=add_features\n)\n\nfor sample in loader:\n print(f\"Words: {sample['word_count']}\")\n```\n\n#### Python API - PyTorch DataLoader\n\n```python\nfrom webdataset_loader import WebDatasetLoader, default_collate_fn\nimport torch\n\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n shuffle=True,\n buffer_size=10000\n)\n\n# Create DataLoader\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=4,\n pin_memory=True,\n collate_fn=default_collate_fn\n)\n\n# Training loop\nfor epoch in range(10):\n for batch in dataloader:\n texts = batch['text'] # List of 32 strings\n scores = batch['score'] # List of 32 floats\n \n # Your training code here\n loss = model(texts, scores)\n loss.backward()\n optimizer.step()\n```\n\n#### Distributed Training\n\nThe loader automatically handles worker-based shard distribution:\n\n```python\n# Each worker gets a subset of shards\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=8, # 8 workers split shards among themselves\n pin_memory=True\n)\n\n# No additional code needed - sharding is automatic!\n```\n\n### Example Usage Script\n\nRun all examples:\n\n```bash\npython example_usage.py\n```\n\nThis demonstrates:\n1. Basic conversion\n2. Checksum validation\n3. Basic loading\n4. Loading with filtering\n5. Loading with transformation\n6. PyTorch DataLoader integration\n7. Distributed training simulation\n\n## Advanced Usage\n\n### Custom Collate Function\n\nCreate a custom collate function for batching:\n\n```python\nimport torch\n\ndef custom_collate_fn(batch):\n \"\"\"Custom batching with tokenization.\"\"\"\n from transformers import AutoTokenizer\n \n tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n \n # Extract texts\n texts = [sample['text'] for sample in batch]\n \n # Tokenize\n encoded = tokenizer(\n texts,\n padding=True,\n truncation=True,\n max_length=512,\n return_tensors='pt'\n )\n \n return {\n 'input_ids': encoded['input_ids'],\n 'attention_mask': encoded['attention_mask'],\n 'scores': torch.tensor([s['score'] for s in batch])\n }\n\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=4,\n collate_fn=custom_collate_fn\n)\n```\n\n### Multi-GPU Training\n\n```python\nimport torch\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\n# Initialize distributed training\ndist.init_process_group(\"nccl\")\nrank = dist.get_rank()\nworld_size = dist.get_world_size()\n\n# Create loader (same on all processes)\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n validate_checksums=True,\n shuffle=True\n)\n\n# Create DataLoader with appropriate workers\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=4\n)\n\n# Wrap model with DDP\nmodel = DDP(model, device_ids=[rank])\n\n# Training loop (each GPU processes different shards)\nfor batch in dataloader:\n # ... training code\n```\n\n### Checksum Validation Utilities\n\n```python\nfrom webdataset_loader import ChecksumValidator, verify_checksums\n\n# Method 1: Simple validation\nvalid = verify_checksums(\"./webdataset_output\")\nprint(f\"Checksums valid: {valid}\")\n\n# Method 2: Detailed validation\nfrom pathlib import Path\n\nvalidator = ChecksumValidator(\n Path(\"./webdataset_output/checksums.json\")\n)\n\n# Validate specific shard\nshard_path = Path(\"./webdataset_output/fineweb_edu_000000.tar\")\nis_valid = validator.validate_shard(shard_path)\n\n# Validate all shards\nall_valid = validator.validate_all_shards(\n Path(\"./webdataset_output\")\n)\n```\n\n## Configuration Examples\n\n### Small Test Dataset\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-10BT \\\\\n --output-dir ./test_dataset \\\\\n --shard-size 50 \\\\\n --max-samples 1000\n```\n\nOutput: ~1000 samples in small shards for quick testing\n\n### Medium Dataset\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-100BT \\\\\n --output-dir ./medium_dataset \\\\\n --shard-size 500\n```\n\nOutput: ~100B tokens in 500MB shards\n\n### Full Dataset\n\n```bash\npython convert_to_webdataset.py \\\\\n --config sample-350BT \\\\\n --output-dir ./full_dataset \\\\\n --shard-size 500\n```\n\nOutput: ~350B tokens in 500MB shards\n\n### Custom Dataset\n\n```python\nfrom convert_to_webdataset import WebDatasetConverter\n\n# Convert any HuggingFace dataset\nconverter = WebDatasetConverter(\n dataset_name=\"your-org/your-dataset\",\n config_name=\"your-config\",\n split=\"train\",\n output_dir=\"./custom_dataset\",\n shard_size_mb=500,\n streaming=True\n)\n\nconverter.convert()\n```\n\n## Performance Tips\n\n### Conversion Performance\n\n1. **Use streaming mode** (default) for large datasets\n2. **Adjust shard size** based on your storage:\n - Smaller shards (100MB): More files, faster per-shard processing\n - Larger shards (1GB): Fewer files, better for slow filesystems\n3. **Set max_samples** for testing before full conversion\n\n### Loading Performance\n\n1. **Use multiple workers**: `num_workers=4-8` for DataLoader\n2. **Enable pin_memory**: `pin_memory=True` for GPU training\n3. **Tune buffer_size**: Larger = better shuffling, more memory\n4. **Skip validation** after first check: `validate_checksums=False`\n\n### Memory Usage\n\n- Streaming mode: O(1) memory during conversion\n- Loading: O(buffer_size) for shuffling\n- Workers: Each worker loads one shard at a time\n\n## Troubleshooting\n\n### Issue: Checksum validation fails\n\n**Cause**: Corrupted shard or interrupted download\n\n**Solution**:\n```bash\n# Re-validate to identify corrupt shards\npython webdataset_loader.py ./webdataset_output --validate-only\n\n# Re-convert if needed\npython convert_to_webdataset.py --config sample-10BT --output-dir ./webdataset_output\n```\n\n### Issue: Out of memory during conversion\n\n**Cause**: Not using streaming mode\n\n**Solution**:\n```bash\n# Ensure streaming is enabled (default)\npython convert_to_webdataset.py --config sample-10BT\n```\n\n### Issue: Slow data loading\n\n**Cause**: Not using enough workers\n\n**Solution**:\n```python\ndataloader = loader.get_dataloader(\n batch_size=32,\n num_workers=8, # Increase workers\n pin_memory=True\n)\n```\n\n### Issue: Workers getting same data\n\n**Cause**: Not using `IterableDataset` correctly\n\n**Solution**: The WebDatasetLoader automatically handles worker sharding. Make sure you're using PyTorch >= 2.0.\n\n### Issue: Shards not found\n\n**Cause**: Wrong directory or glob pattern\n\n**Solution**:\n```python\n# Check the directory\nimport os\nprint(os.listdir(\"./webdataset_output\"))\n\n# Adjust shard_pattern if needed\nloader = WebDatasetLoader(\n data_dir=\"./webdataset_output\",\n shard_pattern=\"fineweb_edu_*.tar\" # More specific pattern\n)\n```\n\n## File Structure\n\n```\n.\nβ”œβ”€β”€ convert_to_webdataset.py # Conversion script\nβ”œβ”€β”€ webdataset_loader.py # Loading script\nβ”œβ”€β”€ example_usage.py # Usage examples\nβ”œβ”€β”€ webdataset_requirements.txt # Dependencies\n└── README.md # This file\n\n# After conversion:\nwebdataset_output/\nβ”œβ”€β”€ fineweb_edu_000000.tar # Shard 0\nβ”œβ”€β”€ fineweb_edu_000001.tar # Shard 1\nβ”œβ”€β”€ ...\nβ”œβ”€β”€ checksums.json # Checksums\n└── dataset_metadata.json # Metadata\n```\n\n## Dataset Information\n\n### HuggingFaceFW/fineweb-edu\n\nFineWeb-Edu is a high-quality educational subset of the FineWeb dataset:\n- **Size**: Up to 1.3T tokens (full version)\n- **Quality**: Filtered for educational content\n- **Language**: Primarily English\n- **Source**: Common Crawl\n- **License**: ODC-By 1.0\n\n### Configurations\n\n- `sample-10BT`: 10B token sample (~10M documents)\n- `sample-100BT`: 100B token sample (~100M documents)\n- `sample-350BT`: 350B token sample (~350M documents)\n- Full dataset: 1.3T tokens\n\n## API Reference\n\n### `WebDatasetConverter`\n\nMain class for converting HuggingFace datasets to WebDataset format.\n\n```python\nclass WebDatasetConverter:\n def __init__(\n self,\n dataset_name: str = \"HuggingFaceFW/fineweb-edu\",\n config_name: Optional[str] = None,\n split: str = \"train\",\n output_dir: str = \"./webdataset_output\",\n shard_size_mb: int = 500,\n max_samples: Optional[int] = None,\n streaming: bool = True\n )\n \n def convert(self) -> None:\n \"\"\"Run the conversion.\"\"\"\n \n def compute_sha256(self, filepath: Path) -> str:\n \"\"\"Compute SHA256 checksum.\"\"\"\n```\n\n### `WebDatasetLoader`\n\nMain class for loading WebDataset with validation.\n\n```python\nclass WebDatasetLoader(IterableDataset):\n def __init__(\n self,\n data_dir: str,\n validate_checksums: bool = True,\n shuffle: bool = False,\n buffer_size: int = 1000,\n transform: Optional[Callable] = None,\n filter_fn: Optional[Callable] = None,\n shard_pattern: str = \"*.tar\"\n )\n \n def __iter__(self) -> Iterator[Dict]:\n \"\"\"Iterate over samples.\"\"\"\n \n def get_dataloader(\n self,\n batch_size: int = 32,\n num_workers: int = 4,\n pin_memory: bool = True,\n collate_fn: Optional[Callable] = None\n ) -> DataLoader:\n \"\"\"Create PyTorch DataLoader.\"\"\"\n```\n\n### Utility Functions\n\n```python\ndef verify_checksums(data_dir: str) -> bool:\n \"\"\"Verify all checksums in directory.\"\"\"\n\ndef default_collate_fn(batch: List[Dict]) -> Dict:\n \"\"\"Default batching function.\"\"\"\n```\n\n## Contributing\n\nContributions are welcome! Please:\n1. Fork the repository\n2. Create a feature branch\n3. Make your changes\n4. Add tests if applicable\n5. Submit a pull request\n\n## License\n\nThis code is provided as-is for working with the FineWeb-Edu dataset.\n\nFineWeb-Edu dataset license: ODC-By 1.0\n\n## Acknowledgments\n\n- HuggingFace for the `datasets` library and FineWeb-Edu dataset\n- WebDataset project for the tar-based format\n- PyTorch team for excellent data loading primitives\n\n## References\n\n- [FineWeb-Edu Dataset](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)\n- [WebDataset Documentation](https://webdataset.github.io/webdataset/)\n- [PyTorch DataLoader](https://pytorch.org/docs/stable/data.html)\n- [HuggingFace Datasets](https://huggingface.co/docs/datasets/)\n\n## Support\n\nFor issues or questions:\n1. Check this README\n2. Run `example_usage.py` for working examples\n3. Validate your setup with small samples first\n4. Check PyTorch and HuggingFace documentation\n\nHappy dataset processing!\n'''\n\n# Write the README\n(project_root / 'WEBDATASET_README.md').write_text(readme_content)\nprint(f\"\\nβœ“ Created {project_root / 'WEBDATASET_README.md'}\")\nprint(f\" Size: {len(readme_content)} bytes\")\n\nprint(\"\\n\" + \"=\"*80)\nprint(\"ALL FILES CREATED SUCCESSFULLY!\")\nprint(\"=\"*80)\nprint(f\"\\nCreated files in {project_root}:\")\nprint(\" 1. convert_to_webdataset.py - Main conversion script\")\nprint(\" 2. webdataset_loader.py - Streaming loader with validation\")\nprint(\" 3. webdataset_requirements.txt - Dependencies\")\nprint(\" 4. example_usage.py - Complete usage examples\")\nprint(\" 5. WEBDATASET_README.md - Comprehensive documentation\")\nprint(\"\\nNext steps:\")\nprint(\" 1. Install dependencies: pip install -r webdataset_requirements.txt\")\nprint(\" 2. Test conversion: python convert_to_webdataset.py --config sample-10BT --max-samples 100\")\nprint(\" 3. Validate: python webdataset_loader.py ./webdataset_output --validate-only\")\nprint(\" 4. Run examples: python example_usage.py\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "oj722n8665e", + "source": "# Execute all the file creation code above\n# This cell runs all the code to actually create the files\n\nprint(\"Starting file creation process...\")\nprint(\"=\"*80)\n\n# The cells above already contain the code to write the files\n# They should have executed when the notebook cells were added\n\n# Let's verify the files exist\nimport os\nfrom pathlib import Path\n\nproject_root = Path('/Users/akseljoonas/Documents/hf-agent')\n\nfiles_to_check = [\n 'convert_to_webdataset.py',\n 'webdataset_loader.py',\n 'webdataset_requirements.txt',\n 'example_usage.py',\n 'WEBDATASET_README.md'\n]\n\nprint(\"\\nChecking created files:\")\nfor filename in files_to_check:\n filepath = project_root / filename\n if filepath.exists():\n size = filepath.stat().st_size\n print(f\" βœ“ {filename} ({size:,} bytes)\")\n else:\n print(f\" βœ— {filename} - NOT FOUND\")\n\nprint(\"\\n\" + \"=\"*80)\nprint(\"File creation verification complete!\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/eval/generate_rubrics.py b/eval/generate_rubrics.py new file mode 100644 index 0000000000000000000000000000000000000000..081d2db1be3a775b3a3d5de96328296094f5de8c --- /dev/null +++ b/eval/generate_rubrics.py @@ -0,0 +1,403 @@ +#!/usr/bin/env env python3 +""" +Rubric Generation Script for HF-Agent Benchmark + +Generates instance-specific evaluation rubrics following the "Rubrics as Rewards" paper. +Uses LiteLLM to call LLM models for rubric synthesis with expert grounding via reference answers. +""" + +import argparse +import json +import os +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List + +import litellm +import pandas as pd +from dotenv import load_dotenv +from pydantic import BaseModel + +from eval.hf_io import df_to_hub + + +class Rubric(BaseModel): + title: str + description: str + weight: int + + +class RubricList(BaseModel): + rubrics: List[Rubric] + + +# Load environment variables +load_dotenv() + +# Rubric generation prompt template based on RaR paper + + +PROMPT_TEMPLATE = """You are an expert rubric writer. Your job is to generate a self-contained set of evaluation criteria ("rubrics") for judging how good, helpful and complete an agent's trajectory is to a given user question/request. + +Rubrics can cover aspects of a response such as, but not limited to, factual correctness, helpfulness, completeness, harmlessness, correctness of using Hugging Face best practices (based on HF documentation), depth of +reasoning, contextual relevance and usefulness. Each item must be self-contained – non expert readers should not need to +infer anything or consult external information. Begin each description with its category: "Essential Criteria: . . . ", "Important +Criteria: . . . ", "Optional Criteria: . . . ", or "Pitfall Criteria: Does not mention . . . ". + + +Inputs: +- question: <<<{question}>>> +- example_solution (NOT ground truth - just an okay attempt): <<<{example_solution}>>> +- example_trace (NOT ground truth - just an okay attempt showing what tool usage might look like): <<<{example_trace}>>> + +IMPORTANT: The example_solution and example_trace provided are NOT ground truth or ideal solutions. They represent +an attempt at solving the task - they give you a general idea of the shape of the problem and what tool usage +might look like, but they contain mistakes and incomplete solutions, suboptimal approaches, or incomplete answers. Your rubrics MUST be designed to fairly grade a PERFECT solution. The perfect solution is complete in all aspects of solving the task and verifing it's correctness before giving the final answer. It tells the user what was done and why, and provides the final answer clearly answering the user's question. + +Total items: +β€’ Choose 7–20 rubric items based on the complexity of the question. + +Each rubric item: +β€’ title (2–4 words). +β€’ description: One sentence starting with its category prefix that explicitly states exactly what to look for. For example: +– Essential Criteria: Writes a up-to-date, correct, complete and working training loop using the latest Hugging Face best practices. Launches the training with hf-jobs. +– Pitfall Criteria: Deprecated launcher usage. Uses python -m torch.distributed.launch instead of torchrun / accelerate. +– Important Criteria: Explains common DDP knobs. Mentions ddp_find_unused_parameters=False for models with conditional branches; optional ddp_timeout; brief note on when they matter and why. +– Optional Criteria: Briefly notes --deepspeed ds_config.json as an alternative scaler when models get big (but stays on DDP for this Q). +β€’ weight: For Essential/Important/Optional, use 1–5 (5 = most important); for Pitfall, use –1 or –2. + +Category guidance: +β€’ Essential: Critical actions to answer/complete the user's question/request; if missing, the response is invalid and useless (weight 5). +β€’ Important: Key reasoning, completeness, or clarity; strongly affects quality and usefulness (weight 3–4). +β€’ Optional: Helpfulness in educating the user or providing extra depth; nice to have but not deal-breaking (weight 1–2). +β€’ Pitfall: Common mistakes or omissions specific to this promptβ€”identify things a respondent often forgets or misstates. +Each Pitfall description must begin with "Pitfall Criteria: Does not mention . . . " or "Pitfall Criteria: Recommends . . . " +and use weight –1 or –2. + +To ensure self-contained guidance: +β€’ When referring to answer choices, explicitly say "Identifies (A)", "Identifies (B)", etc., rather than vague phrasing. +β€’ If the format requires an action like calling a tool or launching a training run, include a rubric item such as: +– Essential Criteria: Includes a clear statement "Launches the training with hf-jobs.". +β€’ If reasoning should precede the answer, include a rubric like: +– Important Criteria: Presents the explanation and reasoning before stating the final answer. +β€’ If brevity is valued, include a rubric like: +– Optional Criteria: Remains concise and avoids unnecessary detail. +β€’ If the question context demands mention of specific findings/best practices, include that explicitly (e.g., "Essential Criteria: Mentions +that training data must be in "messages" column for LLM training"). + +Output: Provide a JSON array of rubric objects. Each object must contain exactly three keysβ€”title, description, and weight. +Do not copy large blocks of the question or example_solution into the text. Each description must begin with its category +prefix, and no extra keys are allowed. + +Remember: The example_solution and example_trace are NOT ideal answers - they are just rough attempts to show the +general approach. Design rubrics that can fairly evaluate any solution, including ones that are better than the example.""" + + +def build_prompt( + question: str, + example_solution: str, + example_trace: List[Dict[str, Any]], +) -> List[Dict[str, str]]: + """ + Build the messages list for LiteLLM completion. + + Args: + question: The question/task to evaluate + difficulty: The difficulty level of the task + example_solution: An example solution attempt (not ground truth) + example_trace: The agent's message trace showing tool usage + + Returns: + List of message dicts for LiteLLM + """ + # Format the trace for readability - only include key parts + formatted_trace = format_trace_for_prompt(example_trace) + + prompt = PROMPT_TEMPLATE.format( + question=question, + example_solution=example_solution, + example_trace=formatted_trace, + ) + + return [{"role": "user", "content": prompt}] + + +def format_trace_for_prompt(messages: List[Dict[str, Any]]) -> str: + """ + Format the agent message trace for inclusion in the prompt. + Extracts key information while keeping it readable. + """ + if not messages: + return "(No trace available)" + + formatted_parts = [] + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + + # Skip system messages + if role == "system": + continue + + # Handle tool calls + if "tool_calls" in msg and msg["tool_calls"]: + tool_info = [] + for tc in msg["tool_calls"]: + if isinstance(tc, dict) and "function" in tc: + func = tc["function"] + tool_name = func.get("name", "unknown_tool") + tool_info.append(f" - Called: {tool_name}") + if tool_info: + formatted_parts.append( + "[Assistant Tool Calls]\n" + "\n".join(tool_info) + ) + + # Handle regular content + if content: + # Truncate very long content + if len(content) > 500: + content = content[:500] + "... (truncated)" + formatted_parts.append(f"[{role.title()}]\n{content}") + + return "\n\n".join(formatted_parts) if formatted_parts else "(Empty trace)" + + +def validate_rubric(rubric_list: List[Dict[str, Any]]) -> bool: + """ + Validate that rubric meets basic requirements. + + Args: + rubric_list: List of rubric items to validate + + Returns: + True if valid, False otherwise + """ + # Check count + if not (7 <= len(rubric_list) <= 20): + return False + + # Check each item + category_prefixes = [ + "Essential Criteria:", + "Important Criteria:", + "Optional Criteria:", + "Pitfall Criteria:", + ] + + for item in rubric_list: + # Check keys + if set(item.keys()) != {"title", "description", "weight"}: + return False + + # Check description starts with category prefix + if not any( + item["description"].startswith(prefix) for prefix in category_prefixes + ): + return False + + return True + + +def generate_rubric(row: pd.Series, model: str, timeout: int = 120) -> Dict[str, Any]: + """ + Generate rubric for a single question using LiteLLM. + + Args: + row: DataFrame row containing question, difficulty, solution, and messages + model: Model name for LiteLLM + timeout: Request timeout in seconds + + Returns: + Dict with rubric_list and rubric_count, or None on failure + """ + + messages = build_prompt( + question=row["question"], + example_solution=row["solution"], + example_trace=row.get("messages", []), + ) + + try: + response = litellm.completion( + model=model, + messages=messages, + timeout=timeout, + response_format=RubricList, + ) + + # Parse structured output + rubric_list: RubricList = RubricList.model_validate_json( + response.choices[0].message.content + ) + + return rubric_list.model_dump_json() + except Exception as e: + print(f"Error generating rubric: {e}", file=sys.stderr) + return None + + +def load_input_data(infile: str) -> pd.DataFrame: + """ + Load input data from CSV or JSONL file. + + Args: + infile: Path to input file + + Returns: + DataFrame with loaded data + """ + path = Path(infile) + + if not path.exists(): + raise FileNotFoundError(f"Input file not found: {infile}") + + if path.suffix == ".csv": + # Try to auto-detect delimiter (comma or semicolon) + df = pd.read_csv(infile, sep=None, engine="python") + elif path.suffix == ".jsonl": + df = pd.read_json(infile, lines=True) + else: + raise ValueError(f"Unsupported file format: {path.suffix}. Use .csv or .jsonl") + + # Validate required columns + required_cols = [ + "question", + "solution", + ] + optional_cols = ["difficulty", "messages", "error"] + missing_cols = [col for col in required_cols if col not in df.columns] + + if missing_cols: + raise ValueError(f"Missing required columns: {missing_cols}") + + # Log available optional columns + available_optional = [col for col in optional_cols if col in df.columns] + print(f"Found optional columns: {available_optional}") + + return df + + +def main(): + parser = argparse.ArgumentParser( + description="Generate rubrics for HF-agent benchmark evaluation" + ) + parser.add_argument( + "--infile", type=str, required=True, help="Input file path (.csv or .jsonl)" + ) + parser.add_argument( + "--outfile", type=str, required=True, help="Output JSONL file path" + ) + parser.add_argument( + "--model", + type=str, + default="anthropic/claude-sonnet-4-5-20250929", + help="LiteLLM model name (default: from LITELLM_MODEL env or gpt-4o-mini)", + ) + parser.add_argument( + "--timeout", + type=int, + default=120, + help="Request timeout in seconds (default: 120)", + ) + parser.add_argument( + "--max-concurrent", + type=int, + default=30, + help="Maximum number of concurrent workers (default: 30)", + ) + parser.add_argument( + "--push-to-hub", + type=str, + default=None, + help="Push to HuggingFace dataset (e.g., username/dataset@rubrics)", + ) + + args = parser.parse_args() + + # Determine model + model = args.model or os.getenv("LITELLM_MODEL", "gpt-4o-mini") + print(f"Using model: {model}") + + # Load input data + print(f"Loading data from {args.infile}...") + df = load_input_data(args.infile) + print(f"Loaded {len(df)} examples") + + # Run rubric generation in parallel using ThreadPoolExecutor + print(f"Running generation with {args.max_concurrent} parallel workers...") + + with ThreadPoolExecutor(max_workers=args.max_concurrent) as executor: + # Submit all tasks + future_to_idx = {} + for idx, row in df.iterrows(): + future = executor.submit( + generate_rubric, + row=row, + model=model, + timeout=args.timeout, + ) + future_to_idx[future] = idx + + # Collect results in order + results = [None] * len(df) + completed = 0 + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + results[idx] = future.result() + completed += 1 + print(f"Completed: {completed}/{len(df)}", end="\r") + + print() # New line after progress + + # Prepare results DataFrame + print("Preparing results...") + output_rows = [] + success_count = 0 + failure_count = 0 + + for idx, (_, row) in enumerate(df.iterrows()): + rubric_result = results[idx] + + if rubric_result is None: + failure_count += 1 + continue + + # Merge with original data + output_row = row.to_dict() + output_row["messages"] = json.dumps(output_row["messages"]) + output_row["rubric"] = rubric_result + output_rows.append(output_row) + success_count += 1 + + # Create DataFrame with results + results_df = pd.DataFrame(output_rows) + + # Upload to HuggingFace if specified (before saving JSONL) + if args.push_to_hub: + print(f"\nUploading to HuggingFace: {args.push_to_hub}") + upload_success = df_to_hub( + df=results_df, + dataset_spec=args.push_to_hub, + split="train", + private=False, + ) + if not upload_success: + print("Warning: HuggingFace push failed, but continuing to save JSONL...") + + # Write results to JSONL file + print(f"\nWriting results to {args.outfile}...") + with open(args.outfile, "w") as outf: + for output_row in output_rows: + outf.write(json.dumps(output_row, default=str) + "\n") + + print("\nComplete!") + print(f"Success: {success_count}/{len(df)}") + print(f"Failures: {failure_count}/{len(df)}") + print(f"Output written to: {args.outfile}") + if args.push_to_hub and upload_success: + print(f"Pushed to: {args.push_to_hub}") + + +if __name__ == "__main__": + main() diff --git a/eval/generated_tasks_with_difficulty.json b/eval/generated_tasks_with_difficulty.json new file mode 100644 index 0000000000000000000000000000000000000000..344347fe1e25398dc00a7f7c2ebb7f50e02660d5 --- /dev/null +++ b/eval/generated_tasks_with_difficulty.json @@ -0,0 +1,255 @@ +{ + "Evaluate models {M_i} on benchmarks {B_i}": "Easy", + "Train models {M_i} on datasets {D_i} with benchmarks {B_i}": "Medium", + "Run an ablation for hyperparameter P for model M on dataset D": "Hard", + "Generate completions with model M on dataset D using engine E": "Medium", + "Merge models {M_i} using linear averaging to find the best result on benchmarks {B_i}": "Hard", + "Given datasets {D_i}, ablate the best SFT mixture for model M across benchmarks {B_i}": "Very hard", + "Decontaminate dataset D against benchmarks {B_i}": "Hard", + "Benchmark RL framework F for best throughput on G GPUs": "Very hard", + "Implement post-training algorithm A from paper P in framework F. Validate it runs end-to-end": "Very hard", + "Implement benchmark B in framework F. Validate it reproduces some published results": "Very hard", + "Format dataset D for compatibility with framework F on task T": "Easy", + "Remove the background from this image: [image path]": "Easy", + "Transcribe all of the audio files in this directory": "Easy", + "Transcribe all of the audio files in this directory, choose the model that'll be cheapest and also relatively accurate": "Medium (judgment call or interaction needed to figure out what accuracy levels are acceptable)", + "Remove the background music from this audio file": "Medium (needs to find Gradio Space and call its API0", + "Change this video track to be from English to Spanish": "Medium (needs to link several models together)", + "Translate this flyer from English to Spanish, keeping the layout and images the same": "Medium (needs to link several models together)", + "What's the best model for X?": "Easy", + "What datasets are available for X? (X={domain x task x modality})": "Easy", + "Is there a space to do Y?": "Easy", + "I have this script and this error - what's the issue?": "Medium", + "This space is broken, how can i fix it?": "Medium", + "I built a space but it is super slow. What can I do?": "Medium", + "How can I run modal X locally?": "Medium", + "I want to build a space with model Y to do X?": "Hard", + "How can I serve a model with multiple LoRAs?": "Hard", + "What's the best model for sentiment analysis on financial text?": "Easy", + "Are there any medical image segmentation datasets on HuggingFace for CT scans?": "Easy", + "Which text classification models support 4-bit quantization?": "Medium", + "Are there inference endpoints available for Whisper large-v3?": "Easy", + "What's the license for the SA-Med2D-20M dataset?": "Easy", + "Which vision models fit in 8GB VRAM for image segmentation?": "Medium", + "What datasets are available for 3D medical image segmentation?": "Medium", + "Is there a space to do text-to-speech with emotion control?": "Medium", + "I'm getting \"CUDA out of memory\" when loading Llama-2-7b even though nvidia-smi shows I have 6GB free - what's the issue?": "Medium", + "My Gradio space shows \"Connection errored out\" after working fine yesterday, no code changes - how can I fix it?": "Medium", + "I built a Gradio space for Stable Diffusion but inference takes 5+ minutes on a 4090 - what can I do?": "Medium", + "My Whisper model outputs different transcriptions after quantization to int8 - why?": "Medium", + "Getting \"RuntimeError: CUDA error: out of memory. Tried to allocate 70.00 MiB\" but only 2.87 GiB is allocated - what's happening?": "Medium", + "My HuggingFace space build fails with \"failed to create containerd task\" - how to fix?": "Medium", + "DistilBERT model gives \"you should probably train your model\" warning even though it's a pretrained model from the Hub": "Easy", + "Space was working fine but now receiving build errors - receiving this error even with a new space": "Medium", + "Inference is correct locally but wrong on deployed space": "Medium", + "Getting CUDA OOM despite having enough memory according to nvidia-smi": "Medium", + "How can I run Mistral-7B-v0.1 locally with multiple LoRA adapters?": "Hard", + "How can I serve Llama-2-7b with vLLM and dynamically load multiple LoRA adapters?": "Hard", + "How do I batch inference requests in my Gradio space for better throughput?": "Medium", + "Can I run Whisper large-v3 with faster-whisper for 4x speedup?": "Medium", + "How to run Llama 2 on CPU after fine-tuning with LoRA?": "Medium", + "Best way to handle 50+ concurrent requests in a Gradio space without OOM?": "Hard", + "How do I add custom stopping criteria for text generation with Transformers?": "Hard", + "Can I merge multiple LoRA adapters before inference to reduce latency?": "Hard", + "How can I optimize my LLM inference with one base LLM and multiple LoRA adapters?": "Hard", + "Compare tokenizers {T_i} for model M on tasks {classification, QA}; report accuracy and average sequence length per task": "Medium", + "Run a LoRA rank sweep (r in {4, 8, 16, 32}) for model M on dataset D; plot validation perplexity vs VRAM usage and select Pareto-optimal settings": "Hard", + "Build a streaming dataloader from Parquet on S3 with deterministic shuffling across N workers; validate epoch reproducibility": "Very hard", + "Find three open-source TTS models with emotion control and list their sample rates and licenses": "Easy", + "Create a retrieval-augmented QA pipeline: index corpus C with FAISS, connect to model M, and benchmark top-1 accuracy and p95 latency": "Hard", + "Diagnose a Space where memory grows per request; add no-grad guards, free caches, and demonstrate stable RSS over 10,000 calls": "Hard", + "Deduplicate dataset D using MinHash LSH at Jaccard >= 0.9 and publish a cleaned HF dataset with provenance columns": "Medium", + "Add special tokens to tokenizer T and resize model M embeddings; resume pretraining for 10k steps without loss spikes": "Hard", + "Create a HuggingFace Dataset from CSV file data.csv and push to repo username/my_dataset": "Easy", + "Build a real-time Whisper transcription Space with VAD and chunked decoding; keep end-to-end latency under 200 ms": "Hard", + "Quantize model M to 4-bit (bnb.int4) with bitsandbytes; compare perplexity and p95 latency to 8-bit on dataset D; select config with <1% perplexity increase": "Medium", + "Fuse LoRA adapter A into base model M and export a single safetensors checkpoint; verify logits parity (<1e-5 MSE) vs on-the-fly LoRA": "Hard", + "Redact PII from dataset D using a transformer NER pipeline; produce a cleaned HuggingFace Dataset with per-entity removal stats and provenance": "Medium", + "Train a SentencePiece tokenizer (vocab=64k, byte fallback) on corpus C; compare tokenization speed, unknown-token rate, and bytes/token vs tokenizer T": "Hard", + "Build a sharded FAISS IVF-PQ index for 100M embeddings stored on S3; integrate with HF datasets streaming and report recall@10 and QPS": "Very hard", + "Fine-tune model M with QLoRA using TRL PPO on dataset D; log KL, reward, and throughput; validate no divergence on a held-out eval": "Hard", + "Resolve HfHubHTTPError 401 when pushing dataset repo R: diagnose token scopes, git-lfs config, and large file thresholds; document the fix": "Medium", + "Implement a custom Transformers LogitsProcessor that bans repeated bigrams; add unit tests and benchmark generation quality (BLEU) on dataset D": "Hard", + "List and download all Hub models tagged 'text-classification' with Apache-2.0 license and size <500MB; save model ids and downloads to CSV": "Easy", + "Enable speculative decoding in vLLM with draft model D for base model M; benchmark tokens/sec speedup at batch sizes {1,4,16} and max_new_tokens {64,256}": "Very hard", + "Profile model M under torch.compile modes {reduce-overhead, max-autotune} on GPU G; report tokens/sec, peak VRAM, and compile overhead": "Medium", + "Detect and remove near-duplicate images in dataset D using CLIP ViT-L/14 embeddings at cosine >= 0.95; publish a cleaned dataset with duplicate_group ids": "Medium", + "Convert a TensorFlow SavedModel of T5-base to Transformers PyTorch format; verify logits parity (MSE < 1e-4) on 1,000 random prompts": "Hard", + "Enable FlashAttention-2 in a Transformers training loop for model M; benchmark step time and confirm loss parity over 2,000 steps vs baseline": "Hard", + "Deploy vLLM for model M with hot-swappable LoRA adapters {A_i}; provide an API to switch adapters and demonstrate <200 ms switch latency under load": "Very hard", + "Implement a custom Trainer callback to log gradient norms, activation histograms, and learning rate; diagnose periodic loss spikes and propose a fix": "Hard", + "Build a bilingual RAG pipeline indexing corpora {en, es} with FAISS HNSW; evaluate exact match@1 on dataset D and report p95 latency": "Hard", + "Run a mixed-precision sweep (fp16 vs bf16) for model M on A100 and RTX 3090; compare convergence, throughput, and numerical stability issues": "Medium", + "Create a Gradio Space that batches Whisper-large-v3 transcription via queue + chunked decoding; maintain real-time factor <= 0.5 on a T4": "Hard", + "List five OCR datasets on the Hub with line-level annotations; include licenses and approximate image counts": "Easy", + "List models on the Hub tagged 'summarization' that offer safetensors weights and 4-bit quantization; output model ids": "Easy", + "Evaluate safety filters of models {M_i} on red-team prompt set R; report jailbreak rate and false positive rate": "Medium", + "Run a prompt template ablation for chat model M on dataset D; compare {alpaca, chatml, llama2} formats and report exact match and average output length": "Hard", + "Implement tensor parallelism for model M in framework F and show linear scaling across 2\u20138 GPUs with <=10% gap from ideal": "Very hard", + "Convert and shard dataset D into WebDataset tar files (~500MB/shard); build a streaming loader with checksum validation": "Medium", + "Deploy a Spaces app serving Stable Diffusion XL with ControlNet; add output caching and keep p95 latency <1s for 20 concurrent users": "Hard", + "Diagnose and fix 'shape mismatch' when loading LoRA into model M after tokenizer resize; provide minimal repro and patch": "Medium", + "Add a detailed model card to repo username/model_M with training data, intended use, limitations, and evaluation results": "Easy", + "Enable KV cache quantization (int8) in Transformers for model M; compare tokens/sec and ROUGE-L on dataset D vs fp16 cache": "Hard", + "Detect and redact license-incompatible samples in dataset D by matching SPDX identifiers and source domains; publish a compliance report": "Medium", + "Profile vLLM serving of model M with paged attention; tune block_size to maximize tokens/sec and report p50/p95 latency and peak VRAM": "Medium", + "Filter dataset D for toxic content using classifier C; log per-label removal rates and recreate stratified train/valid/test splits": "Medium", + "Train a unigram tokenizer (vocab=80k) on corpora {en, fr}; fine-tune T5-small and compare BLEU vs a BPE baseline; report tokenization speed and OOV rate": "Hard", + "Run distributed evaluation of models {M_i} on benchmark B across 4 GPUs with DeepSpeed-Inference; ensure identical metrics across 3 seeds": "Hard", + "Find three open-source ASR models that provide word-level timestamps; record licenses and expected WER on LibriSpeech": "Easy", + "Diagnose intermittent 'Address already in use' crashes in a FastAPI Space; add graceful shutdown and port probing, verifying stability over 1,000 restart cycles": "Medium", + "Export a LoRA-finetuned Llama checkpoint to GGUF for llama.cpp; validate perplexity parity (<=1% drift) on WikiText-2": "Hard", + "Construct a streaming RAG pipeline over S3-stored corpus C with Chroma; index ~1B tokens, implement shard rebalancing, and benchmark recall@5 and QPS": "Very hard", + "List Hub datasets tagged 'speech-emotion-recognition' with CC-BY or CC-BY-SA licenses and >=10k utterances; write dataset ids and sizes to JSON": "Easy", + "Train a summarization reward model via pairwise ranking on dataset D; apply DPO to model M and report ROUGE-L and human win rate": "Hard", + "Find four open-source OCR models that output line- or paragraph-level text and provide ONNX or TensorRT exports; list their licenses and maximum input resolutions": "Easy", + "Verify tokenizer special tokens for model M are preserved after adding new tokens; write a unit test that asserts CLS/SEP/PAD ids are unchanged before and after resize": "Medium", + "Implement a constrained decoder for model M that enforces a JSON schema via a custom Transformers LogitsProcessor; add unit tests and benchmark latency on dataset D": "Hard", + "Build a multilingual RAG index for 50M documents using mDPR with sharded storage on S3; support hot index reloads and report recall@10 and p95 latency at 100 QPS": "Very hard", + "Quantize T5-base to 8-bit with bitsandbytes (LLM.int8) and compare ROUGE-L and tokens/sec to fp16 on CNN/DailyMail; keep ROUGE-L drop <=1%": "Medium", + "Diagnose VRAM growth in a vLLM server at batch size 32; add profiling, fix cache eviction behavior, and demonstrate flat memory over 10,000 requests": "Hard", + "Convert a HuggingFace TokenizerFast to a SentencePiece model; verify >=99.9% token-level agreement on 10,000 sentences and measure tokenization speed delta": "Medium", + "Train a multi-task adapter stack for {summarization, QA, NLI} on model M; implement routing by prompt prefix and report per-task metrics and cross-task interference": "Very hard", + "Assess license compatibility between model M (Apache-2.0) and dataset D (CC-BY-SA); produce a one-paragraph verdict with rationale and reference links": "Easy", + "Enable FSDP with activation checkpointing for a 13B model across 2\u00d7A100 GPUs; achieve <=10% throughput loss vs baseline and verify loss parity over 1,000 steps": "Hard", + "List three datasets for code summarization with permissive licenses; output their dataset ids and license names": "Easy", + "Set up nightly continuous evaluation of model M on benchmarks {B_i}; log metrics to Weights & Biases and alert on >2% regression vs last 7-day rolling mean": "Medium", + "Implement streaming text generation in a Gradio Space for model M using server-sent events; cap median token emission delay at <50 ms": "Hard", + "Scale out training of a 7B model with FSDP + ZeRO across 8 GPUs; demonstrate checkpoint save/restore and achieve throughput within 15% of ideal linear scaling": "Very hard", + "Export a mixture-of-experts PyTorch model to ONNX and run with TensorRT; verify top-1 accuracy within 0.5% of PyTorch on dataset D": "Medium", + "Identify whether model M supports FlashAttention-2 from its config or source; provide supporting repo links and a yes/no compatibility flag": "Easy", + "Build an audio deduplication pipeline for dataset D using embedding model E with cosine similarity >= 0.98; publish grouped duplicate ids and a cleaned manifest": "Hard", + "Diagnose slow tokenization in a Transformers pipeline; profile, switch to a fast tokenizer, and demonstrate 2\u00d7 end-to-end speedup on 1M lines": "Medium", + "Implement a contrastive preference learning loss in TRL; train model M on dataset D and compare KL, reward variance, and human win rate vs a PPO baseline": "Hard", + "Build an elastic RAG service with Ray that autoscales FAISS shards on S3, supports live corpus updates, and maintains p95 latency <500 ms at 200 QPS": "Very hard", + "List five chat-optimized LLMs on the Hub that include a tokenizer chat_template and safetensors weights; output model ids": "Easy", + "Find three biomedical NER datasets with Apache-2.0 or MIT licenses; return dataset ids and license names": "Easy", + "Create a dataset viewer Space that streams Parquet shards from the Hub using datasets streaming; implement server-side filtering and pagination": "Medium", + "Enable gradient checkpointing and optimizer state offloading for model M with Accelerate; report step time and peak VRAM vs baseline on a single A100": "Medium", + "Diagnose and fix 'size mismatch for position_embeddings' after increasing max_position_embeddings; provide a minimal repro and a migration script": "Medium", + "Implement a regex-constrained Transformers LogitsProcessor that enforces ISO-8601 timestamps; add unit tests and report generation latency overhead on dataset D": "Hard", + "Train language-specific LoRA adapters for {en, es, de} on model M; add an automatic language router and report per-language BLEU and cross-language interference": "Hard", + "Build a speaker diarization + ASR Gradio Space using pyannote and Whisper-large-v3; achieve DER <= 12% and real-time factor <= 0.75 on a T4": "Hard", + "Implement multi-draft speculative decoding with dynamic draft-model selection per prompt; integrate with vLLM and benchmark tokens/sec speedup at batch sizes {1,8,32}": "Very hard", + "Convert a TensorFlow DistilBERT SavedModel to ONNX (opset 17) and validate logits parity (MSE < 1e-4) on 1,000 random inputs; measure CPU inference speedup vs TensorFlow": "Medium", + "Evaluate alignment drift after SFT: compare model M vs base M0 on prompt set P; report win rate, refusal rate, and average output length": "Medium", + "Enable KV cache int4 quantization in vLLM for model M; benchmark tokens/sec and exact match on dataset D vs fp16 cache": "Hard", + "Implement variable-length packing in a HF Datasets + Transformers training loop; ensure epoch-level sample coverage matches baseline and no truncation beyond max_length": "Medium", + "Build a multi-tenant LoRA router over vLLM: on-demand load adapters from the Hub with LRU eviction; sustain 100 tenants and <300 ms adapter swap latency under load": "Very hard", + "Audit generations for PII leakage on prompt set P using detector C; compute precision, recall, and false positive rate; redact before logging and publish a compliance summary": "Medium", + "Merge a stack of PEFT adapters {A_i} into base model M to produce a single FP16 checkpoint; validate perplexity drift <=0.5% on dataset D and export safetensors": "Hard", + "Find three Spaces that demonstrate constrained JSON generation; return Space ids and URLs": "Easy", + "Deploy a cross-lingual vector search service with multilingual-e5-large; shard FAISS across 3 nodes and measure mAP@10 and p95 latency at 500 QPS": "Very hard", + "Quantize attention and MLP projections only with bitsandbytes (selective 8-bit); compare peak VRAM, tokens/sec, and ROUGE-L vs full-model 8-bit on dataset D": "Hard", + "Fix \"Token indices sequence length is longer than the specified maximum\" after tokenizer resize; add truncation with stride and update generation config; verify no validation metric regression": "Medium", + "Identify splits for dataset D and output split names with sample counts": "Easy", + "Find five multilingual sentence-embedding models on the Hub with Apache-2.0 license; return model ids": "Easy", + "Set up CI to run evaluation suite E for model M nightly; fail the job if any metric drops >1% vs 7-day rolling mean": "Medium", + "Add length normalization to beam search for model M; compare vs baseline on dataset D and report ROUGE-L and average output length": "Medium", + "Detect per-sample language for dataset D; add a 'lang' column and recreate train/valid/test splits preserving language proportions": "Medium", + "Benchmark vLLM KV-cache eviction strategies (e.g., LRU vs TTL) for model M at batch sizes {1,8,32}; report tokens/sec and peak VRAM": "Medium", + "Implement a custom DataCollator that packs multiple documents for summarization with separator tokens; add unit tests to prevent cross-sample leakage": "Hard", + "Build a PDF-to-dataset pipeline: OCR pages with model Donut, store word-level bboxes, and publish a HuggingFace Dataset with a viewer Space": "Hard", + "Train a ColBERT reranker on corpus C + pairs dataset D; integrate into a RAG search service and report recall@10 and p95 latency delta": "Hard", + "Deploy vLLM for model M with multi-GPU tensor-parallel inference across 2 nodes using NCCL; demonstrate near-linear throughput scaling and deterministic outputs across 3 seeds": "Very hard", + "List four Hub models tagged 'named-entity-recognition' that declare bitsandbytes 8-bit support in their README; output model ids": "Easy", + "Find three Spaces that provide real-time TTS streaming demos; return Space ids and reported sample rates": "Easy", + "Create a Spaces app that visualizes transformer attention maps for a ViT model using Captum; keep heatmap rendering under 200 ms for 224x224 images": "Medium", + "Set up datasets streaming with resumable downloads and exponential backoff for S3-hosted Parquet shards; verify checksum integrity after killing and resuming the job": "Medium", + "Build a tokenizer migration tool to convert a SentencePiece model to a HuggingFace tokenizers JSON with byte-fallback; assert >=99.95% token-level agreement on 20k sentences and report speed delta": "Medium", + "Implement a custom DataCollator for span masking with variable block sizes for byte-level BPE; add unit tests and demonstrate MLM loss parity over 10k steps on WikiText-103": "Hard", + "Add speculative decoding with a small draft model to a Transformers-based text-generation server; expose a per-request flag and benchmark tokens/sec speedup at batch sizes {1,8,32}": "Hard", + "Train an online knowledge-distillation SFT: teacher M0 -> student M on dataset D; log KL divergence, token agreement, and throughput; cap metric drop at <=2% vs teacher": "Hard", + "Deploy a multi-region vLLM service on Kubernetes with adaptive batching and hot LoRA adapter loading; sustain 200 QPS with p95 latency <300 ms and zero-downtime rollouts": "Very hard", + "Build a sharded cross-encoder reranking service with Ray: distribute ColBERT scoring across nodes, integrate with FAISS retrieval, and maintain recall@10 within 1% of single-node baseline at 500 QPS": "Very hard", + "List four Spaces that perform multilingual OCR with layout extraction; return Space ids and supported languages": "Easy", + "Find five Hub datasets for code generation evaluation with permissive licenses; output dataset ids and license names": "Easy", + "Add gradient accumulation and gradient clipping to a Transformers Trainer finetune of model M; report step time, peak VRAM, and validation metric vs baseline": "Medium", + "Implement document chunking with sliding windows and overlap in a Datasets map pipeline; add doc_id and span indices and verify no segment exceeds max_length": "Medium", + "Export a fine-tuned BERT model to TorchScript and ONNX; verify logits parity (MSE < 1e-4) on 1,000 samples and compare CPU throughput": "Medium", + "Diagnose 'pad_token_id is not set' warnings during generation; add a PAD token, resize embeddings, and write a unit test asserting identical logits pre/post fix on 200 prompts": "Medium", + "Implement diverse beam search (group_beam_search) for model M; evaluate on dataset D and report ROUGE-L, distinct-n, and average output length vs standard beam search": "Hard", + "Build a multi-modal RAG demo that indexes image captions with CLIP and uses LLM M to answer visual questions; report top-1 accuracy and p95 latency": "Hard", + "Profile activation and KV-cache memory during generation for model M; log per-layer footprints and reduce peak usage via attention slicing; show tokens/sec and VRAM deltas": "Hard", + "Construct a 200M-document FAISS hybrid (IVF-PQ + HNSW) index with memory-mapped shards on S3; support live add/delete and benchmark recall@10 and QPS at 300 QPS": "Very hard", + "List five Hub datasets tagged 'topic-modeling' with MIT or Apache-2.0 licenses; output dataset ids": "Easy", + "Find three Spaces that offer real-time grammar correction with streaming tokens; return Space ids and URLs": "Easy", + "Convert a spaCy en_core_web_trf NER model to ONNX and wrap it in a Transformers TokenClassification pipeline; verify entity text/label/span parity on 1,000 sentences": "Medium", + "Set up a GitHub Actions workflow that snapshots tokenizer T weekly and fails if vocab or special token ids drift vs the last snapshot; upload a diff artifact": "Medium", + "Profile a Datasets map pipeline on corpus C; refactor to use batched=True, num_proc>1, and caching; achieve >=2\u00d7 speedup while preserving deterministic ordering across runs": "Medium", + "Implement a custom Transformers StoppingCriteria that halts when JSON braces are balanced or max nesting depth is reached; add unit tests and benchmark latency overhead on dataset D": "Hard", + "Build a visual-and-tabular RAG pipeline: index images with CLIP and CSV tables with TAPAS; answer mixed queries using LLM M; report EM@1 and p95 latency at 50 QPS": "Hard", + "Enable KV-cache int4 quantization during generation in Transformers for model M; compare tokens/sec and exact match vs fp16 cache on dataset D; keep metric drop <=1%": "Hard", + "Implement a hot-reloadable sharded FAISS IVF-PQ index for multilingual-e5-base with live add/delete and background re-training; sustain 200 QPS with p95 latency <400 ms across 3 nodes": "Very hard", + "Deploy a geo-distributed vLLM + LoRA adapter gateway across two regions with consistent hashing and zero-downtime adapter updates; ensure identical outputs across 3 seeds and report cross-region p95 latency": "Very hard", + "List five Hub LLM repos that disclose training token counts in their model cards; output model ids and token totals": "Easy", + "Find two ready-to-use Spaces for speaker diarization compatible with Whisper; return Space ids and URLs": "Easy", + "Create a hashing-based dataset splitter using column 'doc_id' to produce reproducible train/valid/test; verify identical splits across two machines and Python versions": "Medium", + "Resolve HTTP 403 when creating an organization dataset via the Hub API; diagnose token scopes and org permissions; provide a minimal repro script and the fix": "Medium", + "Export a PEFT LoRA adapter from a fine-tuned Llama checkpoint as standalone safetensors with a correct adapter_config.json; push to the Hub and verify PEFT.from_pretrained loads it": "Medium", + "Enable multi-query attention in model M within Transformers; benchmark tokens/sec and peak VRAM vs multi-head attention and verify perplexity parity over 2,000 steps": "Hard", + "Audit code dataset D for contamination against {HumanEval, MBPP} using exact substring and 3-gram Jaccard >= 0.9; publish per-source contamination rates and a cleaned dataset": "Hard", + "Implement contrastive search decoding for model M with tunable alpha; compare ROUGE-L, distinct-n, and latency vs nucleus sampling on dataset D": "Hard", + "Implement pipeline parallelism for model M across 4 GPUs with Accelerate; achieve near-linear scaling (<=15% gap), support checkpoint save/restore, and ensure deterministic outputs across 3 seeds": "Very hard", + "Deploy a Spaces app that serves two ASR models with automatic language ID routing; maintain real-time factor <= 0.6 on a single T4 and log per-language latency": "Hard", + "Benchmark JSON-constrained decoding across models {M_i}; report JSON validity rate, exact match on dataset D, and p95 latency under streaming": "Hard", + "Filter a multilingual dataset D to non-English using fastText language ID; recreate stratified splits and report per-language retention and drop rates": "Medium", + "Enable paged attention in a custom Transformers generation loop for model M; verify token-level parity on 500 prompts and measure peak VRAM change": "Hard", + "Shard a 1B-token text corpus into deterministic HF Datasets processing across 16 workers; validate byte-for-byte identical outputs across two runs": "Very hard", + "Compare LoRA vs QLoRA fine-tunes of Mistral-7B on GSM8K; track loss, exact match, and throughput; select the lowest-VRAM config within 2% EM of best": "Hard", + "Deploy a quantized T5 encoder-decoder on Triton Inference Server via a Python backend; add token streaming and achieve >=1.5x throughput vs PyTorch baseline": "Hard", + "Find three Spaces that perform audio source separation (vocals/music); return Space ids and reported sample rates": "Easy", + "Merge a PEFT IA3 adapter stack into Llama-3-8B base weights; verify perplexity drift <=0.3% on WikiText-103 and export safetensors": "Hard", + "Resolve DeepSpeed ZeRO-3 stalls during S3 checkpointing; implement async multipart uploads and show stable 5-minute checkpoint cadence over 2 hours": "Very hard", + "Set up CI to run contamination checks on dataset R against {TruthfulQA, SQuAD} using 4-gram overlap; fail if rate >0.5% and attach offending ids as artifacts": "Medium", + "List four Hub datasets for sarcasm detection in English; return dataset ids and license tags": "Easy", + "Identify whether tokenizer T enables byte_fallback in tokenizer.json; output true/false and the file path": "Easy", + "Find three Spaces that showcase streaming chat with token-by-token updates; return Space ids and whether they use SSE or websockets": "Easy", + "Create a Datasets loader that parses Praat TextGrid files into word-level timestamps aligned with audio; publish a dataset with an 'audio' column and validate 100 sample alignments": "Medium", + "Set up a GitHub Actions workflow that lints model cards for repos {R_i} to require intended use, training data, and limitations; fail PRs and post a summary comment on violations": "Medium", + "Containerize a Gradio Space with optional FlashAttention build: detect GPU capability at startup, compile kernels if supported, and fall back gracefully on unsupported GPUs; test on T4 and A100": "Medium", + "Evaluate long-context retrieval via needle-in-a-haystack for models {M_i} at context lengths {8k, 32k, 64k}; report retrieval accuracy, tokens/sec, and the max stable context length": "Hard", + "Implement a curriculum sampler as a HuggingFace Trainer callback that schedules sample difficulty over epochs; compare convergence and final eval metrics vs random sampling": "Hard", + "Add on-the-fly near-duplicate filtering during training using SimHash over token ids; log per-epoch removal rates and verify no convergence regressions vs a deduplicated baseline": "Hard", + "Deploy a dual-backend inference router using vLLM and TensorRT-LLM that selects backend per prompt length to minimize latency; maintain deterministic outputs across 3 seeds and sustain 300 QPS with p95 latency SLOs": "Very hard", + "Identify max_position_embeddings and whether rope_scaling is enabled for model M from its config; output both values.": "Easy", + "List five Vision Transformer models on the Hub that provide safetensors and have a default image size >= 384; output model ids.": "Easy", + "Find three Spaces that stream machine-translation outputs token-by-token; return Space ids and whether they use SSE or websockets.": "Easy", + "Diagnose bursts of [UNK] after adding special tokens to tokenizer T; enable byte_fallback, retrain embeddings for 2k steps, and show unknown-token rate <= baseline+0.1% on corpus C.": "Medium", + "Create a dataset viewer Space for a dataset with a nested JSON column; convert to Arrow struct arrays, implement server-side filtering on nested keys, and verify row counts match the source.": "Medium", + "Set up a GitHub Action that hits /health and a no-op inference on Space S after each deploy; fail if cold-start median latency >10s and attach server logs as an artifact.": "Medium", + "Implement a SQL grammar-constrained Transformers LogitsProcessor using an LL(1) parser; evaluate on Spider dev and report exact match and p95 latency overhead vs nucleus sampling.": "Hard", + "Add CPU-tier KV-cache offloading with pinned memory for model M in a custom generation loop; compare tokens/sec and peak VRAM vs baseline at context lengths {4k, 16k, 32k}.": "Hard", + "Deploy a batched cross-encoder reranker microservice using bge-reranker-base; keep recall@10 within 1% of single-request baseline and achieve >=2\u00d7 QPS at 100 concurrent users.": "Hard", + "Build a heterogeneous inference gateway that routes requests to vLLM or llama.cpp based on prompt length and GPU load; ensure identical normalized outputs across 3 seeds and sustain 200 QPS with p95 latency <300 ms.": "Very hard", + "Determine whether tokenizer T strips accents (strip_accents); output true/false and the file path where the setting is defined.": "Easy", + "List four Hub datasets for hate-speech detection in English; return dataset ids and license tags.": "Easy", + "Write a Datasets loader for a paginated OAuth2 REST API; cache pages, support streaming, and provide deterministic sharding across 8 workers; verify identical row counts across two runs.": "Medium", + "Add request-level caching (ETag/If-None-Match) to a Gradio summarization Space; achieve >=1.8\u00d7 QPS at 50 concurrent users and report cache hit ratio and p95 latency.": "Medium", + "Enable HuggingFace tokenizers parallelism and batched encoding for corpus C; benchmark throughput and memory on 10M lines and ensure deterministic outputs across 3 runs.": "Medium", + "Set up CI to lint dataset cards in repos {R_i} for required fields {license, citation, dataset_summary}; fail PRs and post a summary comment with missing keys.": "Medium", + "Run a parameter-efficient finetuning sweep comparing LoRA, IA3, and prefix-tuning on RoBERTa-base for MNLI; report accuracy, training time, and peak VRAM; select a Pareto-optimal config.": "Hard", + "Implement a Transformers LogitsProcessor that enforces balanced parentheses and proper quoted-string escaping; add unit tests and benchmark latency overhead on dataset D.": "Hard", + "Export Whisper-medium to ONNX with dynamic axes and int8 weights; verify word-timestamp parity on 500 clips and measure CPU real-time factor improvement >=1.3\u00d7 vs PyTorch.": "Hard", + "Deploy a geo-replicated RAG service: shard FAISS HNSW across three regions with conflict-free index metadata sync; sustain 300 QPS with p95 latency <450 ms and recall@10 within 1% of single-region baseline.": "Very hard", + "Compare cased vs uncased tokenization for BERT on CoNLL-2003 NER; train both, and report F1, average tokens per sentence, and training time.": "Medium", + "Create a HuggingFace Datasets loader for EPUB files: extract chapter text and embedded images into Arrow columns, support streaming and deterministic sharding across 8 workers; verify identical row counts across two runs.": "Medium", + "Configure a Hub webhook to trigger CI when a model card (README.md) changes; fail the job if sections {intended use, limitations} are missing and post a checklist comment on the PR.": "Medium", + "Add a reranking cache to a RAG service keyed by (query, candidate_ids); achieve >=50% cache hit at 100 QPS and keep recall@10 within 0.5% of baseline.": "Hard", + "Fix torch.compile graph breaks in a Transformers training loop; patch non-compilable ops, re-enable compilation, and demonstrate >=1.4\u00d7 step-time speedup with matching loss over 2,000 steps.": "Hard", + "Compute 95% bootstrap confidence intervals for ROUGE-L on dataset D over 3 random seeds; flag regressions when the new CI lies entirely below last week's baseline CI.": "Medium", + "Build a batch image-captioning Space with ViT-GPT2: accept ZIP uploads, use queue-based batching, and keep p95 latency <2s for 32 images.": "Medium", + "Implement hybrid parallelism (tensor + pipeline) for a 13B encoder-decoder using Accelerate; scale across 8 GPUs with <=15% gap from linear, support elastic resize (8->6 GPUs) without losing determinism, and verify checkpoint save/restore.": "Very hard", + "Find five Spaces that stream live vision-language captioning (e.g., LLaVA or BLIP); return Space ids and reported FPS.": "Easy", + "Identify whether tokenizer T applies Unicode normalization (NFKC/NFC/NFD/NFKD) and where it is configured; output the mode and file path.": "Easy", + "Identify whether model repo M stores weights exclusively as safetensors; output true/false and list the .safetensors file paths.": "Easy", + "List three multilingual sentence-embedding models on the Hub that provide ONNX exports; return model ids.": "Easy", + "Determine if tokenizer T lowercases text (do_lower_case or lowercase flag); output true/false and the file path or JSON key where it is set.": "Easy", + "Set up a GitHub Action to run a smoke-test text generation for model M on each push; fail if median time to first token >2s and attach container logs as an artifact.": "Medium", + "Create a Datasets preprocessing pipeline that tokenizes to max_length=512 with stride=64 and retains an 'orig_text' column; verify row counts match input and no NaNs after caching.": "Medium", + "Resolve 'git-lfs: command not found' when pushing model repo R to the Hub; install and configure Git LFS, set an appropriate large file threshold, and provide a minimal repro plus the verified fix.": "Medium", + "Enable KV-cache CPU offloading in a custom Transformers generation loop for model M; benchmark tokens/sec and peak VRAM vs baseline at context lengths {4k, 8k}.": "Hard", + "Implement LoRA rank warmup (r: 4\u219232 over the first 1,000 steps) in a custom Trainer; fine-tune model M on dataset D and report validation perplexity and peak VRAM vs fixed r=32.": "Hard", + "Export Whisper-small to TensorRT via ONNX (opset 18) with dynamic axes; verify word-timestamp parity (median diff \u22640.05s) on 300 clips and measure \u22651.3\u00d7 GPU speedup vs PyTorch.": "Hard", + "Deploy a multi-tenant RAG service that hot-loads per-tenant FAISS indices from S3, shares a reranker, and sustains 200 QPS with p95 latency <350 ms across 1,000 tenants; maintain recall@10 within 1% of a single-tenant baseline.": "Very hard" +} \ No newline at end of file diff --git a/eval/hf_agent_connector.py b/eval/hf_agent_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..3740d32abe22b93911df8104650dded656d89d7a --- /dev/null +++ b/eval/hf_agent_connector.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import asyncio +import sys +from pathlib import Path +from typing import Any + +from lmnr import observe + +from agent.config import Config, load_config +from agent.core.agent_loop import Handlers +from agent.core.session import Session +from agent.core.tools import ToolRouter + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def _resolve_project_path(path: str | Path) -> Path: + candidate = Path(path) + if candidate.is_absolute(): + return candidate + return (PROJECT_ROOT / candidate).resolve() + + +class AgentResponseGenerator: + """ + Thin async wrapper that executes the existing agent loop once and + returns the assistant's final message. + """ + + def __init__(self, config_path: str | Path, max_iterations: int = 10) -> None: + self.config_path = _resolve_project_path(config_path) + self.config: Config = load_config(str(self.config_path)) + self.max_iterations = max_iterations + + @property + def model_name(self) -> str: + """Expose the agent model name for downstream logging.""" + return self.config.model_name + + @observe(name="eval_run") + async def run(self, prompt: str) -> str: + """ + Execute the agent loop for a single prompt and return the assistant reply. + """ + tool_router = ToolRouter(self.config.mcpServers) + + async with tool_router: + session = Session(asyncio.Queue(), config=self.config) + session.tool_router = tool_router + await Handlers.run_agent( + session, + prompt, + max_iterations=self.max_iterations, + ) + return self._latest_assistant_response(session) + + def _latest_assistant_response(self, session: Session) -> str: + """ + Extract the final assistant response from the session history. + """ + for message in reversed(session.context_manager.items): + if getattr(message, "role", None) == "assistant": + return _content_to_text(getattr(message, "content", "")) + + raise RuntimeError("Agent did not produce an assistant message.") + + +def _content_to_text(content: Any) -> str: + """ + Convert LiteLLM content payloads (str or list[dict]) into plain text. + """ + if isinstance(content, str): + return content + + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, dict): + text = block.get("text") + if text: + parts.append(str(text)) + else: + text = getattr(block, "text", None) + if text: + parts.append(str(text)) + return "\n".join(parts) + + return str(content) diff --git a/eval/hf_io.py b/eval/hf_io.py new file mode 100644 index 0000000000000000000000000000000000000000..0f26899ce70ed5490757fca69a12b802bfa76b35 --- /dev/null +++ b/eval/hf_io.py @@ -0,0 +1,215 @@ +""" +HuggingFace Dataset I/O Utilities + +Reusable functions for uploading and downloading JSONL data to/from HuggingFace Hub. +Supports the dataset_name@config_name notation for managing multiple configurations. +""" + +from typing import List, Optional + +import pandas as pd +from datasets import Dataset, load_dataset + + +def list_dataset_configs(dataset_name: str) -> Optional[List[str]]: + """ + List all available configs for a dataset on HuggingFace Hub. + + Args: + dataset_name: Name of the dataset (e.g., "username/my-dataset") + + Returns: + List of config names, or None if unable to retrieve + + Example: + >>> configs = list_dataset_configs("username/hf-agent-benchmark") + >>> print(configs) + ['default', 'rubrics', 'evaluations'] + """ + try: + from datasets import get_dataset_config_names + + configs = get_dataset_config_names(dataset_name) + return configs + except Exception as e: + print(f"βœ— Failed to list configs: {type(e).__name__}: {str(e)}") + return None + + +def df_to_hub( + df: pd.DataFrame, + dataset_spec: str, + split: str = "train", + private: bool = False, +) -> bool: + """ + Upload a pandas DataFrame directly to HuggingFace Hub as a dataset. + + This function converts a pandas DataFrame to a HuggingFace Dataset and uploads + it to the Hub. This is useful for uploading data directly without creating an + intermediate JSONL file. + + Args: + df: pandas DataFrame to upload. All column types should be serializable. + Example DataFrame: + ``` + | question | solution | rubric | + |----------|----------|--------| + | "How..." | "You..." | {...} | + ``` + + dataset_spec: Dataset specification in the format "dataset_name" or + "dataset_name@config_name". Examples: + - "username/my-dataset" (uses "default" config) + - "username/my-dataset@rubrics" (uses "rubrics" config) + - "username/my-dataset@evaluations" (uses "evaluations" config) + + split: The dataset split name. Defaults to "train". Common values: + - "train": Training or main data + - "validation": Validation data + - "test": Test data + + private: Whether to create a private dataset. Defaults to False (public). + + Returns: + bool: True if upload succeeded, False otherwise + + Raises: + ValueError: If DataFrame is empty + Exception: For HuggingFace Hub upload errors + + Example: + >>> import pandas as pd + >>> df = pd.DataFrame({ + ... "question": ["How to train?", "What is fine-tuning?"], + ... "solution": ["Use trainer...", "Fine-tuning is..."], + ... "rubric": ['[{"title": "...", ...}]', '[{"title": "...", ...}]'] + ... }) + >>> upload_dataframe_to_hf(df, "username/dataset@rubrics") + + Notes: + - Requires authentication via `huggingface-cli login` or HF_TOKEN env var + - DataFrame columns with complex objects should be serialized first (e.g., to JSON strings) + - If the dataset doesn't exist, it will be created automatically + - Empty DataFrames will raise ValueError to prevent uploading invalid data + """ + # Validate DataFrame + if df.empty: + raise ValueError("DataFrame is empty") + + # Parse dataset specification + if "@" in dataset_spec: + dataset_name, config_name = dataset_spec.split("@", 1) + else: + dataset_name = dataset_spec + config_name = "default" + + try: + print("\nUploading DataFrame to HuggingFace Hub...") + print(f" Dataset: {dataset_name}") + print(f" Config: {config_name}") + print(f" Split: {split}") + print(f" Rows: {len(df)}") + print(f" Columns: {list(df.columns)}") + + # Convert DataFrame to HuggingFace Dataset + dataset = Dataset.from_pandas(df) + + # Upload to HuggingFace Hub + dataset.push_to_hub( + dataset_name, + config_name=config_name, + split=split, + private=private, + ) + + print( + f"βœ“ Successfully uploaded to {dataset_name}@{config_name} (split: {split})" + ) + return True + + except Exception as e: + print(f"βœ— Failed to upload to HuggingFace: {type(e).__name__}: {str(e)}") + return False + + +def hub_to_df( + dataset_spec: str, + split: str = "train", +) -> Optional[pd.DataFrame]: + """ + Download a dataset from HuggingFace Hub as a pandas DataFrame. + + This function downloads a dataset from the HuggingFace Hub and returns it as a + pandas DataFrame for immediate use in Python. + + Args: + dataset_spec: Dataset specification in the format "dataset_name" or + "dataset_name@config_name". Examples: + - "username/my-dataset" (uses "default" config) + - "username/my-dataset@rubrics" (uses "rubrics" config) + - "username/my-dataset@evaluations" (uses "evaluations" config) + + split: The dataset split to download. Defaults to "train". Common values: + - "train": Training or main data + - "validation": Validation data + - "test": Test data + + Returns: + pd.DataFrame: Downloaded data as pandas DataFrame, or None if failed + + Raises: + ValueError: If the dataset/config/split doesn't exist + Exception: For HuggingFace Hub download errors + + Example: + >>> # Download rubrics from specific config + >>> df = hub_to_df("username/hf-agent-benchmark@rubrics") + >>> print(df.head()) + >>> print(f"Shape: {df.shape}") + + >>> # Download evaluation results + >>> results_df = download_hf_to_dataframe( + ... "username/hf-agent-benchmark@evaluations", + ... split="test" + ... ) + + Notes: + - Requires authentication for private datasets via `huggingface-cli login` + - Downloaded data will be in the same format as uploaded (preserves structure) + - Large datasets may take time to download and consume significant memory + - For very large datasets, consider using streaming or download_hf_to_jsonl + """ + # Parse dataset specification + if "@" in dataset_spec: + dataset_name, config_name = dataset_spec.split("@", 1) + else: + dataset_name = dataset_spec + config_name = "default" + + try: + print("\nDownloading from HuggingFace Hub...") + print(f" Dataset: {dataset_name}") + print(f" Config: {config_name}") + print(f" Split: {split}") + + # Download dataset from HuggingFace Hub + dataset = load_dataset( + dataset_name, + name=config_name, + split=split, + ) + + print(f" Downloaded {len(dataset)} records") + + # Convert to pandas DataFrame + df = dataset.to_pandas() + + print("βœ“ Successfully loaded as DataFrame") + print(f" Shape: {df.shape}") + print(f" Columns: {list(df.columns)}") + return df + + except Exception as e: + print(f"βœ— Failed to download from HuggingFace: {type(e).__name__}: {str(e)}") + return None diff --git a/eval/leaderboard.py b/eval/leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..00444bc342df6b398365270961d9987e2aad981e --- /dev/null +++ b/eval/leaderboard.py @@ -0,0 +1,172 @@ +""" +Utilities for logging solver scores to a Hugging Face dataset. +""" + +from __future__ import annotations + +import json +import re +import shutil +import subprocess +import tempfile +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from huggingface_hub import HfApi, hf_hub_download + +AVERAGE_RE = re.compile(r"Average normalized score:\s*([0-9.]+)") +DEFAULT_FILENAME = "records.jsonl" + + +def _hydra_join(*parts: str | None) -> str: + tokens = [str(part).strip().replace(" ", "_") for part in parts if part] + return "/".join(tokens) if tokens else "default" + + +def detect_agent_version(config_path: str = "agent/config_mcp_example.json") -> str: + """ + Returns a short string identifying the current agent version: + -. + """ + + try: + commit = ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode() + .strip() + ) + except Exception: + commit = "unknown" + + config_file = Path(config_path) + config_stem = config_file.stem or "config" + parent_name = config_file.parent.name if config_file.parent.name else None + return _hydra_join(parent_name, config_stem, commit) + + +def parse_average_score(text: str) -> float | None: + """Extracts the 'Average normalized score' value from Inspect logs.""" + + match = AVERAGE_RE.search(text) + if match: + try: + return float(match.group(1)) + except ValueError: + return None + return None + + +def latest_log_file( + log_dir: Path, extensions: tuple[str, ...] = (".eval", ".json") +) -> Path | None: + """Returns the most recent log file in log_dir matching the provided extensions.""" + + if not log_dir.exists(): + return None + + files: list[Path] = [] + for ext in extensions: + files.extend(log_dir.glob(f"*{ext}")) + + if not files: + return None + + files.sort(key=lambda path: path.stat().st_mtime) + return files[-1] + + +@dataclass +class LeaderboardClient: + """Simple helper to append JSONL rows to a HF dataset.""" + + repo_id: str + token: str + filename: str = DEFAULT_FILENAME + + def append_record(self, record: dict[str, Any]) -> None: + tmp_dir = Path(tempfile.mkdtemp(prefix="leaderboard_")) + local_file = tmp_dir / self.filename + + self._download_existing(local_file) + if not local_file.exists(): + local_file.write_text("", encoding="utf-8") + + with local_file.open("a", encoding="utf-8") as fh: + fh.write(json.dumps(record) + "\n") + + HfApi(token=self.token).upload_file( + path_or_fileobj=str(local_file), + path_in_repo=self.filename, + repo_id=self.repo_id, + repo_type="dataset", + ) + + try: + local_file.unlink() + tmp_dir.rmdir() + except OSError: + pass + + def _download_existing(self, destination: Path) -> None: + destination.parent.mkdir(parents=True, exist_ok=True) + + try: + downloaded = hf_hub_download( + repo_id=self.repo_id, + filename=self.filename, + repo_type="dataset", + token=self.token, + ) + shutil.copy(Path(downloaded), destination) + except Exception: + destination.write_text("", encoding="utf-8") + + +def build_record( + solver_name: str, + solver_kwargs: dict[str, Any], + dataset_name: str, + dataset_split: str, + limit: int | None, + score: float, + command: list[str], + log_path: Path | None, + criterion_checks: list[dict[str, Any]] | None = None, +) -> dict[str, Any]: + """Assembles a JSON-serialisable record for the leaderboard dataset.""" + + record = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "solver": solver_name, + "solver_kwargs": solver_kwargs, + "dataset_name": dataset_name, + "dataset_split": dataset_split, + "limit": limit, + "score": score, + "command": command, + } + + if solver_name == "hf_agent": + record["solver_version"] = detect_agent_version( + solver_kwargs.get("config_path", "agent/config_mcp_example.json") + ) + else: + version_spec = solver_kwargs.get("version") + if isinstance(version_spec, (list, tuple)): + record["solver_version"] = _hydra_join(*version_spec) + elif isinstance(version_spec, dict): + record["solver_version"] = _hydra_join( + *[f"{k}={v}" for k, v in version_spec.items()] + ) + elif isinstance(version_spec, str): + record["solver_version"] = version_spec + else: + record["solver_version"] = _hydra_join(solver_name, "default") + + if log_path: + record["log_artifact"] = str(log_path) + record["criterion_checks"] = criterion_checks or [] + + return record diff --git a/eval/models.py b/eval/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d58f7f2478ebba8cb6cdbfb48cd2700f2322f77e --- /dev/null +++ b/eval/models.py @@ -0,0 +1,63 @@ +"""Shared data models for the HF agent project""" + +from datetime import datetime +from enum import Enum + +from pydantic import BaseModel, Field + + +class Discussion(BaseModel): + """Model for a discussion thread""" + + title: str + url: str + topic_id: int + category: int + created_at: datetime + + +class QuestionAndSolution(BaseModel): + """Model for a QA pair from a discussion""" + + discussion_title: str + discussion_url: str + discussion_topic_id: int + discussion_category: int + discussion_created_at: datetime + thread: list[dict] + question: str + solution: str + + +class Correctness(str, Enum): + yes = "yes" + no = "no" + + +class JudgementResult(BaseModel): + """Structured output for LLM judge evaluation""" + + extracted_final_answer: str = Field( + description="The final exact/snippet answer extracted from the response" + ) + reasoning: str = Field( + description="Explanation of why the answer is correct or incorrect" + ) + correct: Correctness = Field(description="'yes' if correct, 'no' if incorrect") + confidence: int = Field( + description="Confidence score between 0 and 100", ge=0, le=100 + ) + + +class EvaluationResult(BaseModel): + """Model for evaluation results including metadata""" + + success: bool + judgement: JudgementResult | None = None + error: str | None = None + + +class EvaluatedQuestionAndSolution(QuestionAndSolution): + """Model for a QA pair with its evaluation result""" + + evaluation: JudgementResult diff --git a/eval/rubric_eval.py b/eval/rubric_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ce706de91c531f862d50af8f39cfd9a4f32e01da --- /dev/null +++ b/eval/rubric_eval.py @@ -0,0 +1,142 @@ +""" +Rubric-based evaluation following the "Rubrics as Rewards" paper. + +Implements RaR-Explicit: Weighted sum of individual criterion scores (Equation 1) +""" + +from typing import List, Optional + +import litellm +from pydantic import BaseModel + + +class CriterionCheck(BaseModel): + """Result of checking a single rubric criterion.""" + + title: str + description: str + weight: int + satisfied: bool + reasoning: Optional[str] = None + + +class RubricEvaluation(BaseModel): + """Complete rubric-based evaluation result.""" + + criterion_checks: List[CriterionCheck] + raw_score: float # Unnormalized score + normalized_score: float # Score normalized to [0, 1] + + +CRITERION_PROMPT = """You are evaluating whether a response satisfies a specific evaluation criterion. + +Question: {question} + +Response to evaluate: {response} + +Evaluation Criterion: +{criterion_description} + +Your task: Determine if the response satisfies this criterion. + +Output a JSON object with: +- "satisfied": true or false +- "reasoning": Brief explanation (1-2 sentences) of why it does or doesn't satisfy the criterion + +Be strict but fair. The criterion must be clearly satisfied for you to answer true.""" + + +class RubricData(BaseModel): + """Rubric data loaded from file.""" + + title: str + description: str + weight: int + + +def check_criterion( + question: str, response: str, criterion: RubricData, model: str = "gpt-4o-mini" +) -> CriterionCheck: + """ + Check if response satisfies a single criterion. + + Args: + question: The question being answered + response: The response to evaluate + criterion: The rubric criterion to check + model: LLM model for judging + + Returns: + CriterionCheck with satisfaction result + """ + prompt = CRITERION_PROMPT.format( + question=question, + response=response, + criterion_description=criterion.description, + ) + + llm_response = litellm.completion( + model=model, + messages=[ + { + "role": "system", + "content": "You are an expert evaluator for rubric-based assessment.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.0, + response_format=CriterionCheck, + ) + + result = CriterionCheck.model_validate_json(llm_response.choices[0].message.content) + + return result + + +def evaluate_with_rubrics( + question: str, + response: str, + rubrics: List[RubricData], + model: str = "gpt-5-nano", +) -> RubricEvaluation: + """ + Evaluate response using RaR-Explicit method (weighted sum). + + Implements Equation 1 from paper: + r(x, Ε·) = Ξ£(w_j * c_j(x, Ε·)) / Ξ£(w_j) + + Args: + question: The question + response: Response to evaluate + reference_answer: Reference answer (not directly used, but available) + rubrics: List of rubric criteria + model: LLM model for judging + + Returns: + RubricEvaluation with normalized score + """ + # Check each criterion independently + checks = [] + for rubric in rubrics: + check = check_criterion(question, response, rubric, model) + checks.append(check) + + # Calculate weighted score (Equation 1) + # Only positive weights contribute to denominator + positive_weights = sum(abs(r.weight) for r in rubrics if r.weight > 0) + + raw_score = 0.0 + for check in checks: + if check.satisfied: + raw_score += check.weight + + # Normalize to [0, 1] + normalized_score = raw_score / positive_weights if positive_weights > 0 else 0.0 + # Clip to [0, 1] in case pitfalls make it negative + normalized_score = max(0.0, min(1.0, normalized_score)) + + return RubricEvaluation( + raw_score=raw_score, + normalized_score=normalized_score, + criterion_checks=checks, + ) diff --git a/eval/run_eval_with_leaderboard.py b/eval/run_eval_with_leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..031246179bf2f2eb8ae1b84b64985e82fd01bebe --- /dev/null +++ b/eval/run_eval_with_leaderboard.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import sys +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv +from leaderboard import LeaderboardClient, build_record, latest_log_file + +load_dotenv() + + +def run_command(cmd: list[str]) -> subprocess.CompletedProcess[str]: + print(f"[leaderboard] running: {' '.join(cmd)}") + return subprocess.run(cmd, capture_output=True, text=True) + + +def build_inspect_command(args: argparse.Namespace) -> list[str]: + cmd = [] + cmd.extend(args.inspect_launch) + cmd.append(args.inspect_task) + + def add_task_arg(key: str, value: Any) -> None: + if value is None: + return + cmd.extend(["-T", f"{key}={value}"]) + + add_task_arg("solver_name", args.solver_name) + add_task_arg("solver_kwargs", json.dumps(args.solver_kwargs)) + add_task_arg("dataset_name", args.dataset) + if args.limit is not None: + add_task_arg("limit", args.limit) + + cmd.extend(["--log-dir", args.log_dir]) + if args.log_format: + cmd.extend(["--log-format", args.log_format]) + + if args.extra_inspect_args: + cmd.extend(args.extra_inspect_args) + + return cmd + + +def parse_score_from_outputs(log_dir: Path) -> tuple[float, Path, list[dict[str, Any]]]: + log_path = latest_log_file(log_dir) + if not log_path: + raise RuntimeError("Inspect log file not found.") + + # Sanitization + content = log_path.read_text(encoding="utf-8") + # Regex to match hf_ followed by 34 alphanumeric chars + sanitized_content = re.sub(r"hf_[a-zA-Z0-9]{34}", "", content) + + if content != sanitized_content: + log_path.write_text(sanitized_content, encoding="utf-8") + print(f"[leaderboard] Redacted HF tokens in {log_path}") + content = sanitized_content + + data = json.loads(content) + results = data.get("results", {}) + scores = results.get("scores", []) + score_value = None + criterion_checks: list[dict[str, Any]] = [] + + for score_entry in scores: + metrics = score_entry.get("metrics", {}) + for metric in metrics.values(): + value = metric.get("value") + if isinstance(value, (int, float)): + score_value = float(value) + break + if score_value is not None: + break + + if score_value is None: + raise RuntimeError("Could not find a numeric metric value in the Inspect log.") + + for sample in data.get("samples", []): + # Grab the question from metadata (fallback to input) + question = "Unknown Question" + if "metadata" in sample and "question" in sample["metadata"]: + question = sample["metadata"]["question"] + elif "input" in sample: + question = sample["input"] + + # Check if any scorer produced criterion_checks + for scorer in sample.get("scores", {}).values(): + metadata = scorer.get("metadata") or {} + checks = metadata.get("criterion_checks") + + if isinstance(checks, list) and checks: + # Create a grouped entry for this question/sample + grouped_entry = {"question": question, "checks": []} + for check in checks: + if isinstance(check, dict): + grouped_entry["checks"].append(check) + + if grouped_entry["checks"]: + criterion_checks.append(grouped_entry) + + return score_value, log_path, criterion_checks + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run Inspect eval and append the resulting score to a HF dataset." + ) + parser.add_argument( + "--hf-dataset", + default="akseljoonas/hf-agent-leaderboard", + help="HF dataset repo id for the leaderboard (e.g. user/leaderboard).", + ) + + parser.add_argument( + "--solver-name", + required=True, + help="Solver name used in the Inspect task (e.g. hf_agent).", + ) + parser.add_argument( + "--solver-kwargs", + type=json.loads, + default="{}", + help="JSON string with solver kwargs passed to the Inspect task.", + ) + parser.add_argument( + "--dataset", + default="akseljoonas/hf-agent-rubrics@train", + help="Dataset spec in the form author/dataset@split.", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Optional sample limit passed to Inspect.", + ) + parser.add_argument( + "--inspect-task", + default="eval/task.py@hf-benchmark-with-rubrics", + help="Inspect task reference.", + ) + parser.add_argument( + "--inspect-launch", + nargs="+", + default=["uv", "run", "inspect", "eval"], + help="Command used to invoke Inspect (default: uv run inspect eval).", + ) + parser.add_argument( + "--log-dir", + default="logs/leaderboard", + help="Directory where Inspect outputs .eval logs.", + ) + parser.add_argument( + "--extra-inspect-args", + nargs="*", + help="Additional args forwarded to Inspect after the standard task arguments.", + ) + parser.add_argument( + "--log-format", + default="json", + help="Log format passed to Inspect (default: json).", + ) + + args = parser.parse_args() + + if isinstance(args.solver_kwargs, str): + args.solver_kwargs = json.loads(args.solver_kwargs or "{}") + + hf_token = os.getenv("HF_TOKEN") + if not hf_token: + print("ERROR: set HF_TOKEN in your environment.", file=sys.stderr) + sys.exit(1) + + if "@" not in args.dataset: + raise ValueError("Dataset must be in the format 'author/dataset@split'.") + dataset_name, dataset_split = args.dataset.split("@", 1) + + log_dir = Path(args.log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + + inspect_cmd = build_inspect_command(args) + result = run_command(inspect_cmd) + + if result.returncode != 0: + print(result.stdout) + print(result.stderr, file=sys.stderr) + raise SystemExit(result.returncode) + + score, log_path, criterion_checks = parse_score_from_outputs(log_dir) + + client = LeaderboardClient(repo_id=args.hf_dataset, token=hf_token) + record = build_record( + solver_name=args.solver_name, + solver_kwargs=args.solver_kwargs, + dataset_name=dataset_name, + dataset_split=dataset_split, + limit=args.limit, + score=score, + command=inspect_cmd, + log_path=log_path, + criterion_checks=criterion_checks, + ) + client.append_record(record) + + print( + f"[leaderboard] recorded score {score:.3f} for solver '{args.solver_name}' to {args.hf_dataset}" + ) + + +if __name__ == "__main__": + main() diff --git a/eval/scrape_discussions/discussions_scraper.py b/eval/scrape_discussions/discussions_scraper.py new file mode 100644 index 0000000000000000000000000000000000000000..506cf97302e033e0e63fa3a125c0d6ba0b438f6b --- /dev/null +++ b/eval/scrape_discussions/discussions_scraper.py @@ -0,0 +1,98 @@ +import sys +import time +from pathlib import Path + +import requests +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +# Add parent directory to path to import models +sys.path.insert(0, str(Path(__file__).parent.parent)) +from models import Discussion, QuestionAndSolution + +BASE_URL = "https://discuss.huggingface.co" + + +# configure retry decorator for your requests +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=1, max=60), + retry=retry_if_exception_type(requests.HTTPError), +) +def safe_get(url, **kwargs): + resp = requests.get(url, **kwargs) + if resp.status_code == 422: + # read retry‐after header if present + retry_after = resp.headers.get("Retry-After") + if retry_after: + delay = float(retry_after) + else: + # fallback to guess + delay = 30 + print(f"429 hit β€” waiting {delay} seconds...") + time.sleep(delay) + resp.raise_for_status() + else: + resp.raise_for_status() + return resp + + +def get_solved_discussions(n_posts: int = 50): + page = 1 + discussions = [] + while len(discussions) < n_posts: + url = f"{BASE_URL}/search.json?q=status:solved+order:latest&page={page}" + resp = safe_get(url) + topics = resp.json()["topics"] + if not topics: + break + for post in topics: + discussions.append( + Discussion( + title=post["fancy_title"], + url=f"{BASE_URL}/t/{post['slug']}/{post['id']}", + topic_id=post["id"], + category=post["category_id"], + created_at=post["created_at"], + ) + ) + if len(discussions) >= n_posts: + break + page += 1 + time.sleep(0.5) # simple pacing to avoid bursts + return discussions + + +def get_qa_pair(discussions, start_idx: int = 0): + for discussion in discussions[start_idx:]: + resp = safe_get(discussion.url + ".json") + data = resp.json() + posts = data["post_stream"]["posts"] + accepted_nr = min( + max(data["accepted_answer"]["post_number"] - 1, 0), len(posts) - 1 + ) + question = posts[0]["cooked"] + solution = posts[accepted_nr]["cooked"] + yield QuestionAndSolution( + discussion_title=discussion.title, + discussion_url=discussion.url, + discussion_topic_id=discussion.topic_id, + discussion_category=discussion.category, + discussion_created_at=discussion.created_at, + question=question, + solution=solution, + thread=posts, + ) + time.sleep(0.5) + + +if __name__ == "__main__": + discussions = get_solved_discussions(n_posts=300) + print(f"Fetched {len(discussions)} discussions") + with open("qa_pairs.jsonl", "a") as f: + for qa_pair in get_qa_pair(discussions): + f.write(qa_pair.model_dump_json() + "\n") diff --git a/eval/solvers.py b/eval/solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..9a48eed88f74bf60a7c34ebffd0a324556841220 --- /dev/null +++ b/eval/solvers.py @@ -0,0 +1,170 @@ +""" +Collection of Inspect AI solvers used by the rubric task. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import tempfile +from typing import Callable, Dict, List, Sequence + +import litellm +from inspect_ai.model import ChatMessageAssistant, ModelOutput +from inspect_ai.solver import Solver, solver +from inspect_ai.solver._task_state import TaskState +from lmnr import Laminar, LaminarLiteLLMCallback + +from eval.hf_agent_connector import AgentResponseGenerator + + +async def _run_subprocess(command: Sequence[str]) -> str: + process = await asyncio.create_subprocess_exec( + *command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + if process.returncode != 0: + raise RuntimeError( + f"Command {' '.join(command)} failed with code {process.returncode}:\n" + f"{stderr.decode().strip()}" + ) + return stdout.decode().strip() + + +@solver(name="hf_agent") +def hf_agent( + config_path: str = "agent/config_mcp_example.json", + max_iterations: int = 10, +) -> Solver: + # init lmnr for observability + Laminar.initialize(project_api_key=os.environ.get("LMNR_API_KEY")) + litellm.callbacks = [LaminarLiteLLMCallback()] + print("βœ… Laminar initialized") + + runner = AgentResponseGenerator( + config_path=config_path, + max_iterations=max_iterations, + ) + + async def solve(state: TaskState, generate) -> TaskState: + response = await runner.run(state.input_text) + assistant_message = ChatMessageAssistant( + content=response, + model=runner.model_name, + source="generate", + ) + state.messages.append(assistant_message) + state.output = ModelOutput.from_message(assistant_message) + state.completed = True + return state + + return solve + + +@solver(name="claude_code") +def claude_code( + output_format: str = "json", + mcp_config: str | None = None, +) -> Solver: + if output_format not in {"text", "json", "stream-json"}: + raise ValueError("output_format must be one of: text, json, stream-json") + + async def solve(state: TaskState, generate) -> TaskState: + prompt = state.input_text + + cmd: List[str] = ["claude", "-p", prompt, "--output-format", output_format] + if mcp_config: + cmd += ["--mcp-config", mcp_config] + + stdout = await _run_subprocess(cmd) + response_text = stdout + session_id = None + + if output_format in {"json", "stream-json"}: + # stream-json may emit multiple JSON objects; take the last complete line + candidate_line = stdout.strip().splitlines()[-1] + try: + payload = json.loads(candidate_line) + response_text = ( + payload.get("result") or payload.get("message", "") or stdout + ) + session_id = payload.get("session_id") + except (json.JSONDecodeError, AttributeError): + response_text = stdout + + assistant_message = ChatMessageAssistant( + content=response_text, + model="claude-code", + source="generate", + metadata={"session_id": session_id} if session_id else None, + ) + state.messages.append(assistant_message) + state.output = ModelOutput.from_message(assistant_message) + state.completed = True + return state + + return solve + + +@solver(name="claude_code+hf_mcp") +def claude_code_hf_mcp( + output_format: str = "json", + hf_token: str | None = None, +) -> Solver: + """ + A solver that uses Claude Code with the Hugging Face MCP server. + Requires HF_TOKEN in environment variables or passed as argument. + """ + token = hf_token or os.environ.get("HF_TOKEN") + if not token: + raise ValueError( + "HF_TOKEN not found. Please set HF_TOKEN env var or pass it to the solver." + ) + + # Construct the MCP configuration for Hugging Face + mcp_config = { + "mcpServers": { + "huggingface": { + "type": "http", + "url": "https://huggingface.co/mcp", + "headers": {"Authorization": f"Bearer {token}"}, + } + } + } + + async def solve(state: TaskState, generate) -> TaskState: + # Write config to a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: + json.dump(mcp_config, tmp, indent=2) + tmp_path = tmp.name + + try: + # Delegate to the base claude_code solver + delegate = claude_code(output_format=output_format, mcp_config=tmp_path) + return await delegate(state, generate) + finally: + # Clean up the temporary file + if os.path.exists(tmp_path): + os.remove(tmp_path) + + return solve + + +SOLVER_REGISTRY: Dict[str, Callable[..., Solver]] = { + "hf_agent": hf_agent, + "claude_code": claude_code, + "claude_code+hf_mcp": claude_code_hf_mcp, +} + + +def get_solver(name: str, **kwargs) -> Solver: + try: + factory = SOLVER_REGISTRY[name] + except KeyError as exc: + available = ", ".join(sorted(SOLVER_REGISTRY)) + raise ValueError(f"Unknown solver '{name}'. Available: {available}") from exc + + return factory(**kwargs) diff --git a/eval/task.py b/eval/task.py new file mode 100644 index 0000000000000000000000000000000000000000..134e2119d597ceb9c97666bae161549f766cd4cd --- /dev/null +++ b/eval/task.py @@ -0,0 +1,121 @@ +""" +Inspect AI task definition that runs the existing agent and reuses the rubric scorer. +""" + +from __future__ import annotations + +import asyncio +import json +import sys +from pathlib import Path +from typing import Any, Sequence + +from inspect_ai import Task, task +from inspect_ai.dataset import Sample, hf_dataset +from inspect_ai.scorer import Score, Target, mean, scorer +from inspect_ai.solver._task_state import TaskState +import litellm + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from eval.rubric_eval import RubricData, evaluate_with_rubrics # noqa: E402 +from eval.solvers import get_solver # noqa: E402 + + +def _record_to_sample(record: dict[str, Any]) -> Sample: + rubric_payload = json.loads(record["rubric"]) + rubrics = rubric_payload.get("rubrics", []) + + metadata = { + "question": record["question"], + "discussion_title": record.get("discussion_title"), + "discussion_url": record.get("discussion_url"), + "rubric_title": rubric_payload.get("title"), + "rubric_description": rubric_payload.get("description"), + "rubrics": rubrics, + } + + return Sample( + input=record["question"], + target=record["solution"], + id=record.get("discussion_topic_id"), + metadata=metadata, + ) + + +def _load_dataset(dataset_name: str, split: str, limit: int | None) -> Sequence[Sample]: + return hf_dataset( + dataset_name, sample_fields=_record_to_sample, split=split, limit=limit + ) + + +def _metadata_to_rubrics(metadata: dict[str, Any]) -> list[RubricData]: + raw_rubrics = metadata.get("rubrics", []) + return [RubricData(**rubric) for rubric in raw_rubrics] + + +@scorer(metrics=[mean()], name="rubric_scorer") +def rubric_scorer(judge_model: str = "gpt-5-mini"): + async def score(state: TaskState, target: Target) -> Score: + response_text = state.output.completion or state.output.message.text + question = state.metadata.get("question", state.input_text) + rubrics = _metadata_to_rubrics(state.metadata) + + evaluation = await asyncio.to_thread( + evaluate_with_rubrics, + question, + response_text, + rubrics, + judge_model, + ) + + score_metadata = { + "raw_score": evaluation.raw_score, + "criterion_checks": [ + check.model_dump() for check in evaluation.criterion_checks + ], + "discussion_title": state.metadata.get("discussion_title"), + "discussion_url": state.metadata.get("discussion_url"), + "reference_answer": target.text, + } + + return Score( + value=evaluation.normalized_score, + answer=response_text, + explanation=f"Normalized score {evaluation.normalized_score:.3f}", + metadata=score_metadata, + ) + + return score + + +@task(name="hf-benchmark-with-rubrics") +def hf_benchmark_with_rubrics( + solver_name: str = "hf_agent", + solver_kwargs: dict[str, Any] = { + "max_iterations": 10, + "config_path": "agent/config_mcp_example.json", + }, + dataset_name: str = "akseljoonas/hf-agent-rubrics@train", + limit: int | None = None, + judge_model: str = "gpt-5-mini", +) -> Task: + litellm.drop_params = True + if "@" not in dataset_name: + raise ValueError("Dataset name must be in the format 'author/dataset@split'") + dataset_name, dataset_split = dataset_name.split("@") + dataset = _load_dataset(dataset_name, dataset_split, limit=limit) + + return Task( + dataset=dataset, + solver=get_solver(solver_name, **solver_kwargs), + scorer=rubric_scorer(judge_model=judge_model), + metadata={ + "dataset_name": dataset_name, + "dataset_split": dataset_split, + "solver_name": solver_name, + "judge_model": judge_model, + }, + ) diff --git a/frontend/index.html b/frontend/index.html index e5acd2f134c91f0715d7b259ba2eb842a655757b..6d7b512a1076c222ddf20efbf45893c43f5b33d5 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -2,9 +2,9 @@ - + - ML Intern + HF Agent diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 0d7602b5b273bc63a4864ce026a37e98c28d10e4..a800dd3f254b2ff725890c4f250e34d7490bf52d 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8,12 +8,10 @@ "name": "hf-agent-frontend", "version": "1.0.0", "dependencies": { - "@ai-sdk/react": "^3.0.93", "@emotion/react": "^11.13.0", "@emotion/styled": "^11.13.0", "@mui/icons-material": "^6.1.0", "@mui/material": "^6.1.0", - "ai": "^6.0.91", "react": "^18.3.1", "react-dom": "^18.3.1", "react-markdown": "^9.0.1", @@ -36,70 +34,6 @@ "vite": "^5.4.10" } }, - "node_modules/@ai-sdk/gateway": { - "version": "3.0.50", - "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-3.0.50.tgz", - "integrity": "sha512-Jdd1a8VgbD7l7r+COj0h5SuaYRfPvOJ/AO6l0OrmTPEcI2MUQPr3C4JttfpNkcheEN+gOdy0CtZWuG17bW2fjw==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "3.0.8", - "@ai-sdk/provider-utils": "4.0.15", - "@vercel/oidc": "3.1.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.25.76 || ^4.1.8" - } - }, - "node_modules/@ai-sdk/provider": { - "version": "3.0.8", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz", - "integrity": "sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==", - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/@ai-sdk/provider-utils": { - "version": "4.0.15", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.15.tgz", - "integrity": "sha512-8XiKWbemmCbvNN0CLR9u3PQiet4gtEVIrX4zzLxnCj06AwsEDJwJVBbKrEI4t6qE8XRSIvU2irka0dcpziKW6w==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "3.0.8", - "@standard-schema/spec": "^1.1.0", - "eventsource-parser": "^3.0.6" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.25.76 || ^4.1.8" - } - }, - "node_modules/@ai-sdk/react": { - "version": "3.0.93", - "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-3.0.93.tgz", - "integrity": "sha512-FY1HmeAfCpiAGLhIZh2QR8QFzHFZfhjMmkA9D5KC/O3eGqPeY7CwBABLkzRH+5Gkf+MfxXnEm4VF0MpmvDMjpg==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider-utils": "4.0.15", - "ai": "6.0.91", - "swr": "^2.2.5", - "throttleit": "2.1.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "react": "^18 || ~19.0.1 || ~19.1.2 || ^19.2.1" - } - }, "node_modules/@babel/code-frame": { "version": "7.28.6", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.28.6.tgz", @@ -130,6 +64,7 @@ "integrity": "sha512-H3mcG6ZDLTlYfaSNi0iOKkigqMFvkTKlGUYlD8GW7nNOYRrevuA46iTypPyv+06V3fEmvvazfntkBU34L0azAw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@babel/code-frame": "^7.28.6", "@babel/generator": "^7.28.6", @@ -446,6 +381,7 @@ "resolved": "https://registry.npmjs.org/@emotion/react/-/react-11.14.0.tgz", "integrity": "sha512-O000MLDBDdk/EohJPFUqvnp4qnHeYkVP5B0xEG0D/L7cOKP9kefu2DXn8dj74cQfsEzUqh+sr1RzFqiL1o+PpA==", "license": "MIT", + "peer": true, "dependencies": { "@babel/runtime": "^7.18.3", "@emotion/babel-plugin": "^11.13.5", @@ -489,6 +425,7 @@ "resolved": "https://registry.npmjs.org/@emotion/styled/-/styled-11.14.1.tgz", "integrity": "sha512-qEEJt42DuToa3gurlH4Qqc1kVpNq8wO8cJtDzU46TjlzWjDlsVyevtYCRijVq3SrHsROS+gVQ8Fnea108GnKzw==", "license": "MIT", + "peer": true, "dependencies": { "@babel/runtime": "^7.18.3", "@emotion/babel-plugin": "^11.13.5", @@ -1221,6 +1158,7 @@ "resolved": "https://registry.npmjs.org/@mui/material/-/material-6.5.0.tgz", "integrity": "sha512-yjvtXoFcrPLGtgKRxFaH6OQPtcLPhkloC0BML6rBG5UeldR0nPULR/2E2BfXdo5JNV7j7lOzrrLX2Qf/iSidow==", "license": "MIT", + "peer": true, "dependencies": { "@babel/runtime": "^7.26.0", "@mui/core-downloads-tracker": "^6.5.0", @@ -1410,15 +1348,6 @@ } } }, - "node_modules/@opentelemetry/api": { - "version": "1.9.0", - "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", - "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", - "license": "Apache-2.0", - "engines": { - "node": ">=8.0.0" - } - }, "node_modules/@popperjs/core": { "version": "2.11.8", "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.8.tgz", @@ -1437,9 +1366,9 @@ "license": "MIT" }, "node_modules/@rollup/rollup-android-arm-eabi": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.60.1.tgz", - "integrity": "sha512-d6FinEBLdIiK+1uACUttJKfgZREXrF0Qc2SmLII7W2AD8FfiZ9Wjd+rD/iRuf5s5dWrr1GgwXCvPqOuDquOowA==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.55.1.tgz", + "integrity": "sha512-9R0DM/ykwfGIlNu6+2U09ga0WXeZ9MRC2Ter8jnz8415VbuIykVuc6bhdrbORFZANDmTDvq26mJrEVTl8TdnDg==", "cpu": [ "arm" ], @@ -1451,9 +1380,9 @@ ] }, "node_modules/@rollup/rollup-android-arm64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.60.1.tgz", - "integrity": "sha512-YjG/EwIDvvYI1YvYbHvDz/BYHtkY4ygUIXHnTdLhG+hKIQFBiosfWiACWortsKPKU/+dUwQQCKQM3qrDe8c9BA==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.55.1.tgz", + "integrity": "sha512-eFZCb1YUqhTysgW3sj/55du5cG57S7UTNtdMjCW7LwVcj3dTTcowCsC8p7uBdzKsZYa8J7IDE8lhMI+HX1vQvg==", "cpu": [ "arm64" ], @@ -1465,9 +1394,9 @@ ] }, "node_modules/@rollup/rollup-darwin-arm64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.60.1.tgz", - "integrity": "sha512-mjCpF7GmkRtSJwon+Rq1N8+pI+8l7w5g9Z3vWj4T7abguC4Czwi3Yu/pFaLvA3TTeMVjnu3ctigusqWUfjZzvw==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.55.1.tgz", + "integrity": "sha512-p3grE2PHcQm2e8PSGZdzIhCKbMCw/xi9XvMPErPhwO17vxtvCN5FEA2mSLgmKlCjHGMQTP6phuQTYWUnKewwGg==", "cpu": [ "arm64" ], @@ -1479,9 +1408,9 @@ ] }, "node_modules/@rollup/rollup-darwin-x64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.60.1.tgz", - "integrity": "sha512-haZ7hJ1JT4e9hqkoT9R/19XW2QKqjfJVv+i5AGg57S+nLk9lQnJ1F/eZloRO3o9Scy9CM3wQ9l+dkXtcBgN5Ew==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.55.1.tgz", + "integrity": "sha512-rDUjG25C9qoTm+e02Esi+aqTKSBYwVTaoS1wxcN47/Luqef57Vgp96xNANwt5npq9GDxsH7kXxNkJVEsWEOEaQ==", "cpu": [ "x64" ], @@ -1493,9 +1422,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-arm64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.60.1.tgz", - "integrity": "sha512-czw90wpQq3ZsAVBlinZjAYTKduOjTywlG7fEeWKUA7oCmpA8xdTkxZZlwNJKWqILlq0wehoZcJYfBvOyhPTQ6w==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.55.1.tgz", + "integrity": "sha512-+JiU7Jbp5cdxekIgdte0jfcu5oqw4GCKr6i3PJTlXTCU5H5Fvtkpbs4XJHRmWNXF+hKmn4v7ogI5OQPaupJgOg==", "cpu": [ "arm64" ], @@ -1507,9 +1436,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-x64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.60.1.tgz", - "integrity": "sha512-KVB2rqsxTHuBtfOeySEyzEOB7ltlB/ux38iu2rBQzkjbwRVlkhAGIEDiiYnO2kFOkJp+Z7pUXKyrRRFuFUKt+g==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.55.1.tgz", + "integrity": "sha512-V5xC1tOVWtLLmr3YUk2f6EJK4qksksOYiz/TCsFHu/R+woubcLWdC9nZQmwjOAbmExBIVKsm1/wKmEy4z4u4Bw==", "cpu": [ "x64" ], @@ -1521,16 +1450,13 @@ ] }, "node_modules/@rollup/rollup-linux-arm-gnueabihf": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.60.1.tgz", - "integrity": "sha512-L+34Qqil+v5uC0zEubW7uByo78WOCIrBvci69E7sFASRl0X7b/MB6Cqd1lky/CtcSVTydWa2WZwFuWexjS5o6g==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.55.1.tgz", + "integrity": "sha512-Rn3n+FUk2J5VWx+ywrG/HGPTD9jXNbicRtTM11e/uorplArnXZYsVifnPPqNNP5BsO3roI4n8332ukpY/zN7rQ==", "cpu": [ "arm" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1538,16 +1464,13 @@ ] }, "node_modules/@rollup/rollup-linux-arm-musleabihf": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.60.1.tgz", - "integrity": "sha512-n83O8rt4v34hgFzlkb1ycniJh7IR5RCIqt6mz1VRJD6pmhRi0CXdmfnLu9dIUS6buzh60IvACM842Ffb3xd6Gg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.55.1.tgz", + "integrity": "sha512-grPNWydeKtc1aEdrJDWk4opD7nFtQbMmV7769hiAaYyUKCT1faPRm2av8CX1YJsZ4TLAZcg9gTR1KvEzoLjXkg==", "cpu": [ "arm" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1555,16 +1478,13 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.60.1.tgz", - "integrity": "sha512-Nql7sTeAzhTAja3QXeAI48+/+GjBJ+QmAH13snn0AJSNL50JsDqotyudHyMbO2RbJkskbMbFJfIJKWA6R1LCJQ==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.55.1.tgz", + "integrity": "sha512-a59mwd1k6x8tXKcUxSyISiquLwB5pX+fJW9TkWU46lCqD/GRDe9uDN31jrMmVP3feI3mhAdvcCClhV8V5MhJFQ==", "cpu": [ "arm64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1572,16 +1492,13 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.60.1.tgz", - "integrity": "sha512-+pUymDhd0ys9GcKZPPWlFiZ67sTWV5UU6zOJat02M1+PiuSGDziyRuI/pPue3hoUwm2uGfxdL+trT6Z9rxnlMA==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.55.1.tgz", + "integrity": "sha512-puS1MEgWX5GsHSoiAsF0TYrpomdvkaXm0CofIMG5uVkP6IBV+ZO9xhC5YEN49nsgYo1DuuMquF9+7EDBVYu4uA==", "cpu": [ "arm64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1589,16 +1506,13 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.60.1.tgz", - "integrity": "sha512-VSvgvQeIcsEvY4bKDHEDWcpW4Yw7BtlKG1GUT4FzBUlEKQK0rWHYBqQt6Fm2taXS+1bXvJT6kICu5ZwqKCnvlQ==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.55.1.tgz", + "integrity": "sha512-r3Wv40in+lTsULSb6nnoudVbARdOwb2u5fpeoOAZjFLznp6tDU8kd+GTHmJoqZ9lt6/Sys33KdIHUaQihFcu7g==", "cpu": [ "loong64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1606,16 +1520,13 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.60.1.tgz", - "integrity": "sha512-4LqhUomJqwe641gsPp6xLfhqWMbQV04KtPp7/dIp0nzPxAkNY1AbwL5W0MQpcalLYk07vaW9Kp1PBhdpZYYcEw==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.55.1.tgz", + "integrity": "sha512-MR8c0+UxAlB22Fq4R+aQSPBayvYa3+9DrwG/i1TKQXFYEaoW3B5b/rkSRIypcZDdWjWnpcvxbNaAJDcSbJU3Lw==", "cpu": [ "loong64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1623,16 +1534,13 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.60.1.tgz", - "integrity": "sha512-tLQQ9aPvkBxOc/EUT6j3pyeMD6Hb8QF2BTBnCQWP/uu1lhc9AIrIjKnLYMEroIz/JvtGYgI9dF3AxHZNaEH0rw==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.55.1.tgz", + "integrity": "sha512-3KhoECe1BRlSYpMTeVrD4sh2Pw2xgt4jzNSZIIPLFEsnQn9gAnZagW9+VqDqAHgm1Xc77LzJOo2LdigS5qZ+gw==", "cpu": [ "ppc64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1640,16 +1548,13 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.60.1.tgz", - "integrity": "sha512-RMxFhJwc9fSXP6PqmAz4cbv3kAyvD1etJFjTx4ONqFP9DkTkXsAMU4v3Vyc5BgzC+anz7nS/9tp4obsKfqkDHg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.55.1.tgz", + "integrity": "sha512-ziR1OuZx0vdYZZ30vueNZTg73alF59DicYrPViG0NEgDVN8/Jl87zkAPu4u6VjZST2llgEUjaiNl9JM6HH1Vdw==", "cpu": [ "ppc64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1657,16 +1562,13 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.60.1.tgz", - "integrity": "sha512-QKgFl+Yc1eEk6MmOBfRHYF6lTxiiiV3/z/BRrbSiW2I7AFTXoBFvdMEyglohPj//2mZS4hDOqeB0H1ACh3sBbg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.55.1.tgz", + "integrity": "sha512-uW0Y12ih2XJRERZ4jAfKamTyIHVMPQnTZcQjme2HMVDAHY4amf5u414OqNYC+x+LzRdRcnIG1YodLrrtA8xsxw==", "cpu": [ "riscv64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1674,16 +1576,13 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.60.1.tgz", - "integrity": "sha512-RAjXjP/8c6ZtzatZcA1RaQr6O1TRhzC+adn8YZDnChliZHviqIjmvFwHcxi4JKPSDAt6Uhf/7vqcBzQJy0PDJg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.55.1.tgz", + "integrity": "sha512-u9yZ0jUkOED1BFrqu3BwMQoixvGHGZ+JhJNkNKY/hyoEgOwlqKb62qu+7UjbPSHYjiVy8kKJHvXKv5coH4wDeg==", "cpu": [ "riscv64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1691,16 +1590,13 @@ ] }, "node_modules/@rollup/rollup-linux-s390x-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.60.1.tgz", - "integrity": "sha512-wcuocpaOlaL1COBYiA89O6yfjlp3RwKDeTIA0hM7OpmhR1Bjo9j31G1uQVpDlTvwxGn2nQs65fBFL5UFd76FcQ==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.55.1.tgz", + "integrity": "sha512-/0PenBCmqM4ZUd0190j7J0UsQ/1nsi735iPRakO8iPciE7BQ495Y6msPzaOmvx0/pn+eJVVlZrNrSh4WSYLxNg==", "cpu": [ "s390x" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1708,16 +1604,13 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.60.1.tgz", - "integrity": "sha512-77PpsFQUCOiZR9+LQEFg9GClyfkNXj1MP6wRnzYs0EeWbPcHs02AXu4xuUbM1zhwn3wqaizle3AEYg5aeoohhg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.55.1.tgz", + "integrity": "sha512-a8G4wiQxQG2BAvo+gU6XrReRRqj+pLS2NGXKm8io19goR+K8lw269eTrPkSdDTALwMmJp4th2Uh0D8J9bEV1vg==", "cpu": [ "x64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1725,16 +1618,13 @@ ] }, "node_modules/@rollup/rollup-linux-x64-musl": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.60.1.tgz", - "integrity": "sha512-5cIATbk5vynAjqqmyBjlciMJl1+R/CwX9oLk/EyiFXDWd95KpHdrOJT//rnUl4cUcskrd0jCCw3wpZnhIHdD9w==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.55.1.tgz", + "integrity": "sha512-bD+zjpFrMpP/hqkfEcnjXWHMw5BIghGisOKPj+2NaNDuVT+8Ds4mPf3XcPHuat1tz89WRL+1wbcxKY3WSbiT7w==", "cpu": [ "x64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1742,9 +1632,9 @@ ] }, "node_modules/@rollup/rollup-openbsd-x64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.60.1.tgz", - "integrity": "sha512-cl0w09WsCi17mcmWqqglez9Gk8isgeWvoUZ3WiJFYSR3zjBQc2J5/ihSjpl+VLjPqjQ/1hJRcqBfLjssREQILw==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.55.1.tgz", + "integrity": "sha512-eLXw0dOiqE4QmvikfQ6yjgkg/xDM+MdU9YJuP4ySTibXU0oAvnEWXt7UDJmD4UkYialMfOGFPJnIHSe/kdzPxg==", "cpu": [ "x64" ], @@ -1756,9 +1646,9 @@ ] }, "node_modules/@rollup/rollup-openharmony-arm64": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.60.1.tgz", - "integrity": "sha512-4Cv23ZrONRbNtbZa37mLSueXUCtN7MXccChtKpUnQNgF010rjrjfHx3QxkS2PI7LqGT5xXyYs1a7LbzAwT0iCA==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.55.1.tgz", + "integrity": "sha512-xzm44KgEP11te3S2HCSyYf5zIzWmx3n8HDCc7EE59+lTcswEWNpvMLfd9uJvVX8LCg9QWG67Xt75AuHn4vgsXw==", "cpu": [ "arm64" ], @@ -1770,9 +1660,9 @@ ] }, "node_modules/@rollup/rollup-win32-arm64-msvc": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.60.1.tgz", - "integrity": "sha512-i1okWYkA4FJICtr7KpYzFpRTHgy5jdDbZiWfvny21iIKky5YExiDXP+zbXzm3dUcFpkEeYNHgQ5fuG236JPq0g==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.55.1.tgz", + "integrity": "sha512-yR6Bl3tMC/gBok5cz/Qi0xYnVbIxGx5Fcf/ca0eB6/6JwOY+SRUcJfI0OpeTpPls7f194as62thCt/2BjxYN8g==", "cpu": [ "arm64" ], @@ -1784,9 +1674,9 @@ ] }, "node_modules/@rollup/rollup-win32-ia32-msvc": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.60.1.tgz", - "integrity": "sha512-u09m3CuwLzShA0EYKMNiFgcjjzwqtUMLmuCJLeZWjjOYA3IT2Di09KaxGBTP9xVztWyIWjVdsB2E9goMjZvTQg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.55.1.tgz", + "integrity": "sha512-3fZBidchE0eY0oFZBnekYCfg+5wAB0mbpCBuofh5mZuzIU/4jIVkbESmd2dOsFNS78b53CYv3OAtwqkZZmU5nA==", "cpu": [ "ia32" ], @@ -1798,9 +1688,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-gnu": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.60.1.tgz", - "integrity": "sha512-k+600V9Zl1CM7eZxJgMyTUzmrmhB/0XZnF4pRypKAlAgxmedUA+1v9R+XOFv56W4SlHEzfeMtzujLJD22Uz5zg==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.55.1.tgz", + "integrity": "sha512-xGGY5pXj69IxKb4yv/POoocPy/qmEGhimy/FoTpTSVju3FYXUQQMFCaZZXJVidsmGxRioZAwpThl/4zX41gRKg==", "cpu": [ "x64" ], @@ -1812,9 +1702,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-msvc": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.60.1.tgz", - "integrity": "sha512-lWMnixq/QzxyhTV6NjQJ4SFo1J6PvOX8vUx5Wb4bBPsEb+8xZ89Bz6kOXpfXj9ak9AHTQVQzlgzBEc1SyM27xQ==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.55.1.tgz", + "integrity": "sha512-SPEpaL6DX4rmcXtnhdrQYgzQ5W2uW3SCJch88lB2zImhJRhIIK44fkUrgIV/Q8yUNfw5oyZ5vkeQsZLhCb06lw==", "cpu": [ "x64" ], @@ -1825,12 +1715,6 @@ "win32" ] }, - "node_modules/@standard-schema/spec": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", - "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", - "license": "MIT" - }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -1954,6 +1838,7 @@ "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.27.tgz", "integrity": "sha512-cisd7gxkzjBKU2GgdYrTdtQx1SORymWyaAFhaxQPK9bYO9ot3Y5OikQRvY0VYQtvwjeQnizCINJAenh/V7MK2w==", "license": "MIT", + "peer": true, "dependencies": { "@types/prop-types": "*", "csstype": "^3.2.2" @@ -2039,6 +1924,7 @@ "integrity": "sha512-npiaib8XzbjtzS2N4HlqPvlpxpmZ14FjSJrteZpPxGUaYPlvhzlzUZ4mZyABo0EFrOWnvyd0Xxroq//hKhtAWg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.53.0", "@typescript-eslint/types": "8.53.0", @@ -2183,9 +2069,9 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz", - "integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -2193,13 +2079,13 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": { - "version": "9.0.9", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", - "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", "dev": true, "license": "ISC", "dependencies": { - "brace-expansion": "^2.0.2" + "brace-expansion": "^2.0.1" }, "engines": { "node": ">=16 || 14 >=14.17" @@ -2269,15 +2155,6 @@ "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==", "license": "ISC" }, - "node_modules/@vercel/oidc": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.1.0.tgz", - "integrity": "sha512-Fw28YZpRnA3cAHHDlkt7xQHiJ0fcL+NRcIqsocZQUSmbzeIKRpwttJjik5ZGanXP+vlA4SbTg+AbA3bP363l+w==", - "license": "Apache-2.0", - "engines": { - "node": ">= 20" - } - }, "node_modules/@vitejs/plugin-react": { "version": "4.7.0", "resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.7.0.tgz", @@ -2305,6 +2182,7 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -2322,28 +2200,10 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, - "node_modules/ai": { - "version": "6.0.91", - "resolved": "https://registry.npmjs.org/ai/-/ai-6.0.91.tgz", - "integrity": "sha512-k1/8BusZMhYVxxLZt0BUZzm9HVDCCh117nyWfWUx5xjR2+tWisJbXgysL7EBMq2lgyHwgpA1jDR3tVjWSdWZXw==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/gateway": "3.0.50", - "@ai-sdk/provider": "3.0.8", - "@ai-sdk/provider-utils": "4.0.15", - "@opentelemetry/api": "1.9.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.25.76 || ^4.1.8" - } - }, "node_modules/ajv": { - "version": "6.14.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz", - "integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==", + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", "dev": true, "license": "MIT", "dependencies": { @@ -2423,9 +2283,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.13", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", - "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", "dev": true, "license": "MIT", "dependencies": { @@ -2453,6 +2313,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -2805,6 +2666,7 @@ "integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -2986,15 +2848,6 @@ "node": ">=0.10.0" } }, - "node_modules/eventsource-parser": { - "version": "3.0.6", - "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", - "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==", - "license": "MIT", - "engines": { - "node": ">=18.0.0" - } - }, "node_modules/extend": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", @@ -3104,9 +2957,9 @@ } }, "node_modules/flatted": { - "version": "3.4.2", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz", - "integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", + "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", "dev": true, "license": "ISC" }, @@ -3503,12 +3356,6 @@ "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", "license": "MIT" }, - "node_modules/json-schema": { - "version": "0.4.0", - "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz", - "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==", - "license": "(AFL-2.1 OR BSD-3-Clause)" - }, "node_modules/json-schema-traverse": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", @@ -4491,9 +4338,9 @@ "license": "MIT" }, "node_modules/minimatch": { - "version": "3.1.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", - "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", "dev": true, "license": "ISC", "dependencies": { @@ -4698,11 +4545,12 @@ "license": "ISC" }, "node_modules/picomatch": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", - "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -4800,6 +4648,7 @@ "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", "license": "MIT", + "peer": true, "dependencies": { "loose-envify": "^1.1.0" }, @@ -4812,6 +4661,7 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz", "integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==", "license": "MIT", + "peer": true, "dependencies": { "loose-envify": "^1.1.0", "scheduler": "^0.23.2" @@ -5011,9 +4861,9 @@ } }, "node_modules/rollup": { - "version": "4.60.1", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.1.tgz", - "integrity": "sha512-VmtB2rFU/GroZ4oL8+ZqXgSA38O6GR8KSIvWmEFv63pQ0G6KaBH9s07PO8XTXP4vI+3UJUEypOfjkGfmSBBR0w==", + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.55.1.tgz", + "integrity": "sha512-wDv/Ht1BNHB4upNbK74s9usvl7hObDnvVzknxqY/E/O3X6rW1U1rV1aENEfJ54eFZDTNo7zv1f5N4edCluH7+A==", "dev": true, "license": "MIT", "dependencies": { @@ -5027,31 +4877,31 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.60.1", - "@rollup/rollup-android-arm64": "4.60.1", - "@rollup/rollup-darwin-arm64": "4.60.1", - "@rollup/rollup-darwin-x64": "4.60.1", - "@rollup/rollup-freebsd-arm64": "4.60.1", - "@rollup/rollup-freebsd-x64": "4.60.1", - "@rollup/rollup-linux-arm-gnueabihf": "4.60.1", - "@rollup/rollup-linux-arm-musleabihf": "4.60.1", - "@rollup/rollup-linux-arm64-gnu": "4.60.1", - "@rollup/rollup-linux-arm64-musl": "4.60.1", - "@rollup/rollup-linux-loong64-gnu": "4.60.1", - "@rollup/rollup-linux-loong64-musl": "4.60.1", - "@rollup/rollup-linux-ppc64-gnu": "4.60.1", - "@rollup/rollup-linux-ppc64-musl": "4.60.1", - "@rollup/rollup-linux-riscv64-gnu": "4.60.1", - "@rollup/rollup-linux-riscv64-musl": "4.60.1", - "@rollup/rollup-linux-s390x-gnu": "4.60.1", - "@rollup/rollup-linux-x64-gnu": "4.60.1", - "@rollup/rollup-linux-x64-musl": "4.60.1", - "@rollup/rollup-openbsd-x64": "4.60.1", - "@rollup/rollup-openharmony-arm64": "4.60.1", - "@rollup/rollup-win32-arm64-msvc": "4.60.1", - "@rollup/rollup-win32-ia32-msvc": "4.60.1", - "@rollup/rollup-win32-x64-gnu": "4.60.1", - "@rollup/rollup-win32-x64-msvc": "4.60.1", + "@rollup/rollup-android-arm-eabi": "4.55.1", + "@rollup/rollup-android-arm64": "4.55.1", + "@rollup/rollup-darwin-arm64": "4.55.1", + "@rollup/rollup-darwin-x64": "4.55.1", + "@rollup/rollup-freebsd-arm64": "4.55.1", + "@rollup/rollup-freebsd-x64": "4.55.1", + "@rollup/rollup-linux-arm-gnueabihf": "4.55.1", + "@rollup/rollup-linux-arm-musleabihf": "4.55.1", + "@rollup/rollup-linux-arm64-gnu": "4.55.1", + "@rollup/rollup-linux-arm64-musl": "4.55.1", + "@rollup/rollup-linux-loong64-gnu": "4.55.1", + "@rollup/rollup-linux-loong64-musl": "4.55.1", + "@rollup/rollup-linux-ppc64-gnu": "4.55.1", + "@rollup/rollup-linux-ppc64-musl": "4.55.1", + "@rollup/rollup-linux-riscv64-gnu": "4.55.1", + "@rollup/rollup-linux-riscv64-musl": "4.55.1", + "@rollup/rollup-linux-s390x-gnu": "4.55.1", + "@rollup/rollup-linux-x64-gnu": "4.55.1", + "@rollup/rollup-linux-x64-musl": "4.55.1", + "@rollup/rollup-openbsd-x64": "4.55.1", + "@rollup/rollup-openharmony-arm64": "4.55.1", + "@rollup/rollup-win32-arm64-msvc": "4.55.1", + "@rollup/rollup-win32-ia32-msvc": "4.55.1", + "@rollup/rollup-win32-x64-gnu": "4.55.1", + "@rollup/rollup-win32-x64-msvc": "4.55.1", "fsevents": "~2.3.2" } }, @@ -5202,31 +5052,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/swr": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/swr/-/swr-2.4.0.tgz", - "integrity": "sha512-sUlC20T8EOt1pHmDiqueUWMmRRX03W7w5YxovWX7VR2KHEPCTMly85x05vpkP5i6Bu4h44ePSMD9Tc+G2MItFw==", - "license": "MIT", - "dependencies": { - "dequal": "^2.0.3", - "use-sync-external-store": "^1.6.0" - }, - "peerDependencies": { - "react": "^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" - } - }, - "node_modules/throttleit": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/throttleit/-/throttleit-2.1.0.tgz", - "integrity": "sha512-nt6AMGKW1p/70DF/hGBdJB57B8Tspmbp5gfJ8ilhLnt7kkr2ye7hzD6NVG8GGErk2HWF34igrL2CXmNIkzKqKw==", - "license": "MIT", - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -5296,6 +5121,7 @@ "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "dev": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -5456,15 +5282,6 @@ "punycode": "^2.1.0" } }, - "node_modules/use-sync-external-store": { - "version": "1.6.0", - "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz", - "integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==", - "license": "MIT", - "peerDependencies": { - "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" - } - }, "node_modules/vfile": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz", @@ -5499,6 +5316,7 @@ "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -5587,9 +5405,9 @@ "license": "ISC" }, "node_modules/yaml": { - "version": "1.10.3", - "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.3.tgz", - "integrity": "sha512-vIYeF1u3CjlhAFekPPAk2h/Kv4T3mAkMox5OymRiJQB0spDP10LHvt+K7G9Ny6NuuMAb25/6n1qyUjAcGNf/AA==", + "version": "1.10.2", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.2.tgz", + "integrity": "sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==", "license": "ISC", "engines": { "node": ">= 6" @@ -5608,16 +5426,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/zod": { - "version": "4.3.6", - "resolved": "https://registry.npmjs.org/zod/-/zod-4.3.6.tgz", - "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", - "license": "MIT", - "peer": true, - "funding": { - "url": "https://github.com/sponsors/colinhacks" - } - }, "node_modules/zustand": { "version": "5.0.10", "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.10.tgz", diff --git a/frontend/package.json b/frontend/package.json index 9efe3dced3118cbf0976e413f376f1050f1b2853..553726bae62a96f8869c8bec29bf3fbad511bc0c 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -10,12 +10,10 @@ "preview": "vite preview" }, "dependencies": { - "@ai-sdk/react": "^3.0.93", "@emotion/react": "^11.13.0", "@emotion/styled": "^11.13.0", "@mui/icons-material": "^6.1.0", "@mui/material": "^6.1.0", - "ai": "^6.0.91", "react": "^18.3.1", "react-dom": "^18.3.1", "react-markdown": "^9.0.1", diff --git a/frontend/public/smolagents.webp b/frontend/public/smolagents.webp deleted file mode 100644 index 4be2c482082e0d2f08c88336d89a71f2b1f2f55e..0000000000000000000000000000000000000000 Binary files a/frontend/public/smolagents.webp and /dev/null differ diff --git a/frontend/src/components/Chat/ActivityStatusBar.tsx b/frontend/src/components/Chat/ActivityStatusBar.tsx deleted file mode 100644 index 3dd0af534ec1b7fa5861116b865b388f76a42eaa..0000000000000000000000000000000000000000 --- a/frontend/src/components/Chat/ActivityStatusBar.tsx +++ /dev/null @@ -1,146 +0,0 @@ -import { Box, Typography } from '@mui/material'; -import { keyframes } from '@mui/system'; -import { useAgentStore, type ActivityStatus } from '@/store/agentStore'; - -const shimmer = keyframes` - 0% { background-position: -100% center; } - 50% { background-position: 200% center; } - 100% { background-position: -100% center; } -`; - -const TOOL_LABELS: Record = { - sandbox_create: 'Creating sandbox for code development, this might take 1-2 minutes', - bash: 'Running command in sandbox', - hf_jobs: 'Running a GPU job, this might take a while', - hf_repo_files: 'Uploading file', - hf_repo_git: 'Git operation', - hf_inspect_dataset: 'Inspecting dataset', - hf_search: 'Searching', - plan_tool: 'Planning', - research: 'Researching', -}; - -/** Format raw research log into a clean status label. */ -function formatResearchStatus(raw: string): string { - const s = raw.replace(/^β–Έ\s*/, ''); - const jsonStart = s.indexOf('{'); - const toolName = jsonStart > 0 ? s.slice(0, jsonStart).trim() : s.trim(); - let args: Record = {}; - if (jsonStart > 0) { - const jsonStr = s.slice(jsonStart); - try { - const parsed = JSON.parse(jsonStr); - for (const [k, v] of Object.entries(parsed)) { - if (typeof v === 'string') args[k] = v; - } - } catch { - // JSON is likely truncated β€” extract complete "key": "value" pairs - for (const m of jsonStr.matchAll(/"(\w+)":\s*"([^"]*)"/g)) { - args[m[1]] = m[2]; - } - // Also try to extract a truncated value for known keys if not found yet - if (!args.query && !args.arxiv_id) { - const partial = jsonStr.match(/"(query|arxiv_id)":\s*"([^"]*)/); - if (partial) args[partial[1]] = partial[2]; - } - } - } - - if (toolName === 'github_find_examples') { - const d = (args.keyword) || (args.repo); - return d ? `Finding examples: ${d}` : 'Finding examples'; - } - if (toolName === 'github_read_file') { - const f = ((args.path) || '').split('/').pop(); - return f ? `Reading ${f}` : 'Reading file'; - } - if (toolName === 'explore_hf_docs') { - const d = (args.endpoint) || (args.query); - return d ? `Exploring docs: ${d}` : 'Exploring docs'; - } - if (toolName === 'fetch_hf_docs') { - const p = ((args.url) || '').split('/').pop()?.replace(/\.md$/, ''); - return p ? `Reading docs: ${p}` : 'Fetching docs'; - } - if (toolName === 'hf_inspect_dataset') { - const d = args.dataset as string; - return d ? `Inspecting dataset: ${d}` : 'Inspecting dataset'; - } - if (toolName === 'hf_papers') { - const op = args.operation as string; - const detail = (args.query) || (args.arxiv_id) || (args.positive_ids); - const opLabels: Record = { - trending: 'Browsing trending papers', - search: 'Searching papers', - paper_details: 'Reading paper details', - read_paper: 'Reading paper', - citation_graph: 'Tracing citations', - snippet_search: 'Searching paper passages', - recommend: 'Finding similar papers', - find_datasets: 'Finding paper datasets', - find_models: 'Finding paper models', - find_collections: 'Finding paper collections', - find_all_resources: 'Finding paper resources', - }; - const base = (op && opLabels[op]) || 'Searching papers'; - return detail ? `${base}: ${detail}` : base; - } - if (toolName === 'find_hf_api') { - const d = (args.query) || (args.tag); - return d ? `Finding API: ${d}` : 'Finding API endpoints'; - } - if (toolName === 'hf_repo_files') { - const d = (args.repo_id) || (args.repo); - return d ? `Reading ${d} files` : 'Reading repo files'; - } - return 'Researching'; -} - -function statusLabel(status: ActivityStatus): string { - switch (status.type) { - case 'thinking': return 'Thinking'; - case 'streaming': return 'Writing'; - case 'tool': { - if (status.toolName === 'research' && status.description) { - return formatResearchStatus(status.description); - } - const base = status.description || TOOL_LABELS[status.toolName] || `Running ${status.toolName}`; - if (status.toolName === 'bash' && status.description && /install/i.test(status.description)) { - return `${base} β€” this can take a few minutes, sit tight`; - } - return base; - } - case 'waiting-approval': return 'Waiting for approval'; - case 'cancelled': return 'What should the agent do instead?'; - default: return ''; - } -} - -export default function ActivityStatusBar() { - const activityStatus = useAgentStore(s => s.activityStatus); - - if (activityStatus.type === 'idle') return null; - - const label = statusLabel(activityStatus); - - return ( - - - {label}{activityStatus.type !== 'cancelled' && '…'} - - - ); -} diff --git a/frontend/src/components/Chat/AssistantMessage.tsx b/frontend/src/components/Chat/AssistantMessage.tsx index 91c7b8c1012bf1513ca141999d1acc7cfa23284f..9cd0d0597c8bd723300321587c11ce8ae4993822 100644 --- a/frontend/src/components/Chat/AssistantMessage.tsx +++ b/frontend/src/components/Chat/AssistantMessage.tsx @@ -1,91 +1,54 @@ -import { useMemo, useState } from 'react'; -import { Box, IconButton, Stack, Tooltip, Typography } from '@mui/material'; -import ThumbUpOutlined from '@mui/icons-material/ThumbUpOutlined'; -import ThumbUp from '@mui/icons-material/ThumbUp'; -import ThumbDownOutlined from '@mui/icons-material/ThumbDownOutlined'; -import ThumbDown from '@mui/icons-material/ThumbDown'; +import { Box, Stack, Typography } from '@mui/material'; import MarkdownContent from './MarkdownContent'; import ToolCallGroup from './ToolCallGroup'; -import { apiFetch } from '@/utils/api'; -import type { UIMessage } from 'ai'; -import type { MessageMeta } from '@/types/agent'; +import type { Message } from '@/types/agent'; interface AssistantMessageProps { - message: UIMessage; + message: Message; + /** True when this message is actively receiving streaming chunks. */ isStreaming?: boolean; - sessionId?: string | null; - approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; } -/** - * Groups consecutive tool parts together so they render as a single - * ToolCallGroup (visually identical to the old segments approach). - */ -type DynamicToolPart = Extract; - -function groupParts(parts: UIMessage['parts']) { - const groups: Array< - | { kind: 'text'; text: string; idx: number } - | { kind: 'tools'; tools: DynamicToolPart[]; idx: number } - > = []; - - for (let i = 0; i < parts.length; i++) { - const part = parts[i]; - - if (part.type === 'text') { - groups.push({ kind: 'text', text: part.text, idx: i }); - } else if (part.type === 'dynamic-tool') { - const toolPart = part as DynamicToolPart; - const last = groups[groups.length - 1]; - if (last?.kind === 'tools') { - last.tools.push(toolPart); - } else { - groups.push({ kind: 'tools', tools: [toolPart], idx: i }); +export default function AssistantMessage({ message, isStreaming = false }: AssistantMessageProps) { + const renderSegments = () => { + if (message.segments && message.segments.length > 0) { + // Find the index of the last text segment (that's the one being streamed) + let lastTextIdx = -1; + for (let i = message.segments.length - 1; i >= 0; i--) { + if (message.segments[i].type === 'text') { + lastTextIdx = i; + break; + } } - } - // step-start, step-end, etc. are ignored visually - } - - return groups; -} - -export default function AssistantMessage({ message, isStreaming = false, sessionId, approveTools }: AssistantMessageProps) { - const groups = useMemo(() => groupParts(message.parts), [message.parts]); - const [feedback, setFeedback] = useState<'up' | 'down' | null>(null); - const [feedbackBusy, setFeedbackBusy] = useState(false); - const sendFeedback = async (rating: 'up' | 'down') => { - if (!sessionId || feedbackBusy) return; - setFeedbackBusy(true); - // Optimistic toggle β€” feedback is observability, not a hard requirement. - setFeedback(rating); - try { - await apiFetch(`/api/feedback/${sessionId}`, { - method: 'POST', - body: JSON.stringify({ rating, message_id: message.id }), + return message.segments.map((segment, idx) => { + if (segment.type === 'text' && segment.content) { + return ( + + ); + } + if (segment.type === 'tools' && segment.tools && segment.tools.length > 0) { + return ; + } + return null; }); - } catch { - // Silently swallow β€” don't block chat UX on a telemetry write. - } finally { - setFeedbackBusy(false); } - }; - // Find the last text group index for streaming cursor - let lastTextIdx = -1; - for (let i = groups.length - 1; i >= 0; i--) { - if (groups[i].kind === 'text') { lastTextIdx = i; break; } - } - - const meta = message.metadata as MessageMeta | undefined; - const timeStr = meta?.createdAt - ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) - : null; + // Fallback: render raw content + if (message.content) { + return ; + } - if (groups.length === 0) return null; + return null; + }; return ( + {/* Role label + timestamp */} Assistant - {timeStr && ( - - {timeStr} - - )} + + {new Date(message.timestamp).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })} + + {/* Message bubble */} - {groups.map((group, i) => { - if (group.kind === 'text' && group.text) { - return ( - - ); - } - if (group.kind === 'tools' && group.tools.length > 0) { - return ( - - ); - } - return null; - })} + {renderSegments()} - {!isStreaming && sessionId && ( - - - sendFeedback('up')}> - {feedback === 'up' ? : } - - - - sendFeedback('down')}> - {feedback === 'down' ? : } - - - - )} ); } diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index 8a8810eac905be0639b450d11d35a8ca9c675252..2a1e75e6b857a4b790fa0eedf3738bf63095cc7c 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,34 +1,8 @@ import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; -import { - Alert, - Box, - TextField, - IconButton, - CircularProgress, - Typography, - Menu, - MenuItem, - ListItemIcon, - ListItemText, - Chip, - Snackbar, -} from '@mui/material'; +import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; -import StopIcon from '@mui/icons-material/Stop'; import { apiFetch } from '@/utils/api'; -import { useUserQuota } from '@/hooks/useUserQuota'; -import ClaudeCapDialog from '@/components/ClaudeCapDialog'; -import JobsUpgradeDialog from '@/components/JobsUpgradeDialog'; -import { useAgentStore } from '@/store/agentStore'; -import { useSessionStore } from '@/store/sessionStore'; -import { - CLAUDE_MODEL_PATH, - FIRST_FREE_MODEL_PATH, - GPT_55_MODEL_PATH, - isClaudePath, - isPremiumPath, -} from '@/utils/model'; // Model configuration interface ModelOption { @@ -45,199 +19,83 @@ const getHfAvatarUrl = (modelId: string) => { return `https://huggingface.co/api/avatars/${org}`; }; -const DEFAULT_MODEL_OPTIONS: ModelOption[] = [ +const MODEL_OPTIONS: ModelOption[] = [ { - id: 'kimi-k2.6', - name: 'Kimi K2.6', - description: 'Novita', - modelPath: 'moonshotai/Kimi-K2.6', - avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.6'), + id: 'minimax-m2.1', + name: 'MiniMax M2.1', + description: 'Via Novita', + modelPath: 'huggingface/novita/MiniMaxAI/MiniMax-M2.1', + avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.1'), recommended: true, }, { id: 'claude-opus', - name: 'Claude Opus 4.6', + name: 'Claude Opus 4.5', description: 'Anthropic', - modelPath: CLAUDE_MODEL_PATH, + modelPath: 'anthropic/claude-opus-4-5-20251101', avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', recommended: true, }, { - id: 'gpt-5.5', - name: 'GPT-5.5', - description: 'OpenAI', - modelPath: GPT_55_MODEL_PATH, - avatarUrl: 'https://huggingface.co/api/avatars/openai', + id: 'kimi-k2.5', + name: 'Kimi K2.5', + description: 'Via Novita', + modelPath: 'huggingface/novita/moonshotai/Kimi-K2.5', + avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.5'), }, { - id: 'minimax-m2.7', - name: 'MiniMax M2.7', - description: 'Novita', - modelPath: 'MiniMaxAI/MiniMax-M2.7', - avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.7'), - }, - { - id: 'glm-5.1', - name: 'GLM 5.1', - description: 'Together', - modelPath: 'zai-org/GLM-5.1', - avatarUrl: getHfAvatarUrl('zai-org/GLM-5.1'), - }, - { - id: 'deepseek-v4-pro', - name: 'DeepSeek V4 Pro', - description: 'DeepInfra', - modelPath: 'deepseek-ai/DeepSeek-V4-Pro:deepinfra', - avatarUrl: getHfAvatarUrl('deepseek-ai/DeepSeek-V4-Pro'), + id: 'glm-5', + name: 'GLM 5', + description: 'Via Novita', + modelPath: 'huggingface/novita/zai-org/GLM-5', + avatarUrl: getHfAvatarUrl('zai-org/GLM-5'), }, ]; -const findModelByPath = (path: string, options: ModelOption[]): ModelOption | undefined => { - if (isClaudePath(path)) { - const claude = options.find(isClaudeModel); - if (claude) return claude; - } - return options.find(m => m.modelPath === path || path?.includes(m.id)); -}; - -const readApiErrorMessage = async (res: Response, fallback: string): Promise => { - try { - const data = await res.json(); - const detail = data?.detail; - if (typeof detail === 'string') return detail; - if (detail && typeof detail.message === 'string') return detail.message; - if (detail && typeof detail.error === 'string') return detail.error; - } catch { - /* ignore malformed error bodies */ - } - return fallback; +const findModelByPath = (path: string): ModelOption | undefined => { + return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); }; interface ChatInputProps { - sessionId?: string; - initialModelPath?: string | null; onSend: (text: string) => void; - onStop?: () => void; - isProcessing?: boolean; disabled?: boolean; - placeholder?: string; } -const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath); -const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath); -const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0]; - -export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { +export default function ChatInput({ onSend, disabled = false }: ChatInputProps) { const [input, setInput] = useState(''); const inputRef = useRef(null); - const [modelOptions, setModelOptions] = useState(DEFAULT_MODEL_OPTIONS); - const modelOptionsRef = useRef(DEFAULT_MODEL_OPTIONS); - const sessionIdRef = useRef(sessionId); - const [selectedModelId, setSelectedModelId] = useState( - () => findModelByPath(initialModelPath ?? '', DEFAULT_MODEL_OPTIONS)?.id ?? DEFAULT_MODEL_OPTIONS[0].id, - ); + const [selectedModelId, setSelectedModelId] = useState(MODEL_OPTIONS[0].id); const [modelAnchorEl, setModelAnchorEl] = useState(null); - const { quota, refresh: refreshQuota } = useUserQuota(); - // The daily-cap dialog is triggered from two places: (a) a 429 returned - // from the chat transport when the user tries to send on a premium model over cap β€” - // surfaced via the agent-store flag β€” and (b) nothing else right now - // (switching models is free). Keeping the open state in the store means - // the hook layer can flip it without threading props through. - const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted); - const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted); - const jobsUpgradeRequired = useAgentStore((s) => s.jobsUpgradeRequired); - const setJobsUpgradeRequired = useAgentStore((s) => s.setJobsUpgradeRequired); - const updateSessionModel = useSessionStore((s) => s.updateSessionModel); - const [awaitingTopUp, setAwaitingTopUp] = useState(false); - const [modelSwitchError, setModelSwitchError] = useState(null); - const lastSentRef = useRef(''); - - useEffect(() => { - modelOptionsRef.current = modelOptions; - }, [modelOptions]); - - useEffect(() => { - sessionIdRef.current = sessionId; - }, [sessionId]); - - useEffect(() => { - let cancelled = false; - apiFetch('/api/config/model') - .then((res) => (res.ok ? res.json() : null)) - .then((data) => { - if (cancelled || !data?.available) return; - const claude = data.available.find((m: { provider?: string; id?: string }) => ( - m.provider === 'anthropic' && m.id - )); - if (!claude?.id) return; - - const next = DEFAULT_MODEL_OPTIONS.map((option) => ( - isClaudeModel(option) - ? { ...option, modelPath: claude.id, name: claude.label ?? option.name } - : option - )); - modelOptionsRef.current = next; - setModelOptions(next); - if (!sessionIdRef.current) { - const current = data.current ? findModelByPath(data.current, next) : null; - if (current) setSelectedModelId(current.id); - } - }) - .catch(() => { /* ignore */ }); - return () => { cancelled = true; }; - }, []); - // Model is per-session: fetch this tab's current model every time the - // session changes. Other tabs keep their own selections independently. + // Sync with backend on mount useEffect(() => { - if (!sessionId) return; - let cancelled = false; - apiFetch(`/api/session/${sessionId}`) + fetch('/api/config/model') .then((res) => (res.ok ? res.json() : null)) .then((data) => { - if (cancelled) return; - if (data?.model) { - const model = findModelByPath(data.model, modelOptionsRef.current); + if (data?.current) { + const model = findModelByPath(data.current); if (model) setSelectedModelId(model.id); - updateSessionModel(sessionId, data.model); } }) .catch(() => { /* ignore */ }); - return () => { cancelled = true; }; - }, [sessionId, updateSessionModel]); + }, []); - const selectedModel = modelOptions.find(m => m.id === selectedModelId) || modelOptions[0]; + const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; - // Auto-focus the textarea when the session becomes ready + // Auto-focus the textarea when the session becomes ready (disabled -> false) useEffect(() => { - if (!disabled && !isProcessing && inputRef.current) { + if (!disabled && inputRef.current) { inputRef.current.focus(); } - }, [disabled, isProcessing]); + }, [disabled]); const handleSend = useCallback(() => { if (input.trim() && !disabled) { - lastSentRef.current = input; onSend(input); setInput(''); } }, [input, disabled, onSend]); - // When the chat transport reports a premium-model quota 429, restore the typed - // text so the user doesn't lose their message. - useEffect(() => { - if (claudeQuotaExhausted && lastSentRef.current) { - setInput(lastSentRef.current); - } - }, [claudeQuotaExhausted]); - - // Refresh the quota display whenever the session changes (user might - // have started another tab that spent quota). - useEffect(() => { - if (sessionId) refreshQuota(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [sessionId]); - const handleKeyDown = useCallback( (e: KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { @@ -258,116 +116,16 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, const handleSelectModel = async (model: ModelOption) => { handleModelClose(); - if (!sessionId) return; try { - const res = await apiFetch(`/api/session/${sessionId}/model`, { + const res = await apiFetch('/api/config/model', { method: 'POST', body: JSON.stringify({ model: model.modelPath }), }); if (res.ok) { setSelectedModelId(model.id); - updateSessionModel(sessionId, model.modelPath); - setModelSwitchError(null); - return; - } - setModelSwitchError(await readApiErrorMessage(res, 'Could not switch model.')); - } catch (error) { - setModelSwitchError(error instanceof Error ? error.message : 'Could not switch model.'); - } - }; - - // Dialog close: just clear the flag. The typed text is already restored. - const handleCapDialogClose = useCallback(() => { - setClaudeQuotaExhausted(false); - }, [setClaudeQuotaExhausted]); - - // "Use a free model" β€” switch the current session to Kimi (or the first - // non-premium option) and auto-retry the send that tripped the cap. - const handleUseFreeModel = useCallback(async () => { - setClaudeQuotaExhausted(false); - if (!sessionId) return; - const free = modelOptions.find(m => m.modelPath === FIRST_FREE_MODEL_PATH) - ?? firstFreeModel(modelOptions); - try { - const res = await apiFetch(`/api/session/${sessionId}/model`, { - method: 'POST', - body: JSON.stringify({ model: free.modelPath }), - }); - if (res.ok) { - setSelectedModelId(free.id); - updateSessionModel(sessionId, free.modelPath); - const retryText = lastSentRef.current; - if (retryText) { - onSend(retryText); - setInput(''); - lastSentRef.current = ''; - } } } catch { /* ignore */ } - }, [sessionId, onSend, setClaudeQuotaExhausted, modelOptions, updateSessionModel]); - - const handlePremiumUpgradeClick = useCallback(async () => { - if (!sessionId) return; - try { - await apiFetch(`/api/pro-click/${sessionId}`, { - method: 'POST', - body: JSON.stringify({ source: 'premium_cap_dialog', target: 'pro_pricing' }), - }); - } catch { - /* tracking is best-effort */ - } - }, [sessionId]); - - const handleJobsUpgradeClose = useCallback(() => { - setJobsUpgradeRequired(null); - setAwaitingTopUp(false); - }, [setJobsUpgradeRequired]); - - const handleJobsUpgradeClick = useCallback(async () => { - setAwaitingTopUp(true); - if (!sessionId || !jobsUpgradeRequired) return; - try { - await apiFetch(`/api/pro-click/${sessionId}`, { - method: 'POST', - body: JSON.stringify({ source: 'hf_jobs_billing_dialog', target: 'hf_billing' }), - }); - } catch { - /* tracking is best-effort */ - } - }, [sessionId, jobsUpgradeRequired]); - - const handleJobsRetry = useCallback(() => { - const namespace = jobsUpgradeRequired?.namespace; - setJobsUpgradeRequired(null); - setAwaitingTopUp(false); - const msg = namespace - ? `I just added credits to the \`${namespace}\` namespace. Please retry the previous job.` - : "I just added credits. Please retry the previous job."; - onSend(msg); - }, [jobsUpgradeRequired, setJobsUpgradeRequired, onSend]); - - // Auto-retry when the user comes back to this tab after clicking "Add credits". - // Browsers fire visibilitychange when the tab regains focus from a sibling tab. - useEffect(() => { - if (!awaitingTopUp || !jobsUpgradeRequired) return; - const onVisible = () => { - if (document.visibilityState === 'visible') { - handleJobsRetry(); - } - }; - document.addEventListener('visibilitychange', onVisible); - return () => document.removeEventListener('visibilitychange', onVisible); - }, [awaitingTopUp, jobsUpgradeRequired, handleJobsRetry]); - - // Hide the chip until the user has actually burned quota; opening a - // premium-model session without sending should not populate a counter. - const premiumChip = (() => { - if (!quota || quota.premiumUsedToday === 0) return null; - if (quota.plan === 'free') { - return quota.premiumRemaining > 0 ? 'Free today' : 'Pro only'; - } - return `${quota.premiumUsedToday}/${quota.premiumDailyCap} today`; - })(); + }; return ( setInput(e.target.value)} onKeyDown={handleKeyDown} - placeholder={placeholder} - disabled={disabled || isProcessing} + placeholder="Ask anything..." + disabled={disabled} variant="standard" inputRef={inputRef} InputProps={{ @@ -431,49 +189,26 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, } }} /> - {isProcessing ? ( - - - - - - - ) : ( - - - - )} + + {disabled ? : } + {/* Powered By Badge */} @@ -531,7 +266,7 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, } }} > - {modelOptions.map((model) => ( + {MODEL_OPTIONS.map((model) => ( handleSelectModel(model)} @@ -567,19 +302,6 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, }} /> )} - {isPremiumModel(model) && premiumChip && ( - - )} } secondary={model.description} @@ -590,38 +312,6 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, ))} - - - - setModelSwitchError(null)} - autoHideDuration={6000} - > - setModelSwitchError(null)} - sx={{ fontSize: '0.8rem', maxWidth: 480 }} - > - {modelSwitchError} - - ); diff --git a/frontend/src/components/Chat/ExpiredBanner.tsx b/frontend/src/components/Chat/ExpiredBanner.tsx deleted file mode 100644 index 32f638c245089fc1d26561c8f2706609a5a48345..0000000000000000000000000000000000000000 --- a/frontend/src/components/Chat/ExpiredBanner.tsx +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Shown inline in a chat when the backend no longer recognizes the - * session id (typically: Space was restarted). Lets the user catch the - * agent up with a summary of the prior conversation, or start over. - */ -import { useState, useCallback } from 'react'; -import { Box, Button, CircularProgress, Typography } from '@mui/material'; -import { apiFetch } from '@/utils/api'; -import { useSessionStore } from '@/store/sessionStore'; -import { useAgentStore } from '@/store/agentStore'; -import { loadBackendMessages } from '@/lib/backend-message-store'; -import { loadMessages } from '@/lib/chat-message-store'; -import { uiMessagesToLLMMessages } from '@/lib/convert-llm-messages'; -import { logger } from '@/utils/logger'; - -interface Props { - sessionId: string; -} - -export default function ExpiredBanner({ sessionId }: Props) { - const { renameSession, deleteSession, updateSessionModel } = useSessionStore(); - const [busy, setBusy] = useState<'catch-up' | 'start-over' | null>(null); - const [error, setError] = useState(null); - - const handleCatchUp = useCallback(async () => { - setBusy('catch-up'); - setError(null); - try { - // Prefer the raw backend-message cache; fall back to reconstructing - // from UIMessages (for sessions that predate the backend cache). - let messages = loadBackendMessages(sessionId); - if (!messages || messages.length === 0) { - const uiMsgs = loadMessages(sessionId); - if (uiMsgs.length > 0) messages = uiMessagesToLLMMessages(uiMsgs); - } - if (!messages || messages.length === 0) { - setError('Nothing to summarize from this chat.'); - setBusy(null); - return; - } - - const res = await apiFetch('/api/session/restore-summary', { - method: 'POST', - body: JSON.stringify({ messages }), - }); - if (!res.ok) throw new Error(`restore-summary failed: ${res.status}`); - const data = await res.json(); - const newId = data.session_id as string | undefined; - if (!newId) throw new Error('no session_id in response'); - - useAgentStore.getState().clearSessionState(sessionId); - renameSession(sessionId, newId); - if (data.model) updateSessionModel(newId, data.model); - } catch (e) { - logger.warn('Catch-up failed:', e); - setError("Couldn't catch up β€” try starting over."); - setBusy(null); - } - }, [sessionId, renameSession, updateSessionModel]); - - const handleStartOver = useCallback(() => { - setBusy('start-over'); - useAgentStore.getState().clearSessionState(sessionId); - deleteSession(sessionId); - }, [sessionId, deleteSession]); - - return ( - - - Where were we? - - - Let me skim the conversation so far and pick up right where we left - off β€” or we can start something new. - - - - - - {error && ( - - {error} - - )} - - ); -} diff --git a/frontend/src/components/Chat/MarkdownContent.tsx b/frontend/src/components/Chat/MarkdownContent.tsx index 0d1e69171d3955e998d78807006862bd95422c34..beb682720bf2b4d846b67a86d45607bc4544044b 100644 --- a/frontend/src/components/Chat/MarkdownContent.tsx +++ b/frontend/src/components/Chat/MarkdownContent.tsx @@ -1,4 +1,4 @@ -import { useMemo, useRef, useState, useEffect, type ComponentPropsWithoutRef } from 'react'; +import { useMemo, useRef, useState, useEffect } from 'react'; import { Box } from '@mui/material'; import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; @@ -70,30 +70,16 @@ const markdownSx: SxProps = { width: '100%', my: 2, fontSize: '0.85rem', - display: 'block', - overflowX: 'auto', - WebkitOverflowScrolling: 'touch', - }, - '& thead': { - position: 'sticky', - top: 0, }, '& th': { borderBottom: '2px solid var(--border-hover)', - bgcolor: 'var(--hover-bg)', textAlign: 'left', - px: 1.5, - py: 0.75, + p: 1, fontWeight: 600, - whiteSpace: 'nowrap', }, '& td': { borderBottom: '1px solid var(--tool-border)', - px: 1.5, - py: 0.75, - }, - '& tr:nth-of-type(even) td': { - bgcolor: 'color-mix(in srgb, var(--hover-bg) 50%, transparent)', + p: 1, }, '& hr': { @@ -166,17 +152,9 @@ export default function MarkdownContent({ content, sx, isStreaming = false }: Ma const remarkPlugins = useMemo(() => [remarkGfm], []); - const components = useMemo(() => ({ - a: ({ href, children, ...props }: ComponentPropsWithoutRef<'a'>) => ( - - {children} - - ), - }), []); - return ( - {displayContent} + {displayContent} ); } diff --git a/frontend/src/components/Chat/MessageBubble.tsx b/frontend/src/components/Chat/MessageBubble.tsx index ab971205c18a9c60bb23b398e83cf1090dcd5116..d7d36330bd762d41b267323b6b3f79242e4feef5 100644 --- a/frontend/src/components/Chat/MessageBubble.tsx +++ b/frontend/src/components/Chat/MessageBubble.tsx @@ -1,50 +1,51 @@ import UserMessage from './UserMessage'; import AssistantMessage from './AssistantMessage'; -import type { UIMessage } from 'ai'; +import type { Message } from '@/types/agent'; interface MessageBubbleProps { - message: UIMessage; + message: Message; + /** True if this is the user message that starts the last turn. */ isLastTurn?: boolean; + /** Callback to undo (remove) the last turn. */ onUndoTurn?: () => void; - onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise; + /** Whether the agent is currently processing. */ isProcessing?: boolean; + /** True when this message is actively receiving streaming chunks. */ isStreaming?: boolean; - sessionId?: string | null; - approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; } +/** + * Thin dispatcher β€” routes each message to the correct + * specialised component based on its role / content. + */ export default function MessageBubble({ message, isLastTurn = false, onUndoTurn, - onEditAndRegenerate, isProcessing = false, isStreaming = false, - sessionId, - approveTools, }: MessageBubbleProps) { + // Legacy approval-only messages (from old localStorage data) β€” skip them. + // Approvals are now rendered inline within ToolCallGroup. + if (message.approval && !message.content && !message.segments?.length) { + return null; + } + if (message.role === 'user') { return ( ); } if (message.role === 'assistant') { - return ( - - ); + return ; } + // Fallback (tool messages, etc.) return null; } diff --git a/frontend/src/components/Chat/MessageList.tsx b/frontend/src/components/Chat/MessageList.tsx index 5e3efcaea901bf97970f7644fae162046e3382b2..ca1201303490a52568f95e7b298412611cae76f9 100644 --- a/frontend/src/components/Chat/MessageList.tsx +++ b/frontend/src/components/Chat/MessageList.tsx @@ -1,17 +1,16 @@ -import { useCallback, useEffect, useRef, useMemo } from 'react'; +import { useEffect, useRef, useMemo, useCallback } from 'react'; import { Box, Stack, Typography } from '@mui/material'; import MessageBubble from './MessageBubble'; -import ActivityStatusBar from './ActivityStatusBar'; +import ThinkingIndicator from './ThinkingIndicator'; import { useAgentStore } from '@/store/agentStore'; -import type { UIMessage } from 'ai'; +import { useSessionStore } from '@/store/sessionStore'; +import { apiFetch } from '@/utils/api'; +import { logger } from '@/utils/logger'; +import type { Message } from '@/types/agent'; interface MessageListProps { - messages: UIMessage[]; + messages: Message[]; isProcessing: boolean; - sessionId?: string | null; - approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; - onUndoLastTurn: () => void | Promise; - onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise; } function getGreeting(): string { @@ -21,6 +20,7 @@ function getGreeting(): string { return 'Evening'; } +/** Minimal greeting shown when the conversation is empty. */ function WelcomeGreeting() { const { user } = useAgentStore(); const firstName = user?.name?.split(' ')[0] || user?.username; @@ -58,40 +58,58 @@ function WelcomeGreeting() { ); } -export default function MessageList({ messages, isProcessing, sessionId, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) { +export default function MessageList({ messages, isProcessing }: MessageListProps) { const scrollContainerRef = useRef(null); const stickToBottom = useRef(true); + const { activeSessionId } = useSessionStore(); + const { removeLastTurn, currentTurnMessageId } = useAgentStore(); + // ── Scroll-to-bottom helper ───────────────────────────────────── const scrollToBottom = useCallback(() => { const el = scrollContainerRef.current; if (el) el.scrollTop = el.scrollHeight; }, []); + // ── Track user scroll intent ──────────────────────────────────── useEffect(() => { const el = scrollContainerRef.current; if (!el) return; + const onScroll = () => { const distFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight; stickToBottom.current = distFromBottom < 80; }; + el.addEventListener('scroll', onScroll, { passive: true }); return () => el.removeEventListener('scroll', onScroll); }, []); + // ── Auto-scroll on new messages / state changes ───────────────── useEffect(() => { if (stickToBottom.current) scrollToBottom(); }, [messages, isProcessing, scrollToBottom]); + // ── Auto-scroll on DOM mutations (streaming content growth) ───── useEffect(() => { const el = scrollContainerRef.current; if (!el) return; + const observer = new MutationObserver(() => { - if (stickToBottom.current) el.scrollTop = el.scrollHeight; + if (stickToBottom.current) { + el.scrollTop = el.scrollHeight; + } + }); + + observer.observe(el, { + childList: true, + subtree: true, + characterData: true, }); - observer.observe(el, { childList: true, subtree: true, characterData: true }); + return () => observer.disconnect(); }, []); + // Find the index of the last user message (start of the last turn) const lastUserMsgId = useMemo(() => { for (let i = messages.length - 1; i >= 0; i--) { if (messages[i].role === 'user') return messages[i].id; @@ -99,13 +117,15 @@ export default function MessageList({ messages, isProcessing, sessionId, approve return null; }, [messages]); - // The last assistant message is "streaming" when we're processing - const lastAssistantId = useMemo(() => { - for (let i = messages.length - 1; i >= 0; i--) { - if (messages[i].role === 'assistant') return messages[i].id; + const handleUndoLastTurn = useCallback(async () => { + if (!activeSessionId) return; + try { + await apiFetch(`/api/undo/${activeSessionId}`, { method: 'POST' }); + removeLastTurn(activeSessionId); + } catch (e) { + logger.error('Undo failed:', e); } - return null; - }, [messages]); + }, [activeSessionId, removeLastTurn]); return ( )) )} - + {/* Show thinking dots only when processing but no streaming message yet */} + {isProcessing && !currentTurnMessageId && } + {/* Sentinel β€” keeps scroll anchor at the bottom */}

diff --git a/frontend/src/components/Chat/ToolCallGroup.tsx b/frontend/src/components/Chat/ToolCallGroup.tsx index 9f09b6b96e9c9775f33b67583b019a52645d0492..caa31e85c40061a37a4acd97936c740c80948e74 100644 --- a/frontend/src/components/Chat/ToolCallGroup.tsx +++ b/frontend/src/components/Chat/ToolCallGroup.tsx @@ -1,642 +1,113 @@ -import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { Alert, Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material'; +import { useCallback, useState } from 'react'; +import { Box, Stack, Typography, Chip, Button, TextField, IconButton, Link } from '@mui/material'; import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline'; import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline'; +import MoreHorizIcon from '@mui/icons-material/MoreHoriz'; import OpenInNewIcon from '@mui/icons-material/OpenInNew'; import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty'; import LaunchIcon from '@mui/icons-material/Launch'; import SendIcon from '@mui/icons-material/Send'; -import BlockIcon from '@mui/icons-material/Block'; -import { useAgentStore, type ResearchAgentState } from '@/store/agentStore'; +import { useAgentStore } from '@/store/agentStore'; import { useLayoutStore } from '@/store/layoutStore'; +import { useSessionStore } from '@/store/sessionStore'; +import { apiFetch } from '@/utils/api'; import { logger } from '@/utils/logger'; -import { RESEARCH_MAX_STEPS } from '@/lib/research-store'; -import type { UIMessage } from 'ai'; - -// --------------------------------------------------------------------------- -// Type helpers β€” extract the dynamic-tool part type from UIMessage -// --------------------------------------------------------------------------- -type DynamicToolPart = Extract; - -type ToolPartState = DynamicToolPart['state']; - -/** Check if a tool part was cancelled (output-error with cancellation message). */ -function isCancelledTool(tool: DynamicToolPart): boolean { - return tool.state === 'output-error' && - typeof (tool as Record).errorText === 'string' && - ((tool as Record).errorText as string).includes('Cancelled by user'); -} +import type { TraceLog } from '@/types/agent'; interface ToolCallGroupProps { - tools: DynamicToolPart[]; - approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>) => Promise; -} - -// --------------------------------------------------------------------------- -// Research sub-steps (inline under the research tool row) -// --------------------------------------------------------------------------- - -/** Hook that forces a re-render every second while enabled β€” used so each - * research card can compute its own elapsed seconds synchronously from - * Date.now() without needing its own timer. */ -function useSecondTick(enabled: boolean): void { - const [, setTick] = useState(0); - useEffect(() => { - if (!enabled) return; - const id = setInterval(() => setTick(t => t + 1), 1000); - return () => clearInterval(id); - }, [enabled]); -} - -/** Compute elapsed seconds from startedAt (or null). Call under useSecondTick. */ -function computeElapsed(startedAt: number | null): number | null { - if (startedAt === null) return null; - return Math.round((Date.now() - startedAt) / 1000); -} - -/** Format token count like the CLI: "12.4k" or "800". */ -function formatTokens(tokens: number): string { - return tokens >= 1000 ? `${(tokens / 1000).toFixed(1)}k` : String(tokens); -} - -/** Format elapsed seconds like the CLI: "18s" or "2m 5s". */ -function formatElapsed(seconds: number): string { - if (seconds < 60) return `${seconds}s`; - return `${Math.floor(seconds / 60)}m ${seconds % 60}s`; + tools: TraceLog[]; } -/** Build the research stats chip label. */ -function researchChipLabel( - stats: { toolCount: number; tokenCount: number; startedAt: number | null; finalElapsed: number | null }, - liveElapsed: number | null, -): string | null { - const elapsed = stats.finalElapsed ?? liveElapsed; - if (elapsed === null && stats.toolCount === 0) return null; - const parts: string[] = []; - if (stats.startedAt !== null) parts.push('running'); - if (stats.toolCount > 0) parts.push(`${stats.toolCount} tools`); - if (stats.tokenCount > 0) parts.push(`${formatTokens(stats.tokenCount)} tokens`); - if (elapsed !== null) parts.push(formatElapsed(elapsed)); - return parts.join(' \u00B7 '); +/** Check if a running tool has been stuck for too long (5 minutes). */ +const TOOL_TIMEOUT_MS = 5 * 60 * 1000; +function isTimedOut(log: TraceLog): boolean { + if (log.completed || log.approvalStatus === 'pending') return false; + const elapsed = Date.now() - new Date(log.timestamp).getTime(); + return elapsed > TOOL_TIMEOUT_MS; } -/** Parse JSON args from a step string like "tool_name {json}" (may be truncated at 80 chars). */ -function parseStepArgs(step: string): Record { - const jsonStart = step.indexOf('{'); - if (jsonStart < 0) return {}; - const jsonStr = step.slice(jsonStart); - try { - const parsed = JSON.parse(jsonStr); - const result: Record = {}; - for (const [k, v] of Object.entries(parsed)) { - if (typeof v === 'string') result[k] = v; - } - return result; - } catch { - // JSON likely truncated β€” extract key-value pairs via regex - const result: Record = {}; - // Match complete "key": "value" pairs - for (const m of jsonStr.matchAll(/"(\w+)":\s*"([^"]*)"/g)) { - result[m[1]] = m[2]; - } - // Match truncated trailing value: "key": "value... (no closing quote) - if (Object.keys(result).length === 0 || !result.query) { - const trunc = jsonStr.match(/"(\w+)":\s*"([^"]+)$/); - if (trunc && !result[trunc[1]]) { - result[trunc[1]] = trunc[2]; - } - } - return result; - } -} - -/** Pretty labels for research sub-agent tool calls */ -function formatResearchStep(raw: string): { label: string } { - // Backend sends logs like "β–Έ tool_name {args}" β€” strip the prefix - const step = raw.replace(/^β–Έ\s*/, ''); - const args = parseStepArgs(step); - - if (step.startsWith('github_find_examples')) { - const detail = (args.keyword) || (args.repo); - return { label: detail ? `Finding examples: ${detail}` : 'Finding examples' }; - } - if (step.startsWith('github_read_file')) { - const path = (args.path) || ''; - const filename = path.split('/').pop() || path; - return { label: filename ? `Reading ${filename}` : 'Reading file' }; - } - if (step.startsWith('explore_hf_docs')) { - const endpoint = (args.endpoint) || (args.query); - return { label: endpoint ? `Exploring docs: ${endpoint}` : 'Exploring docs' }; +// ── Status icon based on tool state ───────────────────────────────── +function StatusIcon({ log }: { log: TraceLog }) { + // Awaiting approval + if (log.approvalStatus === 'pending') { + return ; } - if (step.startsWith('fetch_hf_docs')) { - const url = (args.url) || ''; - const page = url.split('/').pop()?.replace(/\.md$/, ''); - return { label: page ? `Reading docs: ${page}` : 'Fetching docs' }; + // Rejected + if (log.approvalStatus === 'rejected') { + return ; } - if (step.startsWith('hf_inspect_dataset')) { - const dataset = (args.dataset); - return { label: dataset ? `Inspecting dataset: ${dataset}` : 'Inspecting dataset' }; + // Timed out + if (isTimedOut(log)) { + return ; } - if (step.startsWith('hf_papers')) { - const op = args.operation as string; - const detail = (args.query) || (args.arxiv_id); - const opLabels: Record = { - trending: 'Browsing trending papers', - search: 'Searching papers', - paper_details: 'Reading paper details', - read_paper: 'Reading paper', - citation_graph: 'Tracing citations', - snippet_search: 'Searching paper snippets', - recommend: 'Finding related papers', - find_datasets: 'Finding paper datasets', - find_models: 'Finding paper models', - find_collections: 'Finding paper collections', - find_all_resources: 'Finding paper resources', - }; - const base = (op && opLabels[op]) || 'Searching papers'; - return { label: detail ? `${base}: ${detail}` : base }; - } - if (step.startsWith('find_hf_api')) { - const detail = (args.query) || (args.tag); - return { label: detail ? `Finding API: ${detail}` : 'Finding API endpoints' }; - } - if (step.startsWith('hf_repo_files')) { - const repo = (args.repo_id) || (args.repo); - return { label: repo ? `Reading ${repo} files` : 'Reading repo files' }; - } - if (step.startsWith('read')) { - const path = (args.path) || ''; - const filename = path.split('/').pop(); - return { label: filename ? `Reading ${filename}` : 'Reading file' }; - } - if (step.startsWith('bash')) { - const cmd = args.command as string; - const short = cmd && cmd.length > 40 ? cmd.slice(0, 40) + '...' : cmd; - return { label: short ? `Running: ${short}` : 'Running command' }; - } - return { label: step.replace(/^β–Έ\s*/, '') }; -} - -/** Rolling display of research sub-tool calls for a single agent. */ -function ResearchSteps({ steps }: { steps: string[] }) { - const visible = steps.slice(-RESEARCH_MAX_STEPS); - if (visible.length === 0) return null; - - return ( - - {visible.map((step, i) => { - const { label } = formatResearchStep(step); - const isLast = i === visible.length - 1; - return ( - - {isLast ? ( - - ) : ( - - )} - - {label} - - - ); - })} - - ); -} - -// --------------------------------------------------------------------------- -// Trackio dashboard embed -// --------------------------------------------------------------------------- - -// HF repo IDs are `/` where each segment is alphanumerics plus -// `_`, `.`, `-`. Anything else (slashes, spaces, query params, missing owner) -// would let an attacker-controlled string redirect the embed to a different -// Space, so we refuse to render rather than build a malformed URL. -const SPACE_ID_PATTERN = /^[a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+$/; - -function isValidSpaceId(spaceId: string): boolean { - return SPACE_ID_PATTERN.test(spaceId); -} - -/** HF Space embed subdomain: 'user/space_name' β†’ 'user-space-name'. */ -function spaceIdToSubdomain(spaceId: string): string { - return spaceId - .toLowerCase() - .replace(/[/_.]/g, '-') - .replace(/-+/g, '-') - .replace(/^-|-$/g, ''); -} - -function buildTrackioEmbedUrl(spaceId: string, project?: string): string { - // __theme=dark is gradio's standard query param to force the embedded - // dashboard into dark mode so it blends with the surrounding chat instead - // of flashing a bright white panel inside the dark UI. - const params = new URLSearchParams({ - sidebar: 'hidden', - footer: 'false', - __theme: 'dark', - }); - if (project) params.set('project', project); - return `https://${spaceIdToSubdomain(spaceId)}.hf.space/?${params.toString()}`; -} - -function buildTrackioPageUrl(spaceId: string, project?: string): string { - const qs = project ? `?${new URLSearchParams({ project }).toString()}` : ''; - return `https://huggingface.co/spaces/${spaceId}${qs}`; -} - -function TrackioEmbed({ spaceId, project }: { spaceId: string; project?: string }) { - const [expanded, setExpanded] = useState(true); - const [iframeLoaded, setIframeLoaded] = useState(false); - const embedUrl = useMemo(() => buildTrackioEmbedUrl(spaceId, project), [spaceId, project]); - const pageUrl = useMemo(() => buildTrackioPageUrl(spaceId, project), [spaceId, project]); - const label = project ? `${spaceId} Β· ${project}` : spaceId; - - if (!isValidSpaceId(spaceId)) return null; - - return ( - - - e.stopPropagation()} - sx={{ - px: 1.25, - py: 0.5, - borderBottom: expanded ? '1px solid var(--tool-border)' : 'none', - }} - > - - trackio - - - {label} - - e.stopPropagation()} - sx={{ - display: 'inline-flex', - alignItems: 'center', - gap: 0.4, - color: 'var(--accent-yellow)', - fontSize: '0.65rem', - textDecoration: 'none', - '&:hover': { textDecoration: 'underline' }, - }} - > - - Open - - - - {expanded && ( - -