Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +20 -0
- LICENSE +201 -0
- README.md +48 -6
- __init__.py +6 -0
- baseline_agent.py +185 -0
- client.py +42 -0
- gradio_app.py +756 -0
- models.py +41 -0
- openenv.yaml +77 -0
- openenv_whipstudio.egg-info/PKG-INFO +16 -0
- openenv_whipstudio.egg-info/SOURCES.txt +25 -0
- openenv_whipstudio.egg-info/dependency_links.txt +1 -0
- openenv_whipstudio.egg-info/entry_points.txt +3 -0
- openenv_whipstudio.egg-info/requires.txt +10 -0
- openenv_whipstudio.egg-info/top_level.txt +1 -0
- project_status_report_2703.md +70 -0
- pyproject.toml +31 -0
- server/__init__.py +1 -0
- server/app.py +254 -0
- server/environment.py +165 -0
- server/requirements.txt +11 -0
- server/sandbox.py +88 -0
- server/start.sh +7 -0
- server/tasks/__init__.py +1 -0
- server/tasks/graders.py +256 -0
- server/tasks/task1_broken_loop.py +33 -0
- server/tasks/task2_nan_loss.py +32 -0
- server/tasks/task3_oom_leakage.py +50 -0
- server/tasks/task4_wrong_loss.py +62 -0
- server/tasks/task5_frozen_backbone.py +62 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN pip install --no-cache-dir \
|
| 6 |
+
torch==2.2.0+cpu torchvision==0.17.0+cpu \
|
| 7 |
+
--index-url https://download.pytorch.org/whl/cpu
|
| 8 |
+
|
| 9 |
+
RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY server/requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
ENV PYTHONUNBUFFERED=1
|
| 18 |
+
ENV ENABLE_WEB_INTERFACE=false
|
| 19 |
+
|
| 20 |
+
CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,10 +1,52 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: WhipStudio Env
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
base_path: /ui
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# ML Debug Environment
|
| 12 |
+
|
| 13 |
+
An OpenEnv-compatible RL environment where agents debug broken PyTorch training scripts.
|
| 14 |
+
|
| 15 |
+
## Environment Description
|
| 16 |
+
The agent receives a broken Python training script and must return a corrected version.
|
| 17 |
+
Five tasks simulate real ML production bugs with increasing complexity.
|
| 18 |
+
|
| 19 |
+
## Action Space
|
| 20 |
+
- fixed_code (str, required): Complete corrected Python script
|
| 21 |
+
- explanation (str, optional): Description of bugs found
|
| 22 |
+
- attempt_number (int, 1-3): Which fix attempt this is
|
| 23 |
+
|
| 24 |
+
## Observation Space
|
| 25 |
+
- task_id: Which task (task1/task2/task3)
|
| 26 |
+
- task_description: Plain English instructions
|
| 27 |
+
- buggy_code: The broken script
|
| 28 |
+
- error_log: stdout+stderr from previous attempt
|
| 29 |
+
- last_reward: Score from previous attempt (0.0 on first step)
|
| 30 |
+
- metrics: {exit_code, elapsed_seconds, timed_out, step, best_reward_so_far}
|
| 31 |
+
|
| 32 |
+
## Reward Function
|
| 33 |
+
Continuous score 0.0–1.0. Partial credit for every improvement.
|
| 34 |
+
See `server/tasks/graders.py` for per-task scoring logic.
|
| 35 |
+
|
| 36 |
+
## Tasks
|
| 37 |
+
| Task | Difficulty | Bug Type |
|
| 38 |
+
|------|-----------|----------|
|
| 39 |
+
| task1 | Easy | Wrong optimizer order + bad learning rate |
|
| 40 |
+
| task2 | Medium | Silent NaN from log(0) numerical instability |
|
| 41 |
+
| task3 | Hard | OOM memory leak + train/val data leakage |
|
| 42 |
+
| task4 | Medium | Wrong loss function |
|
| 43 |
+
| task5 | Medium | Frozen backbone |
|
| 44 |
+
|
| 45 |
+
## Setup
|
| 46 |
+
```bash
|
| 47 |
+
pip install openenv-core
|
| 48 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Endpoints
|
| 52 |
+
POST /reset, POST /step, GET /state, GET /tasks, POST /grader, GET /baseline
|
__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ML Debug OpenEnv package."""
|
| 2 |
+
|
| 3 |
+
from .models import MLDebugAction, MLDebugObservation
|
| 4 |
+
from .client import MLDebugEnv
|
| 5 |
+
|
| 6 |
+
__all__ = ["MLDebugAction", "MLDebugObservation", "MLDebugEnv"]
|
baseline_agent.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
load_dotenv(override=True)
|
| 6 |
+
|
| 7 |
+
import httpx
|
| 8 |
+
|
| 9 |
+
SYSTEM_PROMPT = """
|
| 10 |
+
You are an expert PyTorch debugging agent.
|
| 11 |
+
You receive a broken training script and must fix ALL bugs in it.
|
| 12 |
+
Rules:
|
| 13 |
+
- Return ONLY the complete corrected Python code, nothing else.
|
| 14 |
+
- No markdown, no backticks, no explanation text.
|
| 15 |
+
- The script must print losses in format: LOSSES:[v1, v2, ...]
|
| 16 |
+
- For task3, also print: VAL_ACCS:[v1,...] and FINAL_LOSS:X.XX
|
| 17 |
+
- Keep all torch.manual_seed() calls intact.
|
| 18 |
+
""".strip()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_model():
|
| 22 |
+
from smolagents import InferenceClientModel
|
| 23 |
+
|
| 24 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 25 |
+
if not hf_token:
|
| 26 |
+
raise RuntimeError(
|
| 27 |
+
"HF_TOKEN is not set. Set HF_TOKEN to run /baseline with InferenceClientModel."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
return InferenceClientModel(
|
| 31 |
+
model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
| 32 |
+
token=hf_token,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _generate_fixed_code(model, prompt: str) -> str:
|
| 37 |
+
def _extract_text(response) -> str:
|
| 38 |
+
if isinstance(response, str):
|
| 39 |
+
return response
|
| 40 |
+
|
| 41 |
+
if hasattr(response, "content"):
|
| 42 |
+
content = getattr(response, "content")
|
| 43 |
+
if isinstance(content, str):
|
| 44 |
+
return content
|
| 45 |
+
if isinstance(content, list):
|
| 46 |
+
chunks = []
|
| 47 |
+
for item in content:
|
| 48 |
+
if isinstance(item, str):
|
| 49 |
+
chunks.append(item)
|
| 50 |
+
elif isinstance(item, dict):
|
| 51 |
+
text = item.get("text") or item.get("content")
|
| 52 |
+
if text:
|
| 53 |
+
chunks.append(str(text))
|
| 54 |
+
if chunks:
|
| 55 |
+
return "\n".join(chunks)
|
| 56 |
+
|
| 57 |
+
if isinstance(response, dict):
|
| 58 |
+
text = response.get("content") or response.get("text")
|
| 59 |
+
if isinstance(text, str):
|
| 60 |
+
return text
|
| 61 |
+
|
| 62 |
+
return str(response)
|
| 63 |
+
|
| 64 |
+
if hasattr(model, "generate"):
|
| 65 |
+
generate = getattr(model, "generate")
|
| 66 |
+
messages = [
|
| 67 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 68 |
+
{"role": "user", "content": prompt},
|
| 69 |
+
]
|
| 70 |
+
try:
|
| 71 |
+
return _extract_text(generate(messages=messages))
|
| 72 |
+
except TypeError:
|
| 73 |
+
return _extract_text(generate(messages))
|
| 74 |
+
|
| 75 |
+
if callable(model):
|
| 76 |
+
try:
|
| 77 |
+
return _extract_text(model(prompt, system_prompt=SYSTEM_PROMPT))
|
| 78 |
+
except TypeError:
|
| 79 |
+
return _extract_text(model(prompt))
|
| 80 |
+
|
| 81 |
+
raise AttributeError("Model does not support callable() or generate() inference APIs")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
async def run_single_task(task_id: str, env_url: str = "http://localhost:7860") -> float:
|
| 85 |
+
"""Backwards-compatible wrapper that returns just the score."""
|
| 86 |
+
result = await run_single_task_detailed(task_id, env_url)
|
| 87 |
+
return result["score"]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
async def run_single_task_detailed(task_id: str, env_url: str = "http://localhost:7860") -> dict:
|
| 91 |
+
"""Run the baseline agent on a single task. Returns detailed results."""
|
| 92 |
+
model = get_model()
|
| 93 |
+
timeout = httpx.Timeout(900.0, connect=10.0)
|
| 94 |
+
|
| 95 |
+
attempts_log = []
|
| 96 |
+
|
| 97 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 98 |
+
reset_resp = await client.post(f"{env_url}/reset", json={"task_id": task_id})
|
| 99 |
+
reset_resp.raise_for_status()
|
| 100 |
+
obs = reset_resp.json().get("observation", reset_resp.json())
|
| 101 |
+
|
| 102 |
+
best_reward = 0.0
|
| 103 |
+
best_code = ""
|
| 104 |
+
best_output = ""
|
| 105 |
+
|
| 106 |
+
for attempt in range(1, 4):
|
| 107 |
+
prompt = f"""
|
| 108 |
+
Task: {obs.get('task_description', '')}
|
| 109 |
+
Buggy code:
|
| 110 |
+
{obs.get('buggy_code', '')}
|
| 111 |
+
Previous execution output (if any):
|
| 112 |
+
{obs.get('error_log', 'None')}
|
| 113 |
+
Previous score: {obs.get('last_reward', 0.0)}
|
| 114 |
+
""".strip()
|
| 115 |
+
|
| 116 |
+
fixed_code = _generate_fixed_code(model, prompt)
|
| 117 |
+
if "```python" in fixed_code:
|
| 118 |
+
fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip()
|
| 119 |
+
elif "```" in fixed_code:
|
| 120 |
+
fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip()
|
| 121 |
+
|
| 122 |
+
step_payload = {
|
| 123 |
+
"action": {
|
| 124 |
+
"fixed_code": fixed_code,
|
| 125 |
+
"attempt_number": attempt,
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
step_resp = await client.post(f"{env_url}/step", json=step_payload)
|
| 129 |
+
if step_resp.status_code == 422:
|
| 130 |
+
step_resp = await client.post(
|
| 131 |
+
f"{env_url}/step",
|
| 132 |
+
json={
|
| 133 |
+
"fixed_code": fixed_code,
|
| 134 |
+
"attempt_number": attempt,
|
| 135 |
+
},
|
| 136 |
+
)
|
| 137 |
+
step_resp.raise_for_status()
|
| 138 |
+
result = step_resp.json()
|
| 139 |
+
|
| 140 |
+
reward = float(result.get("reward", 0.0) or 0.0)
|
| 141 |
+
obs = result.get("observation", obs)
|
| 142 |
+
output_log = obs.get("error_log", "") if isinstance(obs, dict) else ""
|
| 143 |
+
|
| 144 |
+
attempts_log.append({
|
| 145 |
+
"attempt": attempt,
|
| 146 |
+
"code": fixed_code,
|
| 147 |
+
"output": output_log[:3000],
|
| 148 |
+
"reward": reward,
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
if reward > best_reward:
|
| 152 |
+
best_reward = reward
|
| 153 |
+
best_code = fixed_code
|
| 154 |
+
best_output = output_log
|
| 155 |
+
|
| 156 |
+
if result.get("done") or reward >= 0.95:
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
return {
|
| 160 |
+
"score": best_reward,
|
| 161 |
+
"fixed_code": best_code,
|
| 162 |
+
"output": best_output[:3000],
|
| 163 |
+
"attempts": attempts_log,
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
import argparse
|
| 169 |
+
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
parser.add_argument("--env-url", default="http://localhost:7860")
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
|
| 174 |
+
async def main():
|
| 175 |
+
scores = {}
|
| 176 |
+
for tid in ["task1", "task2", "task3"]:
|
| 177 |
+
try:
|
| 178 |
+
s = await asyncio.wait_for(run_single_task(tid, args.env_url), timeout=95.0)
|
| 179 |
+
except TimeoutError:
|
| 180 |
+
s = 0.0
|
| 181 |
+
scores[tid] = round(s, 4)
|
| 182 |
+
print(f"{tid}: {s:.4f}")
|
| 183 |
+
print(f"Average: {sum(scores.values()) / 3:.4f}")
|
| 184 |
+
|
| 185 |
+
asyncio.run(main())
|
client.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from openenv.core import EnvClient
|
| 4 |
+
from openenv.core.client_types import StepResult
|
| 5 |
+
from openenv.core.env_server.types import State
|
| 6 |
+
|
| 7 |
+
from .models import MLDebugAction, MLDebugObservation
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MLDebugEnv(EnvClient[MLDebugAction, MLDebugObservation, State]):
|
| 11 |
+
def _step_payload(self, action: MLDebugAction) -> Dict:
|
| 12 |
+
return {
|
| 13 |
+
"fixed_code": action.fixed_code,
|
| 14 |
+
"explanation": action.explanation,
|
| 15 |
+
"attempt_number": action.attempt_number,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
def _parse_result(self, payload: Dict) -> StepResult[MLDebugObservation]:
|
| 19 |
+
obs_data = payload.get("observation", {})
|
| 20 |
+
observation = MLDebugObservation(
|
| 21 |
+
task_id=obs_data.get("task_id", "task1"),
|
| 22 |
+
task_description=obs_data.get("task_description", ""),
|
| 23 |
+
buggy_code=obs_data.get("buggy_code", ""),
|
| 24 |
+
error_log=obs_data.get("error_log", ""),
|
| 25 |
+
last_reward=obs_data.get("last_reward", 0.0),
|
| 26 |
+
metrics=obs_data.get("metrics", {}),
|
| 27 |
+
done=payload.get("done", False),
|
| 28 |
+
reward=payload.get("reward"),
|
| 29 |
+
metadata=obs_data.get("metadata", {}),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return StepResult(
|
| 33 |
+
observation=observation,
|
| 34 |
+
reward=payload.get("reward"),
|
| 35 |
+
done=payload.get("done", False),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 39 |
+
return State(
|
| 40 |
+
episode_id=payload.get("episode_id"),
|
| 41 |
+
step_count=payload.get("step_count", 0),
|
| 42 |
+
)
|
gradio_app.py
ADDED
|
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WhipStudio — ML Debug Arena
|
| 3 |
+
A polished Gradio UI for the ML Debugging RL environment.
|
| 4 |
+
Provides code editing, loss curve visualization, diff views, and episode history.
|
| 5 |
+
"""
|
| 6 |
+
import difflib
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
import re
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import httpx
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
DEFAULT_BASE_URL = os.environ.get("BASE_URL", "http://localhost:7860")
|
| 17 |
+
|
| 18 |
+
# ── Task metadata ──────────────────────────────────────────────────────────
|
| 19 |
+
|
| 20 |
+
TASK_INFO = {
|
| 21 |
+
"task1": {
|
| 22 |
+
"name": "Broken Training Loop",
|
| 23 |
+
"difficulty": "🟢 Easy",
|
| 24 |
+
"description": "Fix optimizer order + learning rate bugs in a linear classifier.",
|
| 25 |
+
"hints": "Look at optimizer.step() / loss.backward() order and the learning rate.",
|
| 26 |
+
},
|
| 27 |
+
"task2": {
|
| 28 |
+
"name": "Silent NaN Loss",
|
| 29 |
+
"difficulty": "🟡 Medium",
|
| 30 |
+
"description": "Fix numerical instability causing NaN loss from log(0).",
|
| 31 |
+
"hints": "The loss computation uses torch.log() without clamping — pred can be 0.",
|
| 32 |
+
},
|
| 33 |
+
"task3": {
|
| 34 |
+
"name": "OOM + Data Leakage",
|
| 35 |
+
"difficulty": "🔴 Hard",
|
| 36 |
+
"description": "Fix memory leak (graph accumulation) AND train/val data leakage.",
|
| 37 |
+
"hints": "Two bugs: total_loss accumulates graph, and augmentation is applied before split.",
|
| 38 |
+
},
|
| 39 |
+
"task4": {
|
| 40 |
+
"name": "Wrong Loss Function",
|
| 41 |
+
"difficulty": "🟡 Medium",
|
| 42 |
+
"description": "Multi-label classification incorrectly using CrossEntropyLoss. Fix loss and eval.",
|
| 43 |
+
"hints": "Use BCEWithLogitsLoss for multi-label. Ensure predictions are multi-hot.",
|
| 44 |
+
},
|
| 45 |
+
"task5": {
|
| 46 |
+
"name": "Frozen Backbone",
|
| 47 |
+
"difficulty": "🟡 Medium",
|
| 48 |
+
"description": "Backbone frozen but its parameters are passed to the optimizer.",
|
| 49 |
+
"hints": "Unfreeze backend or only pass head parameters to Adam.",
|
| 50 |
+
},
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# ── Theme ──────────────────────────────────────────────────────────────────
|
| 54 |
+
|
| 55 |
+
CUSTOM_CSS = """
|
| 56 |
+
/* Dark arena theme */
|
| 57 |
+
.gradio-container { max-width: 1400px !important; }
|
| 58 |
+
|
| 59 |
+
/* Score display */
|
| 60 |
+
.score-high { color: #22c55e !important; font-weight: 700; }
|
| 61 |
+
.score-med { color: #eab308 !important; font-weight: 700; }
|
| 62 |
+
.score-low { color: #ef4444 !important; font-weight: 700; }
|
| 63 |
+
|
| 64 |
+
/* Task badges */
|
| 65 |
+
.task-badge { display: inline-block; padding: 2px 8px; border-radius: 4px; font-size: 0.85em; }
|
| 66 |
+
.badge-easy { background: #16a34a22; color: #22c55e; border: 1px solid #22c55e44; }
|
| 67 |
+
.badge-medium { background: #ca8a0422; color: #eab308; border: 1px solid #eab30844; }
|
| 68 |
+
.badge-hard { background: #dc262622; color: #ef4444; border: 1px solid #ef444444; }
|
| 69 |
+
|
| 70 |
+
/* Diff highlighting */
|
| 71 |
+
.diff-add { color: #22c55e; background: #22c55e11; }
|
| 72 |
+
.diff-del { color: #ef4444; background: #ef444411; }
|
| 73 |
+
|
| 74 |
+
/* Score bar */
|
| 75 |
+
.score-bar-container { position: relative; height: 28px; background: #1e293b; border-radius: 6px; overflow: hidden; margin: 8px 0; }
|
| 76 |
+
.score-bar-fill { height: 100%; border-radius: 6px; transition: width 0.6s ease-in-out; }
|
| 77 |
+
.score-bar-label { position: absolute; right: 8px; top: 4px; font-weight: 600; font-size: 0.9em; }
|
| 78 |
+
|
| 79 |
+
/* Episode step indicators */
|
| 80 |
+
.step-indicator { display: inline-flex; align-items: center; gap: 6px; padding: 4px 12px; border-radius: 16px; font-size: 0.85em; margin: 2px; }
|
| 81 |
+
.step-done { background: #16a34a22; border: 1px solid #22c55e44; color: #22c55e; }
|
| 82 |
+
.step-active { background: #3b82f622; border: 1px solid #3b82f644; color: #60a5fa; }
|
| 83 |
+
.step-pending { background: #334155; border: 1px solid #47556966; color: #94a3b8; }
|
| 84 |
+
|
| 85 |
+
/* Header styling */
|
| 86 |
+
.arena-header { text-align: center; padding: 12px 0; }
|
| 87 |
+
.arena-header h1 { margin: 0; font-size: 1.8em; }
|
| 88 |
+
.arena-header p { margin: 4px 0 0 0; color: #94a3b8; }
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ── API helpers ────────────────────────────────────────────────────────────
|
| 93 |
+
|
| 94 |
+
def _api(base_url: str, method: str, path: str, payload: dict | None = None) -> dict:
|
| 95 |
+
"""Call the WhipStudio API and return parsed JSON."""
|
| 96 |
+
base_url = (base_url or DEFAULT_BASE_URL).strip().rstrip("/")
|
| 97 |
+
url = f"{base_url}{path}"
|
| 98 |
+
try:
|
| 99 |
+
with httpx.Client(timeout=90.0) as client:
|
| 100 |
+
if method == "GET":
|
| 101 |
+
resp = client.get(url)
|
| 102 |
+
else:
|
| 103 |
+
resp = client.post(url, json=payload or {})
|
| 104 |
+
resp.raise_for_status()
|
| 105 |
+
return resp.json()
|
| 106 |
+
except Exception as exc:
|
| 107 |
+
return {"error": f"{exc.__class__.__name__}: {exc}"}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _parse_losses_from_log(log: str) -> list[float]:
|
| 111 |
+
"""Extract LOSSES:[...] from stdout."""
|
| 112 |
+
match = re.search(r"LOSSES:\[([^\]]+)\]", log)
|
| 113 |
+
if not match:
|
| 114 |
+
return []
|
| 115 |
+
try:
|
| 116 |
+
return [float(x.strip()) for x in match.group(1).split(",")]
|
| 117 |
+
except Exception:
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _parse_val_accs_from_log(log: str) -> list[float]:
|
| 122 |
+
"""Extract VAL_ACCS:[...] from stdout."""
|
| 123 |
+
match = re.search(r"VAL_ACCS:\[([^\]]+)\]", log)
|
| 124 |
+
if not match:
|
| 125 |
+
return []
|
| 126 |
+
try:
|
| 127 |
+
return [float(x.strip()) for x in match.group(1).split(",")]
|
| 128 |
+
except Exception:
|
| 129 |
+
return []
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _score_color(score: float) -> str:
|
| 133 |
+
if score >= 0.7:
|
| 134 |
+
return "#22c55e"
|
| 135 |
+
if score >= 0.4:
|
| 136 |
+
return "#eab308"
|
| 137 |
+
return "#ef4444"
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _score_html(score: float) -> str:
|
| 141 |
+
pct = int(score * 100)
|
| 142 |
+
color = _score_color(score)
|
| 143 |
+
return f"""
|
| 144 |
+
<div style="text-align:center; margin: 8px 0;">
|
| 145 |
+
<div style="font-size: 2.4em; font-weight: 700; color: {color};">{score:.2f}</div>
|
| 146 |
+
<div style="position:relative; height:24px; background:#1e293b; border-radius:6px; overflow:hidden; margin:8px auto; max-width:280px;">
|
| 147 |
+
<div style="height:100%; width:{pct}%; background:linear-gradient(90deg, {color}88, {color}); border-radius:6px; transition:width 0.6s ease;"></div>
|
| 148 |
+
</div>
|
| 149 |
+
<div style="color:#94a3b8; font-size:0.85em;">{pct}% complete</div>
|
| 150 |
+
</div>"""
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _diff_html(original: str, fixed: str) -> str:
|
| 154 |
+
"""Generate an HTML diff view between original and fixed code."""
|
| 155 |
+
orig_lines = original.strip().splitlines()
|
| 156 |
+
fixed_lines = fixed.strip().splitlines()
|
| 157 |
+
diff = difflib.unified_diff(orig_lines, fixed_lines, lineterm="", n=3)
|
| 158 |
+
lines = []
|
| 159 |
+
for line in diff:
|
| 160 |
+
if line.startswith("+++") or line.startswith("---"):
|
| 161 |
+
continue
|
| 162 |
+
if line.startswith("@@"):
|
| 163 |
+
lines.append(f'<div style="color:#60a5fa;margin:8px 0 2px 0;font-size:0.85em;">{line}</div>')
|
| 164 |
+
elif line.startswith("+"):
|
| 165 |
+
lines.append(f'<div style="color:#22c55e;background:#22c55e0d;padding:1px 6px;font-family:monospace;font-size:0.85em;">+ {line[1:]}</div>')
|
| 166 |
+
elif line.startswith("-"):
|
| 167 |
+
lines.append(f'<div style="color:#ef4444;background:#ef44440d;padding:1px 6px;font-family:monospace;font-size:0.85em;">- {line[1:]}</div>')
|
| 168 |
+
else:
|
| 169 |
+
lines.append(f'<div style="color:#94a3b8;padding:1px 6px;font-family:monospace;font-size:0.85em;"> {line}</div>')
|
| 170 |
+
if not lines:
|
| 171 |
+
return '<div style="color:#94a3b8;text-align:center;padding:20px;">No changes detected</div>'
|
| 172 |
+
return '<div style="background:#0f172a;border-radius:8px;padding:12px;overflow-x:auto;max-height:500px;overflow-y:auto;">' + "\n".join(lines) + "</div>"
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _step_timeline_html(trajectory: list[dict], current_step: int, max_steps: int = 3) -> str:
|
| 176 |
+
"""Render step timeline as HTML."""
|
| 177 |
+
items = []
|
| 178 |
+
for i in range(1, max_steps + 1):
|
| 179 |
+
entry = next((t for t in trajectory if t.get("step") == i), None)
|
| 180 |
+
if entry:
|
| 181 |
+
r = entry["reward"]
|
| 182 |
+
color = _score_color(r)
|
| 183 |
+
items.append(
|
| 184 |
+
f'<span class="step-indicator step-done" style="border-color:{color}44;background:{color}11;color:{color};">'
|
| 185 |
+
f'Step {i} → {r:.2f}</span>'
|
| 186 |
+
)
|
| 187 |
+
elif i == current_step + 1:
|
| 188 |
+
items.append(f'<span class="step-indicator step-active">Step {i} ▶</span>')
|
| 189 |
+
else:
|
| 190 |
+
items.append(f'<span class="step-indicator step-pending">Step {i}</span>')
|
| 191 |
+
|
| 192 |
+
return '<div style="display:flex;gap:8px;justify-content:center;flex-wrap:wrap;padding:8px 0;">' + "".join(items) + "</div>"
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _loss_plot(losses: list[float], title: str = "Loss Curve"):
|
| 196 |
+
"""Generate a matplotlib figure for loss curve."""
|
| 197 |
+
import matplotlib
|
| 198 |
+
matplotlib.use("Agg")
|
| 199 |
+
import matplotlib.pyplot as plt
|
| 200 |
+
|
| 201 |
+
fig, ax = plt.subplots(figsize=(5, 3))
|
| 202 |
+
fig.patch.set_facecolor("#0f172a")
|
| 203 |
+
ax.set_facecolor("#1e293b")
|
| 204 |
+
|
| 205 |
+
if losses:
|
| 206 |
+
valid_losses = [(i, l) for i, l in enumerate(losses) if not (math.isnan(l) or math.isinf(l))]
|
| 207 |
+
nan_steps = [i for i, l in enumerate(losses) if math.isnan(l) or math.isinf(l)]
|
| 208 |
+
|
| 209 |
+
if valid_losses:
|
| 210 |
+
steps, vals = zip(*valid_losses)
|
| 211 |
+
ax.plot(steps, vals, color="#60a5fa", linewidth=2, marker="o", markersize=3, zorder=3)
|
| 212 |
+
ax.fill_between(steps, vals, alpha=0.15, color="#60a5fa")
|
| 213 |
+
|
| 214 |
+
if nan_steps:
|
| 215 |
+
ax.scatter(nan_steps, [max(v for _, v in valid_losses) if valid_losses else 1.0] * len(nan_steps),
|
| 216 |
+
color="#ef4444", marker="x", s=60, zorder=4, label="NaN/Inf")
|
| 217 |
+
ax.legend(facecolor="#1e293b", edgecolor="#334155", labelcolor="#94a3b8")
|
| 218 |
+
|
| 219 |
+
ax.set_xlabel("Step", color="#94a3b8", fontsize=9)
|
| 220 |
+
ax.set_ylabel("Loss", color="#94a3b8", fontsize=9)
|
| 221 |
+
ax.set_title(title, color="#e2e8f0", fontsize=11, fontweight="bold")
|
| 222 |
+
ax.tick_params(colors="#64748b", labelsize=8)
|
| 223 |
+
for spine in ax.spines.values():
|
| 224 |
+
spine.set_color("#334155")
|
| 225 |
+
ax.grid(True, alpha=0.15, color="#475569")
|
| 226 |
+
|
| 227 |
+
fig.tight_layout()
|
| 228 |
+
return fig
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _val_acc_plot(accs: list[float]):
|
| 232 |
+
"""Generate a matplotlib figure for validation accuracy."""
|
| 233 |
+
import matplotlib
|
| 234 |
+
matplotlib.use("Agg")
|
| 235 |
+
import matplotlib.pyplot as plt
|
| 236 |
+
|
| 237 |
+
fig, ax = plt.subplots(figsize=(5, 3))
|
| 238 |
+
fig.patch.set_facecolor("#0f172a")
|
| 239 |
+
ax.set_facecolor("#1e293b")
|
| 240 |
+
|
| 241 |
+
if accs:
|
| 242 |
+
epochs = list(range(1, len(accs) + 1))
|
| 243 |
+
ax.plot(epochs, accs, color="#a78bfa", linewidth=2, marker="o", markersize=3)
|
| 244 |
+
ax.fill_between(epochs, accs, alpha=0.15, color="#a78bfa")
|
| 245 |
+
ax.axhline(y=0.7, color="#22c55e", linestyle="--", alpha=0.5, label="Target (0.70)")
|
| 246 |
+
ax.legend(facecolor="#1e293b", edgecolor="#334155", labelcolor="#94a3b8")
|
| 247 |
+
|
| 248 |
+
ax.set_xlabel("Epoch", color="#94a3b8", fontsize=9)
|
| 249 |
+
ax.set_ylabel("Accuracy", color="#94a3b8", fontsize=9)
|
| 250 |
+
ax.set_title("Validation Accuracy", color="#e2e8f0", fontsize=11, fontweight="bold")
|
| 251 |
+
ax.set_ylim(0, 1.05)
|
| 252 |
+
ax.tick_params(colors="#64748b", labelsize=8)
|
| 253 |
+
for spine in ax.spines.values():
|
| 254 |
+
spine.set_color("#334155")
|
| 255 |
+
ax.grid(True, alpha=0.15, color="#475569")
|
| 256 |
+
|
| 257 |
+
fig.tight_layout()
|
| 258 |
+
return fig
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# ── Episode state ──────────────────────────────────────────────────────────
|
| 262 |
+
|
| 263 |
+
class EpisodeState:
|
| 264 |
+
"""Track episode state across Gradio interactions."""
|
| 265 |
+
def __init__(self):
|
| 266 |
+
self.reset()
|
| 267 |
+
|
| 268 |
+
def reset(self):
|
| 269 |
+
self.task_id = ""
|
| 270 |
+
self.buggy_code = ""
|
| 271 |
+
self.task_description = ""
|
| 272 |
+
self.step = 0
|
| 273 |
+
self.best_reward = 0.0
|
| 274 |
+
self.last_reward = 0.0
|
| 275 |
+
self.error_log = ""
|
| 276 |
+
self.trajectory: list[dict] = []
|
| 277 |
+
self.done = False
|
| 278 |
+
self.last_fixed_code = ""
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
_state = EpisodeState()
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# ── Action handlers ────────────────────────────────────────────────────────
|
| 285 |
+
|
| 286 |
+
def do_reset(base_url: str, task_id: str):
|
| 287 |
+
"""Reset the environment for a given task."""
|
| 288 |
+
_state.reset()
|
| 289 |
+
data = _api(base_url, "POST", "/reset", {"task_id": task_id})
|
| 290 |
+
if "error" in data:
|
| 291 |
+
return (
|
| 292 |
+
f"❌ Error: {data['error']}", # status
|
| 293 |
+
"", # code editor
|
| 294 |
+
"", # task desc
|
| 295 |
+
_score_html(0.0), # score
|
| 296 |
+
None, # loss plot
|
| 297 |
+
None, # acc plot
|
| 298 |
+
"", # diff
|
| 299 |
+
_step_timeline_html([], 0), # timeline
|
| 300 |
+
"", # error log
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
obs = data.get("observation", data)
|
| 304 |
+
_state.task_id = obs.get("task_id", task_id)
|
| 305 |
+
_state.buggy_code = obs.get("buggy_code", "")
|
| 306 |
+
_state.task_description = obs.get("task_description", "")
|
| 307 |
+
_state.step = 0
|
| 308 |
+
_state.done = False
|
| 309 |
+
|
| 310 |
+
info = TASK_INFO.get(task_id, {})
|
| 311 |
+
task_md = f"""### {info.get('name', task_id)} {info.get('difficulty', '')}
|
| 312 |
+
|
| 313 |
+
{_state.task_description.strip()}
|
| 314 |
+
|
| 315 |
+
**💡 Hint:** {info.get('hints', 'No hints available.')}
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
return (
|
| 319 |
+
f"✅ Episode started — {info.get('name', task_id)}",
|
| 320 |
+
_state.buggy_code.strip(),
|
| 321 |
+
task_md,
|
| 322 |
+
_score_html(0.0),
|
| 323 |
+
_loss_plot([], "Loss Curve — Submit a fix to see results"),
|
| 324 |
+
None,
|
| 325 |
+
'<div style="color:#94a3b8;text-align:center;padding:20px;">Submit a fix to see the diff</div>',
|
| 326 |
+
_step_timeline_html([], 0),
|
| 327 |
+
"",
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def do_step(base_url: str, fixed_code: str):
|
| 332 |
+
"""Submit a fix attempt."""
|
| 333 |
+
if _state.done:
|
| 334 |
+
return (
|
| 335 |
+
"⚠️ Episode is done. Reset to start a new one.",
|
| 336 |
+
_score_html(_state.best_reward),
|
| 337 |
+
None, None, "", _step_timeline_html(_state.trajectory, _state.step), "",
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if not fixed_code or not fixed_code.strip():
|
| 341 |
+
return (
|
| 342 |
+
"⚠️ Please enter code before submitting.",
|
| 343 |
+
_score_html(_state.last_reward),
|
| 344 |
+
None, None, "", _step_timeline_html(_state.trajectory, _state.step), "",
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
_state.step += 1
|
| 348 |
+
_state.last_fixed_code = fixed_code
|
| 349 |
+
|
| 350 |
+
payload = {"action": {"fixed_code": fixed_code, "attempt_number": _state.step}}
|
| 351 |
+
data = _api(base_url, "POST", "/step", payload)
|
| 352 |
+
|
| 353 |
+
if "error" in data:
|
| 354 |
+
_state.step -= 1
|
| 355 |
+
return (
|
| 356 |
+
f"❌ Error: {data['error']}",
|
| 357 |
+
_score_html(_state.last_reward),
|
| 358 |
+
None, None, "", _step_timeline_html(_state.trajectory, _state.step), "",
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
reward = float(data.get("reward", 0.0) or 0.0)
|
| 362 |
+
_state.last_reward = reward
|
| 363 |
+
_state.best_reward = max(_state.best_reward, reward)
|
| 364 |
+
_state.done = data.get("done", False)
|
| 365 |
+
|
| 366 |
+
obs = data.get("observation", {})
|
| 367 |
+
_state.error_log = obs.get("error_log", "")
|
| 368 |
+
metrics = obs.get("metrics", {})
|
| 369 |
+
|
| 370 |
+
_state.trajectory.append({
|
| 371 |
+
"step": _state.step,
|
| 372 |
+
"reward": reward,
|
| 373 |
+
"best_reward": _state.best_reward,
|
| 374 |
+
"metrics": metrics,
|
| 375 |
+
"done": _state.done,
|
| 376 |
+
})
|
| 377 |
+
|
| 378 |
+
# Parse outputs for visualization
|
| 379 |
+
losses = _parse_losses_from_log(_state.error_log)
|
| 380 |
+
val_accs = _parse_val_accs_from_log(_state.error_log)
|
| 381 |
+
|
| 382 |
+
loss_fig = _loss_plot(losses, f"Loss Curve — Step {_state.step}")
|
| 383 |
+
acc_fig = _val_acc_plot(val_accs) if val_accs else None
|
| 384 |
+
|
| 385 |
+
diff = _diff_html(_state.buggy_code, fixed_code)
|
| 386 |
+
timeline = _step_timeline_html(_state.trajectory, _state.step)
|
| 387 |
+
|
| 388 |
+
if reward >= 0.95:
|
| 389 |
+
emoji = "🎯"
|
| 390 |
+
elif reward >= 0.7:
|
| 391 |
+
emoji = "✅"
|
| 392 |
+
elif len(_state.trajectory) > 1 and reward > _state.trajectory[-2]["reward"]:
|
| 393 |
+
emoji = "📈"
|
| 394 |
+
else:
|
| 395 |
+
emoji = "⚠️"
|
| 396 |
+
done_msg = " — Episode complete!" if _state.done else ""
|
| 397 |
+
status = f"{emoji} Step {_state.step}/3 — Reward: {reward:.2f} (Best: {_state.best_reward:.2f}){done_msg}"
|
| 398 |
+
|
| 399 |
+
error_display = _state.error_log if _state.error_log else "No errors — code ran successfully."
|
| 400 |
+
|
| 401 |
+
return (status, _score_html(_state.best_reward), loss_fig, acc_fig, diff, timeline, error_display)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def do_run_baseline(base_url: str, task_id: str):
|
| 405 |
+
"""Run the baseline agent on a single task."""
|
| 406 |
+
# First reset
|
| 407 |
+
reset_result = do_reset(base_url, task_id)
|
| 408 |
+
yield reset_result + ("🤖 Resetting environment...",)
|
| 409 |
+
|
| 410 |
+
# Call baseline endpoint
|
| 411 |
+
data = _api(base_url, "GET", "/baseline")
|
| 412 |
+
if "error" in data:
|
| 413 |
+
yield reset_result + (f"❌ Baseline error: {data['error']}",)
|
| 414 |
+
return
|
| 415 |
+
|
| 416 |
+
scores = data.get("baseline_scores", {})
|
| 417 |
+
avg = data.get("average", 0.0)
|
| 418 |
+
|
| 419 |
+
results_md = "### 🤖 Baseline Agent Results\n\n"
|
| 420 |
+
results_md += "| Task | Score |\n|---|---|\n"
|
| 421 |
+
for tid in ["task1", "task2", "task3", "task4", "task5"]:
|
| 422 |
+
s = scores.get(tid, 0.0)
|
| 423 |
+
emoji = "🎯" if s >= 0.9 else ("✅" if s >= 0.7 else ("📈" if s >= 0.4 else "⚠️"))
|
| 424 |
+
results_md += f"| {tid} | {emoji} {s:.4f} |\n"
|
| 425 |
+
results_md += f"\n**Average: {avg:.4f}**"
|
| 426 |
+
|
| 427 |
+
yield reset_result + (results_md,)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def load_current_state(base_url: str):
|
| 431 |
+
"""Fetch and format current environment state for UI display."""
|
| 432 |
+
data = _api(base_url, "GET", "/state")
|
| 433 |
+
if "error" in data:
|
| 434 |
+
summary = "⚠️ Could not fetch current state. Start or reset an episode first, then try again."
|
| 435 |
+
return summary, {"error": data["error"]}
|
| 436 |
+
|
| 437 |
+
state_obj = data.get("state", data)
|
| 438 |
+
if not isinstance(state_obj, dict):
|
| 439 |
+
return "⚠️ State endpoint returned an unexpected response format.", {"raw": data}
|
| 440 |
+
|
| 441 |
+
done = bool(state_obj.get("done", False))
|
| 442 |
+
step = state_obj.get("step", 0)
|
| 443 |
+
task_id = state_obj.get("task_id", "-")
|
| 444 |
+
reward = state_obj.get("last_reward", state_obj.get("reward", 0.0))
|
| 445 |
+
summary = (
|
| 446 |
+
f"**Task:** {task_id} | **Step:** {step} | "
|
| 447 |
+
f"**Done:** {'yes' if done else 'no'} | **Last Reward:** {reward}"
|
| 448 |
+
)
|
| 449 |
+
return summary, state_obj
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# ── Build the UI ───────────────────────────────────────────────────────────
|
| 453 |
+
|
| 454 |
+
def build_ui() -> gr.Blocks:
|
| 455 |
+
theme = gr.themes.Soft(
|
| 456 |
+
primary_hue=gr.themes.colors.blue,
|
| 457 |
+
secondary_hue=gr.themes.colors.purple,
|
| 458 |
+
neutral_hue=gr.themes.colors.slate,
|
| 459 |
+
font=gr.themes.GoogleFont("Inter"),
|
| 460 |
+
).set(
|
| 461 |
+
body_background_fill="#0f172a",
|
| 462 |
+
body_background_fill_dark="#0f172a",
|
| 463 |
+
block_background_fill="#1e293b",
|
| 464 |
+
block_background_fill_dark="#1e293b",
|
| 465 |
+
block_border_color="#334155",
|
| 466 |
+
block_border_color_dark="#334155",
|
| 467 |
+
block_label_text_color="#e2e8f0",
|
| 468 |
+
block_label_text_color_dark="#e2e8f0",
|
| 469 |
+
block_title_text_color="#e2e8f0",
|
| 470 |
+
block_title_text_color_dark="#e2e8f0",
|
| 471 |
+
input_background_fill="#0f172a",
|
| 472 |
+
input_background_fill_dark="#0f172a",
|
| 473 |
+
button_primary_background_fill="#3b82f6",
|
| 474 |
+
button_primary_background_fill_dark="#3b82f6",
|
| 475 |
+
button_primary_text_color="#ffffff",
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
with gr.Blocks(title="WhipStudio — ML Debug Arena") as app:
|
| 479 |
+
|
| 480 |
+
# ── Header ──
|
| 481 |
+
gr.HTML("""
|
| 482 |
+
<div class="arena-header">
|
| 483 |
+
<h1>🔧 WhipStudio — ML Debug Arena</h1>
|
| 484 |
+
<p>An RL environment where agents debug broken PyTorch training scripts</p>
|
| 485 |
+
</div>
|
| 486 |
+
""")
|
| 487 |
+
|
| 488 |
+
base_url = gr.Textbox(label="🌐 API Base URL", value=DEFAULT_BASE_URL, scale=1)
|
| 489 |
+
|
| 490 |
+
with gr.Row(equal_height=False):
|
| 491 |
+
|
| 492 |
+
# ── Left column: Task selector ──
|
| 493 |
+
with gr.Column(scale=1, min_width=280):
|
| 494 |
+
gr.Markdown("### 📋 Task Selector")
|
| 495 |
+
task_id = gr.Radio(
|
| 496 |
+
choices=["task1", "task2", "task3", "task4", "task5"],
|
| 497 |
+
value="task1",
|
| 498 |
+
label="Select Task",
|
| 499 |
+
info="Choose a debugging challenge",
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
task_desc = gr.Markdown(
|
| 503 |
+
value="""### Broken Training Loop 🟢 Easy
|
| 504 |
+
|
| 505 |
+
Fix optimizer order + learning rate bugs in a linear classifier.
|
| 506 |
+
|
| 507 |
+
**💡 Hint:** Look at optimizer.step() / loss.backward() order and the learning rate."""
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
with gr.Row():
|
| 511 |
+
btn_reset = gr.Button("🔄 Reset", variant="primary", size="sm")
|
| 512 |
+
btn_baseline = gr.Button("🤖 Auto-Agent", variant="secondary", size="sm")
|
| 513 |
+
|
| 514 |
+
status = gr.Textbox(label="Status", interactive=False, lines=1)
|
| 515 |
+
timeline = gr.HTML(label="Episode Timeline", value=_step_timeline_html([], 0))
|
| 516 |
+
|
| 517 |
+
# ── Center column: Code editor ──
|
| 518 |
+
with gr.Column(scale=2, min_width=400):
|
| 519 |
+
gr.Markdown("### 💻 Code Editor")
|
| 520 |
+
code_editor = gr.Code(
|
| 521 |
+
label="Your Fix (edit the code below)",
|
| 522 |
+
language="python",
|
| 523 |
+
lines=22,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
with gr.Row():
|
| 527 |
+
btn_submit = gr.Button("🚀 Submit Fix", variant="primary", size="lg")
|
| 528 |
+
|
| 529 |
+
error_log = gr.Textbox(
|
| 530 |
+
label="📋 Execution Output / Error Log",
|
| 531 |
+
lines=6,
|
| 532 |
+
interactive=False,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# ── Right column: Results ──
|
| 536 |
+
with gr.Column(scale=1, min_width=300):
|
| 537 |
+
gr.Markdown("### 📊 Results")
|
| 538 |
+
score_display = gr.HTML(value=_score_html(0.0))
|
| 539 |
+
|
| 540 |
+
with gr.Tabs():
|
| 541 |
+
with gr.Tab("📉 Loss Curve"):
|
| 542 |
+
loss_plot = gr.Plot(label="Loss Curve")
|
| 543 |
+
with gr.Tab("📈 Val Accuracy"):
|
| 544 |
+
acc_plot = gr.Plot(label="Validation Accuracy (Task 3)")
|
| 545 |
+
with gr.Tab("🔀 Code Diff"):
|
| 546 |
+
diff_view = gr.HTML(
|
| 547 |
+
value='<div style="color:#94a3b8;text-align:center;padding:20px;">Submit a fix to see the diff</div>'
|
| 548 |
+
)
|
| 549 |
+
with gr.Tab("🧭 Current State"):
|
| 550 |
+
state_summary = gr.Markdown(
|
| 551 |
+
value="Press Reset to start an episode, then Current State will appear here."
|
| 552 |
+
)
|
| 553 |
+
btn_refresh_state = gr.Button("🔄 Refresh State", variant="secondary", size="sm")
|
| 554 |
+
state_json = gr.JSON(label="/state response", value={})
|
| 555 |
+
|
| 556 |
+
baseline_output = gr.Markdown(label="Baseline Results", visible=False)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
# ── Bottom: Raw API tab (for developers) ──
|
| 560 |
+
with gr.Accordion("🔧 Developer Tools (Raw API)", open=False):
|
| 561 |
+
with gr.Row():
|
| 562 |
+
with gr.Column():
|
| 563 |
+
dev_method = gr.Radio(["GET", "POST"], value="GET", label="Method")
|
| 564 |
+
dev_path = gr.Textbox(label="Path", value="/health")
|
| 565 |
+
dev_payload = gr.Code(label="Payload (JSON)", language="json", value="{}")
|
| 566 |
+
btn_dev = gr.Button("Send Request", variant="secondary")
|
| 567 |
+
with gr.Column():
|
| 568 |
+
dev_status = gr.Textbox(label="Status")
|
| 569 |
+
dev_response = gr.Code(label="Response", language="json")
|
| 570 |
+
|
| 571 |
+
def dev_call(base, method, path, payload_text):
|
| 572 |
+
base = (base or DEFAULT_BASE_URL).strip().rstrip("/")
|
| 573 |
+
url = f"{base}{path}"
|
| 574 |
+
try:
|
| 575 |
+
payload = json.loads(payload_text) if payload_text.strip() else {}
|
| 576 |
+
except json.JSONDecodeError as e:
|
| 577 |
+
return f"JSON Error", str(e)
|
| 578 |
+
try:
|
| 579 |
+
with httpx.Client(timeout=90.0) as client:
|
| 580 |
+
if method == "GET":
|
| 581 |
+
resp = client.get(url)
|
| 582 |
+
else:
|
| 583 |
+
resp = client.post(url, json=payload)
|
| 584 |
+
ct = resp.headers.get("content-type", "")
|
| 585 |
+
body = json.dumps(resp.json(), indent=2) if "json" in ct else resp.text
|
| 586 |
+
return f"{resp.status_code} {resp.reason_phrase}", body
|
| 587 |
+
except Exception as exc:
|
| 588 |
+
return "Error", f"{exc.__class__.__name__}: {exc}"
|
| 589 |
+
|
| 590 |
+
btn_dev.click(
|
| 591 |
+
fn=dev_call,
|
| 592 |
+
inputs=[base_url, dev_method, dev_path, dev_payload],
|
| 593 |
+
outputs=[dev_status, dev_response],
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
# ── Event bindings ──
|
| 598 |
+
|
| 599 |
+
# Task selector updates description
|
| 600 |
+
def update_task_desc(tid):
|
| 601 |
+
info = TASK_INFO.get(tid, {})
|
| 602 |
+
return f"""### {info.get('name', tid)} {info.get('difficulty', '')}
|
| 603 |
+
|
| 604 |
+
{info.get('description', '')}
|
| 605 |
+
|
| 606 |
+
**💡 Hint:** {info.get('hints', 'No hints available.')}"""
|
| 607 |
+
|
| 608 |
+
task_id.change(fn=update_task_desc, inputs=[task_id], outputs=[task_desc])
|
| 609 |
+
|
| 610 |
+
# Reset
|
| 611 |
+
btn_reset.click(
|
| 612 |
+
fn=do_reset,
|
| 613 |
+
inputs=[base_url, task_id],
|
| 614 |
+
outputs=[status, code_editor, task_desc, score_display, loss_plot, acc_plot, diff_view, timeline, error_log],
|
| 615 |
+
).then(
|
| 616 |
+
fn=load_current_state,
|
| 617 |
+
inputs=[base_url],
|
| 618 |
+
outputs=[state_summary, state_json],
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# Submit fix
|
| 622 |
+
btn_submit.click(
|
| 623 |
+
fn=do_step,
|
| 624 |
+
inputs=[base_url, code_editor],
|
| 625 |
+
outputs=[status, score_display, loss_plot, acc_plot, diff_view, timeline, error_log],
|
| 626 |
+
).then(
|
| 627 |
+
fn=load_current_state,
|
| 628 |
+
inputs=[base_url],
|
| 629 |
+
outputs=[state_summary, state_json],
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
btn_refresh_state.click(
|
| 633 |
+
fn=load_current_state,
|
| 634 |
+
inputs=[base_url],
|
| 635 |
+
outputs=[state_summary, state_json],
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# Baseline (auto-agent) — live streaming per-task
|
| 639 |
+
TASK_NAMES = {
|
| 640 |
+
"task1": "🟢 Broken Training Loop",
|
| 641 |
+
"task2": "🟡 Silent NaN Loss",
|
| 642 |
+
"task3": "🔴 OOM + Data Leakage",
|
| 643 |
+
"task4": "🟡 Wrong Loss Function",
|
| 644 |
+
"task5": "🟡 Frozen Backbone",
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
def run_baseline_live(base_url_val):
|
| 648 |
+
"""Generator that yields live progress as each task completes."""
|
| 649 |
+
base = (base_url_val or DEFAULT_BASE_URL).strip().rstrip("/")
|
| 650 |
+
results = {}
|
| 651 |
+
lines_header = ["### 🤖 Baseline Agent — Live Progress\n"]
|
| 652 |
+
|
| 653 |
+
# Phase 1: Show "starting" state
|
| 654 |
+
yield "\n".join(lines_header + ["⏳ Starting baseline agent..."])
|
| 655 |
+
|
| 656 |
+
for tid in ["task1", "task2", "task3", "task4", "task5"]:
|
| 657 |
+
tname = TASK_NAMES.get(tid, tid)
|
| 658 |
+
|
| 659 |
+
# Show "running this task" update
|
| 660 |
+
progress_lines = list(lines_header)
|
| 661 |
+
# Show completed tasks
|
| 662 |
+
for done_tid, info in results.items():
|
| 663 |
+
s = info["score"]
|
| 664 |
+
emoji = "🎯" if s >= 0.9 else ("✅" if s >= 0.7 else ("📈" if s >= 0.4 else "⚠️"))
|
| 665 |
+
progress_lines.append(f"- {emoji} **{TASK_NAMES.get(done_tid, done_tid)}**: {s:.4f}")
|
| 666 |
+
if info.get("error"):
|
| 667 |
+
progress_lines.append(f" - ⚠️ `{info['error'][:150]}`")
|
| 668 |
+
# Show currently running task
|
| 669 |
+
progress_lines.append(f"\n🤖 **Running {tname}** — agent is analyzing the code and generating a fix...")
|
| 670 |
+
progress_lines.append(f"\n*This may take 30-60 seconds per task (LLM inference + sandbox execution × 3 attempts)*")
|
| 671 |
+
yield "\n".join(progress_lines)
|
| 672 |
+
|
| 673 |
+
# Actually call the per-task endpoint
|
| 674 |
+
try:
|
| 675 |
+
with httpx.Client(timeout=180.0) as client:
|
| 676 |
+
resp = client.get(f"{base}/baseline/task/{tid}")
|
| 677 |
+
resp.raise_for_status()
|
| 678 |
+
data = resp.json()
|
| 679 |
+
except Exception as exc:
|
| 680 |
+
data = {"score": 0.0, "error": f"{exc.__class__.__name__}: {exc}"}
|
| 681 |
+
|
| 682 |
+
results[tid] = {
|
| 683 |
+
"score": float(data.get("score", 0.0)),
|
| 684 |
+
"error": data.get("error", ""),
|
| 685 |
+
"fixed_code": data.get("fixed_code", ""),
|
| 686 |
+
"output": data.get("output", ""),
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
# Final summary
|
| 690 |
+
final_lines = ["### 🤖 Baseline Agent Results\n", "| Task | Score |", "|---|---|"]
|
| 691 |
+
total = 0.0
|
| 692 |
+
has_errors = False
|
| 693 |
+
for tid in ["task1", "task2", "task3", "task4", "task5"]:
|
| 694 |
+
info = results.get(tid, {"score": 0.0})
|
| 695 |
+
s = info["score"]
|
| 696 |
+
total += s
|
| 697 |
+
emoji = "🎯" if s >= 0.9 else ("✅" if s >= 0.7 else ("📈" if s >= 0.4 else "⚠️"))
|
| 698 |
+
final_lines.append(f"| {TASK_NAMES.get(tid, tid)} | {emoji} **{s:.4f}** |")
|
| 699 |
+
if info.get("error"):
|
| 700 |
+
has_errors = True
|
| 701 |
+
final_lines.append(f"\n> ⚠️ `{info['error'][:200]}`\n")
|
| 702 |
+
|
| 703 |
+
avg = total / 5
|
| 704 |
+
final_lines.append(f"\n**Average: {avg:.4f}**")
|
| 705 |
+
if avg >= 0.7:
|
| 706 |
+
final_lines.append("\n🎉 **Agent performed well!** The environment is solvable.")
|
| 707 |
+
elif avg >= 0.3:
|
| 708 |
+
final_lines.append("\n📈 **Agent showed partial progress.** Reward shaping is working.")
|
| 709 |
+
elif not has_errors:
|
| 710 |
+
final_lines.append("\n⚠️ **Agent scored low.** Tasks may be too challenging for zero-shot inference.")
|
| 711 |
+
|
| 712 |
+
if has_errors:
|
| 713 |
+
final_lines.append("\n---\n> [!WARNING]\n> Some tasks failed. Check if `HF_TOKEN` is valid and the model is accessible.")
|
| 714 |
+
|
| 715 |
+
final_lines.append("\n---\n### 🔍 Auto-Agent Generated Code & Execution Logs")
|
| 716 |
+
for tid in ["task1", "task2", "task3", "task4", "task5"]:
|
| 717 |
+
info = results.get(tid, {})
|
| 718 |
+
fixed_code = str(info.get("fixed_code", ""))
|
| 719 |
+
output = str(info.get("output", ""))
|
| 720 |
+
if fixed_code.strip() or output.strip():
|
| 721 |
+
final_lines.append(f"\n#### {TASK_NAMES.get(tid, tid)}")
|
| 722 |
+
if fixed_code.strip():
|
| 723 |
+
final_lines.append("<details><summary><b>Show Generated Code</b></summary>\n\n```python\n" + fixed_code + "\n```\n</details>")
|
| 724 |
+
if output.strip():
|
| 725 |
+
final_lines.append("<details><summary><b>Show Execution Output</b></summary>\n\n```text\n" + output + "\n```\n</details>")
|
| 726 |
+
|
| 727 |
+
yield "\n".join(final_lines)
|
| 728 |
+
|
| 729 |
+
btn_baseline.click(
|
| 730 |
+
fn=lambda: gr.update(visible=True),
|
| 731 |
+
outputs=[baseline_output],
|
| 732 |
+
).then(
|
| 733 |
+
fn=run_baseline_live,
|
| 734 |
+
inputs=[base_url],
|
| 735 |
+
outputs=[baseline_output],
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# Footer
|
| 739 |
+
gr.HTML("""
|
| 740 |
+
<div style="text-align:center; padding:16px 0; color:#64748b; font-size:0.85em; border-top:1px solid #1e293b; margin-top:16px;">
|
| 741 |
+
WhipStudio v1.0 — OpenEnv ML Debug Environment
|
| 742 |
+
· <a href="/web" style="color:#60a5fa;">OpenEnv Web UI →</a>
|
| 743 |
+
· <a href="/docs" style="color:#60a5fa;">API Docs →</a>
|
| 744 |
+
</div>
|
| 745 |
+
""")
|
| 746 |
+
|
| 747 |
+
return app
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def main(host: str = "0.0.0.0", port: int = 7860):
|
| 751 |
+
app = build_ui()
|
| 752 |
+
app.launch(server_name=host, server_port=port, css=CUSTOM_CSS)
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
if __name__ == "__main__":
|
| 756 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openenv.core.env_server.types import Action, Observation
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MLDebugAction(Action):
|
| 6 |
+
"""Agent submits a corrected training script."""
|
| 7 |
+
|
| 8 |
+
fixed_code: str = Field(
|
| 9 |
+
...,
|
| 10 |
+
description="The corrected Python training script. Must be complete runnable code.",
|
| 11 |
+
)
|
| 12 |
+
explanation: str = Field(
|
| 13 |
+
default="",
|
| 14 |
+
description="Optional: agent's explanation of bugs found (not scored, for logging)",
|
| 15 |
+
)
|
| 16 |
+
attempt_number: int = Field(
|
| 17 |
+
default=1,
|
| 18 |
+
ge=1,
|
| 19 |
+
le=3,
|
| 20 |
+
description="Which attempt this is. Max 3 per episode.",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MLDebugObservation(Observation):
|
| 25 |
+
"""What the agent sees on reset() and after each step()."""
|
| 26 |
+
|
| 27 |
+
task_id: str = Field(..., description="task1 | task2 | task3")
|
| 28 |
+
task_description: str = Field(..., description="Plain English task instructions")
|
| 29 |
+
buggy_code: str = Field(..., description="The broken training script")
|
| 30 |
+
error_log: str = Field(
|
| 31 |
+
default="",
|
| 32 |
+
description="stdout+stderr from the previous attempt. Empty on first step.",
|
| 33 |
+
)
|
| 34 |
+
last_reward: float = Field(
|
| 35 |
+
default=0.0,
|
| 36 |
+
description="Reward from previous attempt. 0.0 on first step.",
|
| 37 |
+
)
|
| 38 |
+
metrics: dict = Field(
|
| 39 |
+
default_factory=dict,
|
| 40 |
+
description="Structured: {final_loss, nan_count, val_acc, timed_out, exit_code}",
|
| 41 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
spec_version: 2
|
| 3 |
+
name: whipstudio-env
|
| 4 |
+
type: space
|
| 5 |
+
runtime: fastapi
|
| 6 |
+
app: server.app:app
|
| 7 |
+
port: 7860
|
| 8 |
+
|
| 9 |
+
metadata:
|
| 10 |
+
title: WhipStudio Env
|
| 11 |
+
version: "1.1.0"
|
| 12 |
+
description: >
|
| 13 |
+
OpenEnv-compatible RL environment where agents debug broken PyTorch
|
| 14 |
+
training scripts across five tasks with continuous reward scoring.
|
| 15 |
+
author: Amogh-kal1
|
| 16 |
+
license: Apache-2.0
|
| 17 |
+
repository: Amogh-kal1/whipstudio-env
|
| 18 |
+
tags:
|
| 19 |
+
- openenv
|
| 20 |
+
- rl
|
| 21 |
+
- ml-debugging
|
| 22 |
+
- fastapi
|
| 23 |
+
- gradio
|
| 24 |
+
|
| 25 |
+
compute:
|
| 26 |
+
python_version: "3.11"
|
| 27 |
+
cpu: 2
|
| 28 |
+
memory_gb: 16
|
| 29 |
+
gpu: false
|
| 30 |
+
|
| 31 |
+
service:
|
| 32 |
+
host: 0.0.0.0
|
| 33 |
+
port: 7860
|
| 34 |
+
health_path: /health
|
| 35 |
+
base_path: /
|
| 36 |
+
web_path: /web
|
| 37 |
+
ui_path: /ui
|
| 38 |
+
docs_path: /docs
|
| 39 |
+
endpoints:
|
| 40 |
+
- method: POST
|
| 41 |
+
path: /reset
|
| 42 |
+
- method: POST
|
| 43 |
+
path: /step
|
| 44 |
+
- method: GET
|
| 45 |
+
path: /state
|
| 46 |
+
- method: GET
|
| 47 |
+
path: /tasks
|
| 48 |
+
- method: POST
|
| 49 |
+
path: /grader
|
| 50 |
+
- method: GET
|
| 51 |
+
path: /baseline
|
| 52 |
+
- method: GET
|
| 53 |
+
path: /baseline/task/{task_id}
|
| 54 |
+
- method: GET
|
| 55 |
+
path: /baseline/health
|
| 56 |
+
|
| 57 |
+
tasks:
|
| 58 |
+
- id: task1
|
| 59 |
+
name: Broken training loop
|
| 60 |
+
difficulty: easy
|
| 61 |
+
max_steps: 3
|
| 62 |
+
- id: task2
|
| 63 |
+
name: Silent NaN loss
|
| 64 |
+
difficulty: medium
|
| 65 |
+
max_steps: 3
|
| 66 |
+
- id: task3
|
| 67 |
+
name: OOM and data leakage
|
| 68 |
+
difficulty: hard
|
| 69 |
+
max_steps: 3
|
| 70 |
+
- id: task4
|
| 71 |
+
name: Wrong loss function
|
| 72 |
+
difficulty: medium
|
| 73 |
+
max_steps: 3
|
| 74 |
+
- id: task5
|
| 75 |
+
name: Frozen backbone
|
| 76 |
+
difficulty: medium
|
| 77 |
+
max_steps: 3
|
openenv_whipstudio.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-whipstudio
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: ML Debug environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.11
|
| 6 |
+
License-File: LICENSE
|
| 7 |
+
Requires-Dist: openenv-core[core]>=0.2.1
|
| 8 |
+
Requires-Dist: fastapi>=0.110.0
|
| 9 |
+
Requires-Dist: uvicorn>=0.27.0
|
| 10 |
+
Requires-Dist: pydantic>=2.0.0
|
| 11 |
+
Requires-Dist: httpx>=0.27.0
|
| 12 |
+
Requires-Dist: torch>=2.2.0
|
| 13 |
+
Requires-Dist: smolagents>=1.0.0
|
| 14 |
+
Provides-Extra: dev
|
| 15 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 16 |
+
Dynamic: license-file
|
openenv_whipstudio.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
__init__.py
|
| 4 |
+
baseline_agent.py
|
| 5 |
+
client.py
|
| 6 |
+
gradio_app.py
|
| 7 |
+
models.py
|
| 8 |
+
pyproject.toml
|
| 9 |
+
openenv_whipstudio.egg-info/PKG-INFO
|
| 10 |
+
openenv_whipstudio.egg-info/SOURCES.txt
|
| 11 |
+
openenv_whipstudio.egg-info/dependency_links.txt
|
| 12 |
+
openenv_whipstudio.egg-info/entry_points.txt
|
| 13 |
+
openenv_whipstudio.egg-info/requires.txt
|
| 14 |
+
openenv_whipstudio.egg-info/top_level.txt
|
| 15 |
+
server/__init__.py
|
| 16 |
+
server/app.py
|
| 17 |
+
server/environment.py
|
| 18 |
+
server/sandbox.py
|
| 19 |
+
server/tasks/__init__.py
|
| 20 |
+
server/tasks/graders.py
|
| 21 |
+
server/tasks/task1_broken_loop.py
|
| 22 |
+
server/tasks/task2_nan_loss.py
|
| 23 |
+
server/tasks/task3_oom_leakage.py
|
| 24 |
+
server/tasks/task4_wrong_loss.py
|
| 25 |
+
server/tasks/task5_frozen_backbone.py
|
openenv_whipstudio.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_whipstudio.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = server.app:main
|
| 3 |
+
whipstudio-server = server.app:main
|
openenv_whipstudio.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.1
|
| 2 |
+
fastapi>=0.110.0
|
| 3 |
+
uvicorn>=0.27.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
httpx>=0.27.0
|
| 6 |
+
torch>=2.2.0
|
| 7 |
+
smolagents>=1.0.0
|
| 8 |
+
|
| 9 |
+
[dev]
|
| 10 |
+
pytest>=8.0.0
|
openenv_whipstudio.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
server
|
project_status_report_2703.md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# WhipStudio Project Status Report
|
| 2 |
+
|
| 3 |
+
**Date:** March 27, 2026
|
| 4 |
+
**Project:** WhipStudio — ML Debugging Arena (OpenEnv AI Hackathon)
|
| 5 |
+
**Status:** 🟢 **On Track for Submission**
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Executive Summary
|
| 10 |
+
WhipStudio has successfully transitioned from a conceptual prototype to a fully functional, professional-grade ML Debugging Arena. The core objective of the platform—providing an isolated, standardized environment (Gymnasium) to evaluate and train autonomous AI agents on PyTorch debugging tasks—is now fully operational. The backend conforms exactly to the OpenEnv specification, while the frontend provides a best-in-class spectator and verification experience.
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## System Architecture & Technical Specifications
|
| 15 |
+
|
| 16 |
+
### 1. The Core Environment Database (Tasks)
|
| 17 |
+
The platform currently ships with three distinct ML debugging challenges, graded automatically:
|
| 18 |
+
- **Task 1 (Easy): Broken Training Loop.** Fixes a simple linear classifier that has an excessively high learning rate, steps the optimizer out of order, and computes loss incorrectly.
|
| 19 |
+
- **Task 2 (Medium): Silent NaN Loss.** Fixes a CNN training script where the output loss silently turns to `NaN`.
|
| 20 |
+
- **Task 3 (Hard): OOM + Data Leakage.** Fixes a complex PyTorch training loop that creates memory leaks by accumulating computation graphs incorrectly, and introduces data leakage by augmenting before splitting train/val. Graded on a granular scale (0.15 - 0.5 per bug fixed).
|
| 21 |
+
|
| 22 |
+
### 2. The Execution Sandbox API (Backend)
|
| 23 |
+
All agent submissions are processed securely in an isolated environment.
|
| 24 |
+
|
| 25 |
+
**Core Endpoints (OpenEnv Core API):**
|
| 26 |
+
- `POST /reset`: Initialize the environment for a specific task. Returns the `observation` payload, which includes the `buggy_code` string and `task_description`.
|
| 27 |
+
- `POST /step`: Submit a fixed code attempt (the `action`). The system executes the code in the sandbox, parses the `stdout`/`stderr` using regex metrics extractors (e.g., [parse_scalar](file:///home/amogh/Documents/openenv-comp/WhipStudio/server/tasks/graders.py#35-38) for `FINAL_LOSS`), scores the execution, logs the trajectory, and returns the formal RL `reward` (0.0 to 1.0) along with the `done` state.
|
| 28 |
+
|
| 29 |
+
**Security & Isolation ([server/sandbox.py](file:///home/amogh/Documents/openenv-comp/WhipStudio/server/sandbox.py)):**
|
| 30 |
+
- Python execution restricts dangerous operations using `BANNED_PATTERNS` (e.g., `os.system`, `subprocess.`, `open()`, `socket.`, `requests.`).
|
| 31 |
+
- Code runs in `/tmp` natively in a subprocess with a strict 30-second timeout (`TIMEOUT_SECONDS = 30`) and output sizing limits (`MAX_OUTPUT_BYTES = 8192`).
|
| 32 |
+
- Environment dependencies (`PYTHONPATH`) are strictly controlled, ensuring library access (like `torch`) without host system compromise.
|
| 33 |
+
|
| 34 |
+
### 3. AI Agent Integrations
|
| 35 |
+
- `GET /baseline/health`: Checks if the default baseline model (Qwen-32B) is accessible and authenticated via the HuggingFace API using the `.env` configuration.
|
| 36 |
+
- `GET /baseline`: Instructs the backend to run the internal autonomous zero-shot agent across all three tasks simultaneously.
|
| 37 |
+
- `GET /baseline/task/{task_id}`: Streams execution of the baseline agent for a specific task (120s timeout). Used extensively by the Gradio UI for real-time visualization.
|
| 38 |
+
|
| 39 |
+
### 4. User Interfaces (Frontend)
|
| 40 |
+
|
| 41 |
+
WhipStudio provides dual interfaces for varying stakeholder needs:
|
| 42 |
+
|
| 43 |
+
**A. The "ML Debug Arena" Verification Portal (Gradio App)**
|
| 44 |
+
A custom-built, interactive dashboard designed for human verification and hackathon judges.
|
| 45 |
+
- **Code Editor:** Monaco-style interactive Python editor pre-filled with the buggy code payload.
|
| 46 |
+
- **Live Loss Visualization:** Integrates `matplotlib` to plot `LOSS` curves and `VAL_ACCURACY` step-by-step from the agent's debugged submission.
|
| 47 |
+
- **Live Diff Context:** A source-control style diff view showing exactly what code lines the agent inserted or deleted to achieve the fix.
|
| 48 |
+
- **Autonomous Streaming Mode:** The "Auto-Agent" trigger connects to the `/baseline/task/{task_id}` streams to provide a live "Chain-of-Thought" experience, displaying progress and rewards directly to the UI panel in real-time.
|
| 49 |
+
|
| 50 |
+
**B. The OpenEnv Compliance Web UI (`/web`)**
|
| 51 |
+
- Directly embedded via the OpenEnv python framework (`os.environ["ENABLE_WEB_INTERFACE"] = "true"`).
|
| 52 |
+
- Accessed via `http://[host]:8000/web`, this provides the raw, standardized OpenEnv chat-style UI validating 100% adherence to the hackathon's core platform requirements.
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## Completed Milestones
|
| 57 |
+
|
| 58 |
+
1. **Bug Remediation:** Fixed critical parsing logic failures where double-escaped regex patterns (`[\\\\d\\\\.\\\\-]+`) prevented the extraction of metrics like `FINAL_LOSS:0.4523`. Also resolved the `ModuleNotFoundError: torch` in the sandbox.
|
| 59 |
+
2. **Environment Token Logic:** Re-wrote authentication management using `python-dotenv` for local caching overrides, ensuring 401 Unauthorized API failures do not disrupt model inference when tokens expire mid-session.
|
| 60 |
+
3. **Task Trajectory Generation:** Enabled persistent trajectory state logging inside `environment.py` for advanced analytics and training extraction.
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## Strategic Next Steps
|
| 65 |
+
|
| 66 |
+
**Phase 2: Reinforcement Learning Integration**
|
| 67 |
+
The zero-shot inference pipeline currently scores dynamically. The immediate next phase leverages the Trajectory Logic:
|
| 68 |
+
1. Export the stored trajectory JSON logs containing `buggy_code`, `reward`, and `error_log`.
|
| 69 |
+
2. Integrate the **TRL (Transformer Reinforcement Learning) library**.
|
| 70 |
+
3. Implement Group Relative Policy Optimization (GRPO) to iteratively train a specific lightweight model (like a 7B parameter local model) to perform significantly better on PyTorch debugging tasks than larger, generalized zero-shot models.
|
pyproject.toml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-whipstudio"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "ML Debug environment for OpenEnv"
|
| 9 |
+
requires-python = ">=3.11"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core[core]>=0.2.1",
|
| 12 |
+
"fastapi>=0.110.0",
|
| 13 |
+
"uvicorn>=0.27.0",
|
| 14 |
+
"pydantic>=2.0.0",
|
| 15 |
+
"httpx>=0.27.0",
|
| 16 |
+
"torch>=2.2.0",
|
| 17 |
+
"smolagents>=1.0.0",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[project.optional-dependencies]
|
| 21 |
+
dev = [
|
| 22 |
+
"pytest>=8.0.0",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.scripts]
|
| 26 |
+
whipstudio-server = "server.app:main"
|
| 27 |
+
server = "server.app:main"
|
| 28 |
+
|
| 29 |
+
[tool.setuptools]
|
| 30 |
+
include-package-data = true
|
| 31 |
+
packages = ["server", "server.tasks"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""ML debug server package."""
|
server/app.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
load_dotenv(override=True)
|
| 7 |
+
|
| 8 |
+
import httpx
|
| 9 |
+
from fastapi import FastAPI, Request
|
| 10 |
+
from fastapi.responses import RedirectResponse
|
| 11 |
+
from fastapi.responses import HTMLResponse
|
| 12 |
+
|
| 13 |
+
from openenv.core.env_server.http_server import create_app
|
| 14 |
+
|
| 15 |
+
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 16 |
+
if _project_root not in sys.path:
|
| 17 |
+
sys.path.insert(0, _project_root)
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from ..models import MLDebugAction, MLDebugObservation
|
| 21 |
+
from .environment import MLDebugEnvironment
|
| 22 |
+
from .tasks.graders import RunResult, score_task
|
| 23 |
+
except ImportError:
|
| 24 |
+
from models import MLDebugAction, MLDebugObservation
|
| 25 |
+
from server.environment import MLDebugEnvironment
|
| 26 |
+
from server.tasks.graders import RunResult, score_task
|
| 27 |
+
|
| 28 |
+
# Disable OpenEnv's default web UI so /web can mirror the custom Gradio UI.
|
| 29 |
+
os.environ["ENABLE_WEB_INTERFACE"] = "false"
|
| 30 |
+
|
| 31 |
+
app: FastAPI = create_app(
|
| 32 |
+
MLDebugEnvironment,
|
| 33 |
+
MLDebugAction,
|
| 34 |
+
MLDebugObservation,
|
| 35 |
+
env_name="whipstudio-env",
|
| 36 |
+
max_concurrent_envs=4,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@app.get("/__build", include_in_schema=False)
|
| 41 |
+
def build_info():
|
| 42 |
+
"""Build/runtime fingerprint to confirm what code is deployed."""
|
| 43 |
+
import platform
|
| 44 |
+
|
| 45 |
+
return {
|
| 46 |
+
"env_name": "whipstudio-env",
|
| 47 |
+
"python": platform.python_version(),
|
| 48 |
+
"platform": platform.platform(),
|
| 49 |
+
"port": os.environ.get("PORT"),
|
| 50 |
+
"enable_web_interface": os.environ.get("ENABLE_WEB_INTERFACE"),
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _has_route(path: str, method: str) -> bool:
|
| 55 |
+
method = method.upper()
|
| 56 |
+
for route in app.router.routes:
|
| 57 |
+
if getattr(route, "path", None) != path:
|
| 58 |
+
continue
|
| 59 |
+
methods = getattr(route, "methods", None)
|
| 60 |
+
if methods and method in methods:
|
| 61 |
+
return True
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@app.get("/", include_in_schema=False)
|
| 66 |
+
def root_redirect():
|
| 67 |
+
return RedirectResponse(url="/ui", status_code=307)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if not _has_route("/health", "GET"):
|
| 71 |
+
|
| 72 |
+
@app.get("/health", include_in_schema=False)
|
| 73 |
+
def health_get():
|
| 74 |
+
return {"status": "ok"}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if not _has_route("/health", "POST"):
|
| 78 |
+
|
| 79 |
+
@app.post("/health", include_in_schema=False)
|
| 80 |
+
def health_post():
|
| 81 |
+
return {"status": "ok"}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@app.get("/reset")
|
| 85 |
+
def reset_liveness():
|
| 86 |
+
return {"status": "ok", "message": "use POST /reset to start an episode"}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@app.get("/tasks")
|
| 90 |
+
def list_tasks():
|
| 91 |
+
return {
|
| 92 |
+
"tasks": [
|
| 93 |
+
{"id": "task1", "name": "Broken training loop", "difficulty": "easy"},
|
| 94 |
+
{"id": "task2", "name": "Silent NaN loss", "difficulty": "medium"},
|
| 95 |
+
{"id": "task3", "name": "OOM and data leakage", "difficulty": "hard"},
|
| 96 |
+
{"id": "task4", "name": "Wrong loss function", "difficulty": "medium"},
|
| 97 |
+
{"id": "task5", "name": "Frozen backbone", "difficulty": "medium"},
|
| 98 |
+
],
|
| 99 |
+
"action_schema": {
|
| 100 |
+
"fixed_code": "string (required) — complete runnable Python script",
|
| 101 |
+
"explanation": "string (optional) — description of bugs found",
|
| 102 |
+
"attempt_number": "int 1-3 (optional) — which attempt this is",
|
| 103 |
+
},
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@app.post("/grader")
|
| 108 |
+
def run_grader(payload: dict):
|
| 109 |
+
task_id = payload.get("task_id", "task1")
|
| 110 |
+
result = RunResult(
|
| 111 |
+
exit_code=payload.get("exit_code", -1),
|
| 112 |
+
stdout=payload.get("stdout", ""),
|
| 113 |
+
stderr=payload.get("stderr", ""),
|
| 114 |
+
elapsed_seconds=payload.get("elapsed", 0.0),
|
| 115 |
+
timed_out=payload.get("timed_out", False),
|
| 116 |
+
fixed_code=payload.get("fixed_code", ""),
|
| 117 |
+
)
|
| 118 |
+
score, breakdown = score_task(task_id, result)
|
| 119 |
+
return {"task_id": task_id, "score": score, "breakdown": breakdown}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@app.get("/baseline")
|
| 123 |
+
async def run_baseline(request: Request):
|
| 124 |
+
try:
|
| 125 |
+
from ..baseline_agent import run_single_task
|
| 126 |
+
except ImportError:
|
| 127 |
+
from baseline_agent import run_single_task
|
| 128 |
+
|
| 129 |
+
env_url = str(request.base_url).rstrip("/")
|
| 130 |
+
results = {}
|
| 131 |
+
task_scores = {}
|
| 132 |
+
for task_id in ["task1", "task2", "task3", "task4", "task5"]:
|
| 133 |
+
try:
|
| 134 |
+
score = await asyncio.wait_for(run_single_task(task_id, env_url), timeout=120.0)
|
| 135 |
+
results[task_id] = round(score, 4)
|
| 136 |
+
task_scores[task_id] = round(score, 4)
|
| 137 |
+
except TimeoutError:
|
| 138 |
+
results[task_id] = 0.0
|
| 139 |
+
task_scores[task_id] = 0.0
|
| 140 |
+
results[f"{task_id}_error"] = "timeout: task took longer than 120s"
|
| 141 |
+
except httpx.HTTPError as exc:
|
| 142 |
+
results[task_id] = 0.0
|
| 143 |
+
task_scores[task_id] = 0.0
|
| 144 |
+
results[f"{task_id}_error"] = f"http_error: {exc.__class__.__name__}: {exc}"
|
| 145 |
+
except Exception as exc:
|
| 146 |
+
results[task_id] = 0.0
|
| 147 |
+
task_scores[task_id] = 0.0
|
| 148 |
+
results[f"{task_id}_error"] = f"internal_error: {exc.__class__.__name__}: {exc}"
|
| 149 |
+
avg = round(sum(task_scores.values()) / max(1, len(task_scores)), 4)
|
| 150 |
+
return {"baseline_scores": results, "average": avg, "env_url": env_url}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@app.get("/baseline/task/{task_id}")
|
| 154 |
+
async def run_baseline_single(task_id: str, request: Request):
|
| 155 |
+
"""Run the baseline agent on a single task. Returns score + details."""
|
| 156 |
+
try:
|
| 157 |
+
from ..baseline_agent import run_single_task_detailed
|
| 158 |
+
except ImportError:
|
| 159 |
+
from baseline_agent import run_single_task_detailed
|
| 160 |
+
|
| 161 |
+
env_url = str(request.base_url).rstrip("/")
|
| 162 |
+
try:
|
| 163 |
+
result = await asyncio.wait_for(run_single_task_detailed(task_id, env_url), timeout=120.0)
|
| 164 |
+
return {
|
| 165 |
+
"task_id": task_id,
|
| 166 |
+
"score": round(result["score"], 4),
|
| 167 |
+
"status": "ok",
|
| 168 |
+
"fixed_code": result.get("fixed_code", ""),
|
| 169 |
+
"output": result.get("output", ""),
|
| 170 |
+
"attempts": result.get("attempts", []),
|
| 171 |
+
}
|
| 172 |
+
except TimeoutError:
|
| 173 |
+
return {"task_id": task_id, "score": 0.0, "status": "timeout", "error": "Task took longer than 120s"}
|
| 174 |
+
except Exception as exc:
|
| 175 |
+
return {"task_id": task_id, "score": 0.0, "status": "error", "error": f"{exc.__class__.__name__}: {exc}"}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@app.get("/baseline/health")
|
| 179 |
+
def baseline_health():
|
| 180 |
+
hf_token_present = bool(os.environ.get("HF_TOKEN"))
|
| 181 |
+
model_ready = False
|
| 182 |
+
model_error = None
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
try:
|
| 186 |
+
from ..baseline_agent import get_model
|
| 187 |
+
except ImportError:
|
| 188 |
+
from baseline_agent import get_model
|
| 189 |
+
|
| 190 |
+
get_model()
|
| 191 |
+
model_ready = True
|
| 192 |
+
except Exception as exc:
|
| 193 |
+
model_error = f"{exc.__class__.__name__}: {exc}"
|
| 194 |
+
|
| 195 |
+
status = "ok" if hf_token_present and model_ready else "degraded"
|
| 196 |
+
return {
|
| 197 |
+
"status": status,
|
| 198 |
+
"hf_token_present": hf_token_present,
|
| 199 |
+
"model_ready": model_ready,
|
| 200 |
+
"model_error": model_error,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
_ui_mounted = False
|
| 205 |
+
try:
|
| 206 |
+
import gradio as gr
|
| 207 |
+
try:
|
| 208 |
+
from ..gradio_app import build_ui
|
| 209 |
+
except ImportError:
|
| 210 |
+
from gradio_app import build_ui
|
| 211 |
+
|
| 212 |
+
gradio_ui = build_ui()
|
| 213 |
+
app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
|
| 214 |
+
_ui_mounted = True
|
| 215 |
+
except Exception as e:
|
| 216 |
+
# Don't fail silently in Spaces: return a helpful error page at /ui.
|
| 217 |
+
import traceback
|
| 218 |
+
|
| 219 |
+
print(f"Failed to mount Gradio UI: {e}")
|
| 220 |
+
traceback.print_exc()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if not _ui_mounted:
|
| 224 |
+
@app.get("/ui", include_in_schema=False)
|
| 225 |
+
@app.get("/ui/", include_in_schema=False)
|
| 226 |
+
def ui_mount_failed():
|
| 227 |
+
return HTMLResponse(
|
| 228 |
+
"<h2>WhipStudio UI failed to start</h2>"
|
| 229 |
+
"<p>The API server is running, but the Gradio UI could not be mounted.</p>"
|
| 230 |
+
"<p>Check container logs for <code>Failed to mount Gradio UI</code>.</p>",
|
| 231 |
+
status_code=500,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@app.api_route("/web", methods=["GET", "POST"], include_in_schema=False)
|
| 236 |
+
def web_redirect_root():
|
| 237 |
+
return RedirectResponse(url="/ui", status_code=307)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@app.api_route("/web/{path:path}", methods=["GET", "POST"], include_in_schema=False)
|
| 241 |
+
def web_redirect_path(path: str):
|
| 242 |
+
if path:
|
| 243 |
+
return RedirectResponse(url=f"/ui/{path}", status_code=307)
|
| 244 |
+
return RedirectResponse(url="/ui", status_code=307)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def main(host: str = "0.0.0.0", port: int = 7860):
|
| 248 |
+
import uvicorn
|
| 249 |
+
|
| 250 |
+
uvicorn.run("server.app:app", host=host, port=port, reload=False)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
main()
|
server/environment.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
from uuid import uuid4
|
| 5 |
+
|
| 6 |
+
from openenv.core.env_server.interfaces import Environment
|
| 7 |
+
from openenv.core.env_server.types import State
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from ..models import MLDebugAction, MLDebugObservation
|
| 11 |
+
from .sandbox import execute_code
|
| 12 |
+
from .tasks import task1_broken_loop, task2_nan_loss, task3_oom_leakage, task4_wrong_loss, task5_frozen_backbone
|
| 13 |
+
from .tasks.graders import parse_losses, parse_val_accs, score_task
|
| 14 |
+
except ImportError:
|
| 15 |
+
from models import MLDebugAction, MLDebugObservation
|
| 16 |
+
from server.sandbox import execute_code
|
| 17 |
+
from server.tasks import task1_broken_loop, task2_nan_loss, task3_oom_leakage, task4_wrong_loss, task5_frozen_backbone
|
| 18 |
+
from server.tasks.graders import parse_losses, parse_val_accs, score_task
|
| 19 |
+
|
| 20 |
+
TASKS = {
|
| 21 |
+
"task1": task1_broken_loop,
|
| 22 |
+
"task2": task2_nan_loss,
|
| 23 |
+
"task3": task3_oom_leakage,
|
| 24 |
+
"task4": task4_wrong_loss,
|
| 25 |
+
"task5": task5_frozen_backbone,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MLDebugEnvironment(Environment):
|
| 30 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 31 |
+
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 34 |
+
self._task_id = "task1"
|
| 35 |
+
self._best_reward = 0.0
|
| 36 |
+
self._trajectory: list[dict] = []
|
| 37 |
+
|
| 38 |
+
def reset(self, task_id: str = "task1", **kwargs) -> MLDebugObservation: # type: ignore[override]
|
| 39 |
+
if task_id not in TASKS:
|
| 40 |
+
task_id = "task1"
|
| 41 |
+
|
| 42 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 43 |
+
self._task_id = task_id
|
| 44 |
+
self._best_reward = 0.0
|
| 45 |
+
self._trajectory = []
|
| 46 |
+
|
| 47 |
+
task = TASKS[self._task_id]
|
| 48 |
+
return MLDebugObservation(
|
| 49 |
+
task_id=self._task_id,
|
| 50 |
+
task_description=task.TASK_DESCRIPTION.strip(),
|
| 51 |
+
buggy_code=task.BUGGY_CODE.strip(),
|
| 52 |
+
error_log="",
|
| 53 |
+
last_reward=0.0,
|
| 54 |
+
metrics={},
|
| 55 |
+
done=False,
|
| 56 |
+
reward=0.0,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def step(self, action: MLDebugAction) -> MLDebugObservation: # type: ignore[override]
|
| 60 |
+
self._state.step_count += 1
|
| 61 |
+
|
| 62 |
+
if not action.fixed_code or not action.fixed_code.strip():
|
| 63 |
+
done = True
|
| 64 |
+
metrics = {
|
| 65 |
+
"exit_code": -1,
|
| 66 |
+
"elapsed_seconds": 0.0,
|
| 67 |
+
"timed_out": False,
|
| 68 |
+
"step": self._state.step_count,
|
| 69 |
+
"best_reward_so_far": self._best_reward,
|
| 70 |
+
"error": "empty code submitted",
|
| 71 |
+
}
|
| 72 |
+
task = TASKS[self._task_id]
|
| 73 |
+
return MLDebugObservation(
|
| 74 |
+
task_id=self._task_id,
|
| 75 |
+
task_description=task.TASK_DESCRIPTION.strip(),
|
| 76 |
+
buggy_code=task.BUGGY_CODE.strip(),
|
| 77 |
+
error_log="empty code submitted",
|
| 78 |
+
last_reward=0.0,
|
| 79 |
+
metrics=metrics,
|
| 80 |
+
done=done,
|
| 81 |
+
reward=0.0,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
run_result1 = execute_code(action.fixed_code)
|
| 85 |
+
reward1, breakdown1 = score_task(self._task_id, run_result1)
|
| 86 |
+
|
| 87 |
+
consistency_flag = False
|
| 88 |
+
reward_variance = 0.0
|
| 89 |
+
final_reward = reward1
|
| 90 |
+
final_breakdown = breakdown1
|
| 91 |
+
run_result = run_result1
|
| 92 |
+
|
| 93 |
+
if reward1 > 0.5:
|
| 94 |
+
run_result2 = execute_code(action.fixed_code)
|
| 95 |
+
reward2, breakdown2 = score_task(self._task_id, run_result2)
|
| 96 |
+
reward_variance = abs(reward1 - reward2)
|
| 97 |
+
if reward_variance > 0.15:
|
| 98 |
+
consistency_flag = True
|
| 99 |
+
final_reward = min(reward1, reward2)
|
| 100 |
+
if reward2 < reward1:
|
| 101 |
+
final_breakdown = breakdown2
|
| 102 |
+
run_result = run_result2
|
| 103 |
+
else:
|
| 104 |
+
consistency_flag = False
|
| 105 |
+
final_reward = (reward1 + reward2) / 2.0
|
| 106 |
+
else:
|
| 107 |
+
consistency_flag = False
|
| 108 |
+
final_reward = reward1
|
| 109 |
+
|
| 110 |
+
self._best_reward = max(self._best_reward, final_reward)
|
| 111 |
+
done = self._state.step_count >= 3 or final_reward >= 0.95
|
| 112 |
+
|
| 113 |
+
losses = parse_losses(run_result.stdout)
|
| 114 |
+
val_accs = parse_val_accs(run_result.stdout)
|
| 115 |
+
final_loss = None
|
| 116 |
+
if losses:
|
| 117 |
+
final_loss = losses[-1]
|
| 118 |
+
else:
|
| 119 |
+
match = re.search(r"FINAL_LOSS:([-\d.]+)", run_result.stdout)
|
| 120 |
+
if match:
|
| 121 |
+
final_loss = float(match.group(1))
|
| 122 |
+
|
| 123 |
+
metrics = {
|
| 124 |
+
"exit_code": run_result.exit_code,
|
| 125 |
+
"elapsed_seconds": run_result.elapsed_seconds,
|
| 126 |
+
"timed_out": run_result.timed_out,
|
| 127 |
+
"step": self._state.step_count,
|
| 128 |
+
"best_reward_so_far": self._best_reward,
|
| 129 |
+
"final_loss": final_loss,
|
| 130 |
+
"nan_count": sum(1 for x in losses if math.isnan(x) or math.isinf(x)) if losses else 0,
|
| 131 |
+
"val_acc": val_accs[-1] if val_accs else None,
|
| 132 |
+
"consistency_flag": consistency_flag,
|
| 133 |
+
"reward_variance": round(reward_variance, 4),
|
| 134 |
+
"reward_breakdown": final_breakdown,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
task = TASKS[self._task_id]
|
| 138 |
+
|
| 139 |
+
self._trajectory.append({
|
| 140 |
+
"step": self._state.step_count,
|
| 141 |
+
"reward": final_reward,
|
| 142 |
+
"best_reward": self._best_reward,
|
| 143 |
+
"metrics": metrics,
|
| 144 |
+
"done": done,
|
| 145 |
+
"timestamp": time.time(),
|
| 146 |
+
})
|
| 147 |
+
|
| 148 |
+
return MLDebugObservation(
|
| 149 |
+
task_id=self._task_id,
|
| 150 |
+
task_description=task.TASK_DESCRIPTION.strip(),
|
| 151 |
+
buggy_code=task.BUGGY_CODE.strip(),
|
| 152 |
+
error_log=(run_result.stdout + "\n" + run_result.stderr).strip()[:2000],
|
| 153 |
+
last_reward=final_reward,
|
| 154 |
+
metrics=metrics,
|
| 155 |
+
done=done,
|
| 156 |
+
reward=final_reward,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def trajectory(self) -> list[dict]:
|
| 161 |
+
return list(self._trajectory)
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def state(self) -> State:
|
| 165 |
+
return self._state
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.1.1
|
| 2 |
+
fastapi>=0.110.0
|
| 3 |
+
uvicorn>=0.27.0
|
| 4 |
+
httpx>=0.27.0
|
| 5 |
+
smolagents>=1.0.0
|
| 6 |
+
pydantic>=2.0.0
|
| 7 |
+
python-dotenv>=1.0.0
|
| 8 |
+
gradio>=4.0.0
|
| 9 |
+
tqdm>=4.0.0
|
| 10 |
+
scikit-learn
|
| 11 |
+
matplotlib
|
server/sandbox.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import sys
|
| 4 |
+
import tempfile
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
from .tasks.graders import RunResult
|
| 8 |
+
|
| 9 |
+
TIMEOUT_SECONDS = 30
|
| 10 |
+
MAX_OUTPUT_BYTES = 8192
|
| 11 |
+
BANNED_PATTERNS = [
|
| 12 |
+
"os.system",
|
| 13 |
+
"subprocess.",
|
| 14 |
+
"shutil.rmtree",
|
| 15 |
+
"open(",
|
| 16 |
+
"__import__",
|
| 17 |
+
"exec(",
|
| 18 |
+
"socket.",
|
| 19 |
+
"urllib.",
|
| 20 |
+
"requests.",
|
| 21 |
+
]
|
| 22 |
+
SAFE_ENV = {
|
| 23 |
+
"PATH": os.environ.get("PATH", "/usr/bin:/usr/local/bin"),
|
| 24 |
+
"HOME": "/tmp",
|
| 25 |
+
"PYTHONPATH": os.pathsep.join(sys.path),
|
| 26 |
+
"PYTHONDONTWRITEBYTECODE": "1",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def strip_markdown_code(code: str) -> str:
|
| 31 |
+
if "```python" in code:
|
| 32 |
+
return code.split("```python", 1)[1].split("```", 1)[0].strip()
|
| 33 |
+
if "```" in code:
|
| 34 |
+
return code.split("```", 1)[1].split("```", 1)[0].strip()
|
| 35 |
+
return code.strip()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def execute_code(code: str) -> RunResult:
|
| 39 |
+
"""Execute agent-submitted code in an isolated subprocess."""
|
| 40 |
+
cleaned_code = strip_markdown_code(code)
|
| 41 |
+
|
| 42 |
+
for pattern in BANNED_PATTERNS:
|
| 43 |
+
if pattern in cleaned_code:
|
| 44 |
+
return RunResult(
|
| 45 |
+
exit_code=-1,
|
| 46 |
+
stdout="",
|
| 47 |
+
stderr=f'Execution blocked: banned pattern "{pattern}" detected.',
|
| 48 |
+
elapsed_seconds=0.0,
|
| 49 |
+
timed_out=False,
|
| 50 |
+
fixed_code=cleaned_code,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, dir="/tmp") as temp_file:
|
| 54 |
+
temp_file.write(cleaned_code)
|
| 55 |
+
tmp_path = temp_file.name
|
| 56 |
+
|
| 57 |
+
start = time.time()
|
| 58 |
+
try:
|
| 59 |
+
proc = subprocess.run(
|
| 60 |
+
["python", tmp_path],
|
| 61 |
+
capture_output=True,
|
| 62 |
+
text=True,
|
| 63 |
+
timeout=TIMEOUT_SECONDS,
|
| 64 |
+
env=SAFE_ENV,
|
| 65 |
+
cwd="/tmp",
|
| 66 |
+
)
|
| 67 |
+
return RunResult(
|
| 68 |
+
exit_code=proc.returncode,
|
| 69 |
+
stdout=proc.stdout[:MAX_OUTPUT_BYTES],
|
| 70 |
+
stderr=proc.stderr[:2048],
|
| 71 |
+
elapsed_seconds=round(time.time() - start, 2),
|
| 72 |
+
timed_out=False,
|
| 73 |
+
fixed_code=cleaned_code,
|
| 74 |
+
)
|
| 75 |
+
except subprocess.TimeoutExpired:
|
| 76 |
+
return RunResult(
|
| 77 |
+
exit_code=-1,
|
| 78 |
+
stdout="",
|
| 79 |
+
stderr="Execution timed out after 30 seconds.",
|
| 80 |
+
elapsed_seconds=TIMEOUT_SECONDS,
|
| 81 |
+
timed_out=True,
|
| 82 |
+
fixed_code=cleaned_code,
|
| 83 |
+
)
|
| 84 |
+
finally:
|
| 85 |
+
try:
|
| 86 |
+
os.unlink(tmp_path)
|
| 87 |
+
except Exception:
|
| 88 |
+
pass
|
server/start.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
export ENABLE_WEB_INTERFACE=false
|
| 5 |
+
|
| 6 |
+
# Single process: FastAPI serves both API and mounted Gradio at /ui
|
| 7 |
+
uvicorn server.app:app --host 0.0.0.0 --port "${PORT:-7860}"
|
server/tasks/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Task definitions for ML debug environment."""
|
server/tasks/graders.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import re
|
| 3 |
+
import ast
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class RunResult:
|
| 9 |
+
exit_code: int
|
| 10 |
+
stdout: str
|
| 11 |
+
stderr: str
|
| 12 |
+
elapsed_seconds: float
|
| 13 |
+
timed_out: bool
|
| 14 |
+
fixed_code: str = ""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def extract_metrics_block(stdout: str) -> str:
|
| 18 |
+
match = re.search(r"##METRICS_START##(.*?)##METRICS_END##", stdout, re.DOTALL)
|
| 19 |
+
if match:
|
| 20 |
+
return match.group(1)
|
| 21 |
+
return stdout
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_losses(stdout: str) -> list[float]:
|
| 25 |
+
stdout = extract_metrics_block(stdout)
|
| 26 |
+
match = re.search(r"LOSSES:\[([^\]]+)\]", stdout)
|
| 27 |
+
if not match:
|
| 28 |
+
return []
|
| 29 |
+
try:
|
| 30 |
+
return [float(x.strip()) for x in match.group(1).split(",")]
|
| 31 |
+
except Exception:
|
| 32 |
+
return []
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def parse_val_accs(stdout: str) -> list[float]:
|
| 36 |
+
stdout = extract_metrics_block(stdout)
|
| 37 |
+
match = re.search(r"VAL_ACCS:\[([^\]]+)\]", stdout)
|
| 38 |
+
if not match:
|
| 39 |
+
return []
|
| 40 |
+
try:
|
| 41 |
+
return [float(x.strip()) for x in match.group(1).split(",")]
|
| 42 |
+
except Exception:
|
| 43 |
+
return []
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def parse_scalar(stdout: str, key: str) -> float | None:
|
| 47 |
+
stdout = extract_metrics_block(stdout)
|
| 48 |
+
match = re.search(rf"{key}:([-\d.]+)", stdout)
|
| 49 |
+
return float(match.group(1)) if match else None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def is_valid_submission(code: str, stdout: str, exit_code: int) -> tuple[bool, str]:
|
| 53 |
+
if exit_code == 0:
|
| 54 |
+
if "LOSSES:" not in stdout and "FINAL_LOSS:" not in stdout:
|
| 55 |
+
return False, "No valid metrics output detected"
|
| 56 |
+
if "LOSSES:" in stdout:
|
| 57 |
+
losses = parse_losses(stdout)
|
| 58 |
+
if len(losses) < 5:
|
| 59 |
+
return False, "Fewer than 5 loss values parsed"
|
| 60 |
+
try:
|
| 61 |
+
tree = ast.parse(code)
|
| 62 |
+
if not any(isinstance(node, (ast.For, ast.While)) for node in ast.walk(tree)):
|
| 63 |
+
return False, "No ast.For or ast.While node found"
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
return True, ""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def sigmoid_reward(value: float, center: float, steepness: float, invert: bool = False) -> float:
|
| 70 |
+
try:
|
| 71 |
+
if invert:
|
| 72 |
+
x = steepness * (value - center)
|
| 73 |
+
else:
|
| 74 |
+
x = steepness * (center - value)
|
| 75 |
+
return round(1.0 / (1.0 + math.exp(-x)), 4)
|
| 76 |
+
except OverflowError:
|
| 77 |
+
return 0.0 if (invert and value > center) or (not invert and value < center) else 1.0
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def grade_task1(result: RunResult) -> tuple[float, dict]:
|
| 81 |
+
valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
|
| 82 |
+
if not valid:
|
| 83 |
+
return 0.0, {"reason": reason}
|
| 84 |
+
|
| 85 |
+
if result.timed_out:
|
| 86 |
+
return 0.05, {"reason": "timed_out"}
|
| 87 |
+
if result.exit_code != 0:
|
| 88 |
+
return 0.0, {"reason": "crash"}
|
| 89 |
+
|
| 90 |
+
losses = parse_losses(result.stdout)
|
| 91 |
+
if not losses:
|
| 92 |
+
return 0.1, {"reason": "no_losses_parsed"}
|
| 93 |
+
if any(math.isnan(loss) or math.isinf(loss) for loss in losses):
|
| 94 |
+
return 0.15, {"reason": "nan_inf_found"}
|
| 95 |
+
|
| 96 |
+
final = losses[-1]
|
| 97 |
+
base_score = sigmoid_reward(final, center=0.75, steepness=3.0, invert=True)
|
| 98 |
+
|
| 99 |
+
bonus = 0.0
|
| 100 |
+
half = len(losses) // 2
|
| 101 |
+
if half > 0:
|
| 102 |
+
first_half = sum(losses[:half]) / half
|
| 103 |
+
second_half = sum(losses[half:]) / len(losses[half:])
|
| 104 |
+
if second_half < 0.85 * first_half:
|
| 105 |
+
bonus = 0.1
|
| 106 |
+
|
| 107 |
+
final_score = min(1.0, base_score + bonus)
|
| 108 |
+
breakdown = {"base_score": base_score, "monotonicity_bonus": bonus}
|
| 109 |
+
return final_score, breakdown
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def grade_task2(result: RunResult) -> tuple[float, dict]:
|
| 113 |
+
valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
|
| 114 |
+
if not valid:
|
| 115 |
+
return 0.0, {"reason": reason}
|
| 116 |
+
|
| 117 |
+
if result.timed_out:
|
| 118 |
+
return 0.05, {"reason": "timed_out"}
|
| 119 |
+
if result.exit_code != 0:
|
| 120 |
+
return 0.0, {"reason": "crash"}
|
| 121 |
+
|
| 122 |
+
losses = parse_losses(result.stdout)
|
| 123 |
+
if not losses or len(losses) < 30:
|
| 124 |
+
return 0.1, {"reason": "too_few_losses"}
|
| 125 |
+
|
| 126 |
+
nan_count = sum(1 for loss in losses if math.isnan(loss) or math.isinf(loss))
|
| 127 |
+
if nan_count == len(losses):
|
| 128 |
+
return 0.0, {"reason": "all_nans"}
|
| 129 |
+
|
| 130 |
+
nan_ratio = nan_count / len(losses)
|
| 131 |
+
finite_losses = [loss for loss in losses if not math.isnan(loss) and not math.isinf(loss)]
|
| 132 |
+
final_finite_loss = finite_losses[-1] if finite_losses else float('inf')
|
| 133 |
+
|
| 134 |
+
convergence_score = sigmoid_reward(final_finite_loss, center=0.5, steepness=4.0, invert=True)
|
| 135 |
+
convergence_score *= (1.0 - nan_ratio)
|
| 136 |
+
|
| 137 |
+
stability_bonus = 0.0
|
| 138 |
+
if len(finite_losses) >= 20:
|
| 139 |
+
tail = finite_losses[-20:]
|
| 140 |
+
mean_tail = sum(tail) / len(tail)
|
| 141 |
+
tail_variance = sum((x - mean_tail) ** 2 for x in tail) / len(tail)
|
| 142 |
+
stability_bonus = sigmoid_reward(tail_variance, center=0.01, steepness=200.0, invert=True) * 0.1
|
| 143 |
+
|
| 144 |
+
final_score = min(1.0, convergence_score + stability_bonus)
|
| 145 |
+
breakdown = {"convergence_score": convergence_score, "nan_penalty": (1.0 - nan_ratio), "stability_bonus": stability_bonus, "nan_ratio": nan_ratio}
|
| 146 |
+
return final_score, breakdown
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def grade_task3(result: RunResult) -> tuple[float, dict]:
|
| 150 |
+
valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
|
| 151 |
+
if not valid:
|
| 152 |
+
return 0.0, {"reason": reason}
|
| 153 |
+
|
| 154 |
+
if result.timed_out:
|
| 155 |
+
return 0.1, {"reason": "timed_out"}
|
| 156 |
+
|
| 157 |
+
if result.exit_code != 0:
|
| 158 |
+
if "out of memory" in result.stderr.lower():
|
| 159 |
+
return 0.1, {"reason": "oom"}
|
| 160 |
+
return 0.0, {"reason": "crash"}
|
| 161 |
+
|
| 162 |
+
val_accs = parse_val_accs(result.stdout)
|
| 163 |
+
final_loss_val = parse_scalar(result.stdout, "FINAL_LOSS")
|
| 164 |
+
|
| 165 |
+
memory_score = 0.0
|
| 166 |
+
if final_loss_val is not None:
|
| 167 |
+
memory_score = sigmoid_reward(final_loss_val, center=50.0, steepness=0.05, invert=True) * 0.5
|
| 168 |
+
|
| 169 |
+
leakage_score = 0.0
|
| 170 |
+
early_acc = 0.0
|
| 171 |
+
final_acc = 0.0
|
| 172 |
+
if val_accs and len(val_accs) >= 2:
|
| 173 |
+
early_acc = sum(val_accs[:2]) / 2.0
|
| 174 |
+
final_acc = val_accs[-1]
|
| 175 |
+
|
| 176 |
+
leak_p1 = sigmoid_reward(early_acc, center=0.75, steepness=20.0, invert=True) * 0.3
|
| 177 |
+
leak_p2 = sigmoid_reward(final_acc, center=0.68, steepness=15.0, invert=False) * 0.7
|
| 178 |
+
leakage_score = (leak_p1 + leak_p2) * 0.5
|
| 179 |
+
|
| 180 |
+
final_score = min(1.0, memory_score + leakage_score)
|
| 181 |
+
breakdown = {"memory_score": memory_score, "leakage_score": leakage_score, "early_acc": early_acc, "final_acc": final_acc}
|
| 182 |
+
return final_score, breakdown
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def grade_task4(result: RunResult) -> tuple[float, dict]:
|
| 186 |
+
valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
|
| 187 |
+
if not valid:
|
| 188 |
+
return 0.0, {"reason": reason}
|
| 189 |
+
|
| 190 |
+
if result.timed_out:
|
| 191 |
+
return 0.1, {"reason": "timed_out"}
|
| 192 |
+
|
| 193 |
+
if result.exit_code != 0:
|
| 194 |
+
return 0.0, {"reason": "crash"}
|
| 195 |
+
|
| 196 |
+
final_loss = parse_scalar(result.stdout, "FINAL_LOSS")
|
| 197 |
+
avg_labels = parse_scalar(result.stdout, "AVG_LABELS")
|
| 198 |
+
f1 = parse_scalar(result.stdout, "F1_SCORE")
|
| 199 |
+
|
| 200 |
+
loss_score = 0.0
|
| 201 |
+
if final_loss is not None:
|
| 202 |
+
loss_score = sigmoid_reward(final_loss, center=0.5, steepness=4.0, invert=True) * 0.3
|
| 203 |
+
|
| 204 |
+
labels_score = 0.0
|
| 205 |
+
if avg_labels is not None:
|
| 206 |
+
labels_score = sigmoid_reward(avg_labels, center=1.0, steepness=5.0, invert=False) * 0.3
|
| 207 |
+
|
| 208 |
+
f1_s = 0.0
|
| 209 |
+
if f1 is not None:
|
| 210 |
+
f1_s = sigmoid_reward(f1, center=0.6, steepness=10.0, invert=False) * 0.4
|
| 211 |
+
|
| 212 |
+
final_score = min(1.0, loss_score + labels_score + f1_s)
|
| 213 |
+
breakdown = {"loss_score": loss_score, "labels_score": labels_score, "f1_score": f1_s}
|
| 214 |
+
return final_score, breakdown
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def grade_task5(result: RunResult) -> tuple[float, dict]:
|
| 218 |
+
valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
|
| 219 |
+
if not valid:
|
| 220 |
+
return 0.0, {"reason": reason}
|
| 221 |
+
|
| 222 |
+
if result.timed_out:
|
| 223 |
+
return 0.1, {"reason": "timed_out"}
|
| 224 |
+
|
| 225 |
+
if result.exit_code != 0:
|
| 226 |
+
return 0.0, {"reason": "crash"}
|
| 227 |
+
|
| 228 |
+
final_loss = parse_scalar(result.stdout, "FINAL_LOSS")
|
| 229 |
+
grad_norm = parse_scalar(result.stdout, "BACKBONE_GRAD_NORM")
|
| 230 |
+
|
| 231 |
+
loss_score = 0.0
|
| 232 |
+
if final_loss is not None:
|
| 233 |
+
loss_score = sigmoid_reward(final_loss, center=2.2, steepness=3.0, invert=True) * 0.5
|
| 234 |
+
|
| 235 |
+
grad_score = 0.0
|
| 236 |
+
if grad_norm is not None:
|
| 237 |
+
grad_score = sigmoid_reward(grad_norm, center=0.001, steepness=1000.0, invert=False) * 0.5
|
| 238 |
+
|
| 239 |
+
final_score = min(1.0, loss_score + grad_score)
|
| 240 |
+
breakdown = {"loss_score": loss_score, "grad_score": grad_score}
|
| 241 |
+
return final_score, breakdown
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def score_task(task_id: str, result: RunResult) -> tuple[float, dict]:
|
| 245 |
+
graders = {
|
| 246 |
+
"task1": grade_task1,
|
| 247 |
+
"task2": grade_task2,
|
| 248 |
+
"task3": grade_task3,
|
| 249 |
+
"task4": grade_task4,
|
| 250 |
+
"task5": grade_task5,
|
| 251 |
+
}
|
| 252 |
+
if task_id not in graders:
|
| 253 |
+
raise ValueError(f"Unknown task_id: {task_id}")
|
| 254 |
+
|
| 255 |
+
score, breakdown = graders[task_id](result)
|
| 256 |
+
return round(max(0.0, min(1.0, score)), 4), breakdown
|
server/tasks/task1_broken_loop.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK_DESCRIPTION = """
|
| 2 |
+
This 2-class linear classifier training loop has bugs preventing convergence.
|
| 3 |
+
Fix it so that after 50 steps the loss is below 0.75 and decreasing.
|
| 4 |
+
Model: nn.Linear(10, 2), dataset: random 2-class, 32 samples/batch.
|
| 5 |
+
Print losses as: LOSSES:[val1, val2, ...]
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
BUGGY_CODE = """
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
torch.manual_seed(0)
|
| 12 |
+
model = nn.Linear(10, 2)
|
| 13 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=10.0) # BUG 1: lr too high
|
| 14 |
+
criterion = nn.CrossEntropyLoss()
|
| 15 |
+
losses = []
|
| 16 |
+
for step in range(50):
|
| 17 |
+
x = torch.randn(32, 10)
|
| 18 |
+
y = torch.randint(0, 2, (32,))
|
| 19 |
+
optimizer.zero_grad()
|
| 20 |
+
logits = model(x)
|
| 21 |
+
loss = criterion(logits, y)
|
| 22 |
+
optimizer.step() # BUG 2: step before backward
|
| 23 |
+
loss.backward() # BUG 3: backward after step
|
| 24 |
+
losses.append(loss.item())
|
| 25 |
+
print('##METRICS_START##')
|
| 26 |
+
print('LOSSES:' + str(losses))
|
| 27 |
+
print('##METRICS_END##')
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
GROUND_TRUTH_BUGS = [
|
| 31 |
+
"optimizer.step() called before loss.backward()",
|
| 32 |
+
"learning rate 10.0 should be ~0.001",
|
| 33 |
+
]
|
server/tasks/task2_nan_loss.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK_DESCRIPTION = """
|
| 2 |
+
This binary regression trainer produces NaN loss around step 15.
|
| 3 |
+
Fix the numerical instability so loss stays finite for all 60 steps
|
| 4 |
+
and the final loss is below 0.5.
|
| 5 |
+
Print losses as: LOSSES:[val1, val2, ...]
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
BUGGY_CODE = """
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
torch.manual_seed(42)
|
| 12 |
+
model = nn.Linear(16, 1)
|
| 13 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
| 14 |
+
losses = []
|
| 15 |
+
for step in range(60):
|
| 16 |
+
x = torch.randn(64, 16)
|
| 17 |
+
y = torch.rand(64, 1)
|
| 18 |
+
optimizer.zero_grad()
|
| 19 |
+
pred = torch.sigmoid(model(x))
|
| 20 |
+
# BUG: log(pred) can be -inf when pred rounds to 0.0
|
| 21 |
+
loss = -torch.mean(y * torch.log(pred) + (1 - y) * torch.log(1 - pred))
|
| 22 |
+
loss.backward()
|
| 23 |
+
optimizer.step()
|
| 24 |
+
losses.append(loss.item())
|
| 25 |
+
print('##METRICS_START##')
|
| 26 |
+
print('LOSSES:' + str(losses))
|
| 27 |
+
print('##METRICS_END##')
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
GROUND_TRUTH_BUGS = [
|
| 31 |
+
"torch.log(pred) when pred can be 0.0 after sigmoid — use F.binary_cross_entropy or clamp",
|
| 32 |
+
]
|
server/tasks/task3_oom_leakage.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK_DESCRIPTION = """
|
| 2 |
+
This trainer has TWO independent bugs:
|
| 3 |
+
1. A memory leak causing OOM crash before epoch 5 on CPU.
|
| 4 |
+
2. Data leakage inflating validation accuracy.
|
| 5 |
+
Fix both. After 20 epochs: val_acc > 0.70, no OOM, no suspicious early accuracy spike.
|
| 6 |
+
Print as: VAL_ACCS:[v1,v2,...] and FINAL_LOSS:X.XX
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
BUGGY_CODE = """
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.utils.data import DataLoader, TensorDataset, random_split
|
| 13 |
+
|
| 14 |
+
torch.manual_seed(42)
|
| 15 |
+
X = torch.randn(1000, 20)
|
| 16 |
+
y = (X[:, 0] > 0).float()
|
| 17 |
+
# BUG 1: augmentation before split — val set gets augmented
|
| 18 |
+
X = X + torch.randn_like(X) * 0.1
|
| 19 |
+
train_ds, val_ds = random_split(TensorDataset(X, y), [800, 200])
|
| 20 |
+
model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1))
|
| 21 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 22 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 23 |
+
train_losses, val_accs = [], []
|
| 24 |
+
total_loss = torch.tensor(0.0) # BUG 2: keeps computation graph alive
|
| 25 |
+
for epoch in range(20):
|
| 26 |
+
model.train()
|
| 27 |
+
for xb, yb in DataLoader(train_ds, batch_size=32):
|
| 28 |
+
optimizer.zero_grad()
|
| 29 |
+
out = model(xb).squeeze()
|
| 30 |
+
loss = criterion(out, yb)
|
| 31 |
+
loss.backward()
|
| 32 |
+
optimizer.step()
|
| 33 |
+
total_loss = total_loss + loss # BUG 2: graph accumulates
|
| 34 |
+
model.eval()
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
idx = val_ds.indices
|
| 37 |
+
xv, yv = X[idx], y[idx]
|
| 38 |
+
preds = (torch.sigmoid(model(xv)) > 0.5).float()
|
| 39 |
+
acc = (preds == yv).float().mean().item()
|
| 40 |
+
val_accs.append(round(acc, 4))
|
| 41 |
+
print('##METRICS_START##')
|
| 42 |
+
print('VAL_ACCS:' + str(val_accs))
|
| 43 |
+
print('FINAL_LOSS:' + str(total_loss.item()))
|
| 44 |
+
print('##METRICS_END##')
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
GROUND_TRUTH_BUGS = [
|
| 48 |
+
"Augmentation applied before split — move after split, apply to train only",
|
| 49 |
+
"total_loss += loss retains graph — use total_loss += loss.item()",
|
| 50 |
+
]
|
server/tasks/task4_wrong_loss.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK_DESCRIPTION = """
|
| 2 |
+
This is a multi-label classification problem where each sample can have multiple active classes.
|
| 3 |
+
However, the model is currently using `CrossEntropyLoss`, which is meant for single-label (mutually exclusive) classes.
|
| 4 |
+
Because of this, the loss trains but the predictions are essentially garbage (treating it as a single-label problem).
|
| 5 |
+
|
| 6 |
+
Fix the loss function so it correctly handles multi-label classification.
|
| 7 |
+
The grader will check:
|
| 8 |
+
1. Loss convergence (loss < 0.5)
|
| 9 |
+
2. Model predictions are actually multi-hot (avg > 1 label/sample)
|
| 10 |
+
3. F1 Score > 0.6
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
BUGGY_CODE = """
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from sklearn.metrics import f1_score
|
| 17 |
+
|
| 18 |
+
torch.manual_seed(42)
|
| 19 |
+
|
| 20 |
+
# Generate synthetic multi-label data (100 samples, 20 features, 5 classes)
|
| 21 |
+
X = torch.randn(100, 20)
|
| 22 |
+
# Each sample has a 30% chance of having each class active
|
| 23 |
+
y = (torch.rand(100, 5) > 0.7).float()
|
| 24 |
+
|
| 25 |
+
model = nn.Sequential(
|
| 26 |
+
nn.Linear(20, 64),
|
| 27 |
+
nn.ReLU(),
|
| 28 |
+
nn.Linear(64, 5)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
| 32 |
+
|
| 33 |
+
# BUG: CrossEntropyLoss is for single-label classification
|
| 34 |
+
criterion = nn.CrossEntropyLoss()
|
| 35 |
+
|
| 36 |
+
losses = []
|
| 37 |
+
for step in range(100):
|
| 38 |
+
optimizer.zero_grad()
|
| 39 |
+
logits = model(X)
|
| 40 |
+
|
| 41 |
+
# CrossEntropyLoss expects class indices, not one-hot/multi-hot vectors for the target
|
| 42 |
+
loss = criterion(logits, y)
|
| 43 |
+
|
| 44 |
+
loss.backward()
|
| 45 |
+
optimizer.step()
|
| 46 |
+
losses.append(loss.item())
|
| 47 |
+
|
| 48 |
+
# Evaluation
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
logits = model(X)
|
| 51 |
+
# Using sigmoid and 0.5 threshold for multi-label prediction
|
| 52 |
+
preds = (torch.sigmoid(logits) > 0.5).float()
|
| 53 |
+
|
| 54 |
+
avg_labels = preds.sum(dim=1).mean().item()
|
| 55 |
+
f1 = f1_score(y.numpy(), preds.numpy(), average='micro')
|
| 56 |
+
|
| 57 |
+
print('##METRICS_START##')
|
| 58 |
+
print('FINAL_LOSS:' + str(losses[-1]))
|
| 59 |
+
print('AVG_LABELS:' + str(avg_labels))
|
| 60 |
+
print('F1_SCORE:' + str(f1))
|
| 61 |
+
print('##METRICS_END##')
|
| 62 |
+
"""
|
server/tasks/task5_frozen_backbone.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK_DESCRIPTION = """
|
| 2 |
+
This is a standard transfer learning setup classifying 10 categories.
|
| 3 |
+
The developer froze the backbone during testing, but forgot to unfreeze it while still passing its parameters to the optimizer.
|
| 4 |
+
Fix the code so the backbone actually trains, or only pass the head parameters.
|
| 5 |
+
The grader checks the gradient norm of the backbone from the first backward pass.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
BUGGY_CODE = """
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
torch.manual_seed(42)
|
| 13 |
+
|
| 14 |
+
# Dummy dataset
|
| 15 |
+
X = torch.randn(32, 512)
|
| 16 |
+
y = torch.randint(0, 10, (32,))
|
| 17 |
+
|
| 18 |
+
# A simulated pre-trained backbone
|
| 19 |
+
backbone = nn.Sequential(
|
| 20 |
+
nn.Linear(512, 512),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.Linear(512, 512),
|
| 23 |
+
nn.ReLU()
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# BUG: backbone is frozen, but passed to optimizer
|
| 27 |
+
backbone.requires_grad_(False)
|
| 28 |
+
|
| 29 |
+
head = nn.Linear(512, 10)
|
| 30 |
+
|
| 31 |
+
# passing both backbone and head to optimizer even though backbone is frozen
|
| 32 |
+
optimizer = torch.optim.Adam(
|
| 33 |
+
list(backbone.parameters()) + list(head.parameters()), lr=0.001
|
| 34 |
+
)
|
| 35 |
+
criterion = nn.CrossEntropyLoss()
|
| 36 |
+
|
| 37 |
+
losses = []
|
| 38 |
+
|
| 39 |
+
# Take one step to check gradients
|
| 40 |
+
optimizer.zero_grad()
|
| 41 |
+
features = backbone(X)
|
| 42 |
+
logits = head(features)
|
| 43 |
+
|
| 44 |
+
loss = criterion(logits, y)
|
| 45 |
+
loss.backward()
|
| 46 |
+
|
| 47 |
+
# Calculate gradient norm on backbone to see if it's training
|
| 48 |
+
backbone_grad_norm = sum(
|
| 49 |
+
p.grad.norm().item() for p in backbone.parameters() if p.grad is not None
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
optimizer.step()
|
| 53 |
+
losses.append(loss.item())
|
| 54 |
+
|
| 55 |
+
# Note: if backbone is properly frozen and only head is passed, backbone_grad_norm will be 0 but optimizer won't complain.
|
| 56 |
+
# If backbone is unfrozen, backbone_grad_norm will be > 0.
|
| 57 |
+
# The grader handles both valid solutions.
|
| 58 |
+
print('##METRICS_START##')
|
| 59 |
+
print('FINAL_LOSS:' + str(losses[-1]))
|
| 60 |
+
print('BACKBONE_GRAD_NORM:' + str(backbone_grad_norm))
|
| 61 |
+
print('##METRICS_END##')
|
| 62 |
+
"""
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|