Amogh-kal1 commited on
Commit
72ddcb6
·
verified ·
1 Parent(s): 35a09d1

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN pip install --no-cache-dir \
6
+ torch==2.2.0+cpu torchvision==0.17.0+cpu \
7
+ --index-url https://download.pytorch.org/whl/cpu
8
+
9
+ RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY server/requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ COPY . .
15
+
16
+ EXPOSE 7860
17
+ ENV PYTHONUNBUFFERED=1
18
+ ENV ENABLE_WEB_INTERFACE=false
19
+
20
+ CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,10 +1,52 @@
1
  ---
2
- title: Whipstudio
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: docker
7
- pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: WhipStudio Env
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
+ base_path: /ui
9
  ---
10
 
11
+ # ML Debug Environment
12
+
13
+ An OpenEnv-compatible RL environment where agents debug broken PyTorch training scripts.
14
+
15
+ ## Environment Description
16
+ The agent receives a broken Python training script and must return a corrected version.
17
+ Five tasks simulate real ML production bugs with increasing complexity.
18
+
19
+ ## Action Space
20
+ - fixed_code (str, required): Complete corrected Python script
21
+ - explanation (str, optional): Description of bugs found
22
+ - attempt_number (int, 1-3): Which fix attempt this is
23
+
24
+ ## Observation Space
25
+ - task_id: Which task (task1/task2/task3)
26
+ - task_description: Plain English instructions
27
+ - buggy_code: The broken script
28
+ - error_log: stdout+stderr from previous attempt
29
+ - last_reward: Score from previous attempt (0.0 on first step)
30
+ - metrics: {exit_code, elapsed_seconds, timed_out, step, best_reward_so_far}
31
+
32
+ ## Reward Function
33
+ Continuous score 0.0–1.0. Partial credit for every improvement.
34
+ See `server/tasks/graders.py` for per-task scoring logic.
35
+
36
+ ## Tasks
37
+ | Task | Difficulty | Bug Type |
38
+ |------|-----------|----------|
39
+ | task1 | Easy | Wrong optimizer order + bad learning rate |
40
+ | task2 | Medium | Silent NaN from log(0) numerical instability |
41
+ | task3 | Hard | OOM memory leak + train/val data leakage |
42
+ | task4 | Medium | Wrong loss function |
43
+ | task5 | Medium | Frozen backbone |
44
+
45
+ ## Setup
46
+ ```bash
47
+ pip install openenv-core
48
+ uvicorn server.app:app --host 0.0.0.0 --port 7860
49
+ ```
50
+
51
+ ## Endpoints
52
+ POST /reset, POST /step, GET /state, GET /tasks, POST /grader, GET /baseline
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """ML Debug OpenEnv package."""
2
+
3
+ from .models import MLDebugAction, MLDebugObservation
4
+ from .client import MLDebugEnv
5
+
6
+ __all__ = ["MLDebugAction", "MLDebugObservation", "MLDebugEnv"]
baseline_agent.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+
4
+ from dotenv import load_dotenv
5
+ load_dotenv(override=True)
6
+
7
+ import httpx
8
+
9
+ SYSTEM_PROMPT = """
10
+ You are an expert PyTorch debugging agent.
11
+ You receive a broken training script and must fix ALL bugs in it.
12
+ Rules:
13
+ - Return ONLY the complete corrected Python code, nothing else.
14
+ - No markdown, no backticks, no explanation text.
15
+ - The script must print losses in format: LOSSES:[v1, v2, ...]
16
+ - For task3, also print: VAL_ACCS:[v1,...] and FINAL_LOSS:X.XX
17
+ - Keep all torch.manual_seed() calls intact.
18
+ """.strip()
19
+
20
+
21
+ def get_model():
22
+ from smolagents import InferenceClientModel
23
+
24
+ hf_token = os.environ.get("HF_TOKEN")
25
+ if not hf_token:
26
+ raise RuntimeError(
27
+ "HF_TOKEN is not set. Set HF_TOKEN to run /baseline with InferenceClientModel."
28
+ )
29
+
30
+ return InferenceClientModel(
31
+ model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
32
+ token=hf_token,
33
+ )
34
+
35
+
36
+ def _generate_fixed_code(model, prompt: str) -> str:
37
+ def _extract_text(response) -> str:
38
+ if isinstance(response, str):
39
+ return response
40
+
41
+ if hasattr(response, "content"):
42
+ content = getattr(response, "content")
43
+ if isinstance(content, str):
44
+ return content
45
+ if isinstance(content, list):
46
+ chunks = []
47
+ for item in content:
48
+ if isinstance(item, str):
49
+ chunks.append(item)
50
+ elif isinstance(item, dict):
51
+ text = item.get("text") or item.get("content")
52
+ if text:
53
+ chunks.append(str(text))
54
+ if chunks:
55
+ return "\n".join(chunks)
56
+
57
+ if isinstance(response, dict):
58
+ text = response.get("content") or response.get("text")
59
+ if isinstance(text, str):
60
+ return text
61
+
62
+ return str(response)
63
+
64
+ if hasattr(model, "generate"):
65
+ generate = getattr(model, "generate")
66
+ messages = [
67
+ {"role": "system", "content": SYSTEM_PROMPT},
68
+ {"role": "user", "content": prompt},
69
+ ]
70
+ try:
71
+ return _extract_text(generate(messages=messages))
72
+ except TypeError:
73
+ return _extract_text(generate(messages))
74
+
75
+ if callable(model):
76
+ try:
77
+ return _extract_text(model(prompt, system_prompt=SYSTEM_PROMPT))
78
+ except TypeError:
79
+ return _extract_text(model(prompt))
80
+
81
+ raise AttributeError("Model does not support callable() or generate() inference APIs")
82
+
83
+
84
+ async def run_single_task(task_id: str, env_url: str = "http://localhost:7860") -> float:
85
+ """Backwards-compatible wrapper that returns just the score."""
86
+ result = await run_single_task_detailed(task_id, env_url)
87
+ return result["score"]
88
+
89
+
90
+ async def run_single_task_detailed(task_id: str, env_url: str = "http://localhost:7860") -> dict:
91
+ """Run the baseline agent on a single task. Returns detailed results."""
92
+ model = get_model()
93
+ timeout = httpx.Timeout(900.0, connect=10.0)
94
+
95
+ attempts_log = []
96
+
97
+ async with httpx.AsyncClient(timeout=timeout) as client:
98
+ reset_resp = await client.post(f"{env_url}/reset", json={"task_id": task_id})
99
+ reset_resp.raise_for_status()
100
+ obs = reset_resp.json().get("observation", reset_resp.json())
101
+
102
+ best_reward = 0.0
103
+ best_code = ""
104
+ best_output = ""
105
+
106
+ for attempt in range(1, 4):
107
+ prompt = f"""
108
+ Task: {obs.get('task_description', '')}
109
+ Buggy code:
110
+ {obs.get('buggy_code', '')}
111
+ Previous execution output (if any):
112
+ {obs.get('error_log', 'None')}
113
+ Previous score: {obs.get('last_reward', 0.0)}
114
+ """.strip()
115
+
116
+ fixed_code = _generate_fixed_code(model, prompt)
117
+ if "```python" in fixed_code:
118
+ fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip()
119
+ elif "```" in fixed_code:
120
+ fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip()
121
+
122
+ step_payload = {
123
+ "action": {
124
+ "fixed_code": fixed_code,
125
+ "attempt_number": attempt,
126
+ }
127
+ }
128
+ step_resp = await client.post(f"{env_url}/step", json=step_payload)
129
+ if step_resp.status_code == 422:
130
+ step_resp = await client.post(
131
+ f"{env_url}/step",
132
+ json={
133
+ "fixed_code": fixed_code,
134
+ "attempt_number": attempt,
135
+ },
136
+ )
137
+ step_resp.raise_for_status()
138
+ result = step_resp.json()
139
+
140
+ reward = float(result.get("reward", 0.0) or 0.0)
141
+ obs = result.get("observation", obs)
142
+ output_log = obs.get("error_log", "") if isinstance(obs, dict) else ""
143
+
144
+ attempts_log.append({
145
+ "attempt": attempt,
146
+ "code": fixed_code,
147
+ "output": output_log[:3000],
148
+ "reward": reward,
149
+ })
150
+
151
+ if reward > best_reward:
152
+ best_reward = reward
153
+ best_code = fixed_code
154
+ best_output = output_log
155
+
156
+ if result.get("done") or reward >= 0.95:
157
+ break
158
+
159
+ return {
160
+ "score": best_reward,
161
+ "fixed_code": best_code,
162
+ "output": best_output[:3000],
163
+ "attempts": attempts_log,
164
+ }
165
+
166
+
167
+ if __name__ == "__main__":
168
+ import argparse
169
+
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument("--env-url", default="http://localhost:7860")
172
+ args = parser.parse_args()
173
+
174
+ async def main():
175
+ scores = {}
176
+ for tid in ["task1", "task2", "task3"]:
177
+ try:
178
+ s = await asyncio.wait_for(run_single_task(tid, args.env_url), timeout=95.0)
179
+ except TimeoutError:
180
+ s = 0.0
181
+ scores[tid] = round(s, 4)
182
+ print(f"{tid}: {s:.4f}")
183
+ print(f"Average: {sum(scores.values()) / 3:.4f}")
184
+
185
+ asyncio.run(main())
client.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from openenv.core import EnvClient
4
+ from openenv.core.client_types import StepResult
5
+ from openenv.core.env_server.types import State
6
+
7
+ from .models import MLDebugAction, MLDebugObservation
8
+
9
+
10
+ class MLDebugEnv(EnvClient[MLDebugAction, MLDebugObservation, State]):
11
+ def _step_payload(self, action: MLDebugAction) -> Dict:
12
+ return {
13
+ "fixed_code": action.fixed_code,
14
+ "explanation": action.explanation,
15
+ "attempt_number": action.attempt_number,
16
+ }
17
+
18
+ def _parse_result(self, payload: Dict) -> StepResult[MLDebugObservation]:
19
+ obs_data = payload.get("observation", {})
20
+ observation = MLDebugObservation(
21
+ task_id=obs_data.get("task_id", "task1"),
22
+ task_description=obs_data.get("task_description", ""),
23
+ buggy_code=obs_data.get("buggy_code", ""),
24
+ error_log=obs_data.get("error_log", ""),
25
+ last_reward=obs_data.get("last_reward", 0.0),
26
+ metrics=obs_data.get("metrics", {}),
27
+ done=payload.get("done", False),
28
+ reward=payload.get("reward"),
29
+ metadata=obs_data.get("metadata", {}),
30
+ )
31
+
32
+ return StepResult(
33
+ observation=observation,
34
+ reward=payload.get("reward"),
35
+ done=payload.get("done", False),
36
+ )
37
+
38
+ def _parse_state(self, payload: Dict) -> State:
39
+ return State(
40
+ episode_id=payload.get("episode_id"),
41
+ step_count=payload.get("step_count", 0),
42
+ )
gradio_app.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WhipStudio — ML Debug Arena
3
+ A polished Gradio UI for the ML Debugging RL environment.
4
+ Provides code editing, loss curve visualization, diff views, and episode history.
5
+ """
6
+ import difflib
7
+ import json
8
+ import math
9
+ import re
10
+ from typing import Any
11
+
12
+ import gradio as gr
13
+ import httpx
14
+ import os
15
+
16
+ DEFAULT_BASE_URL = os.environ.get("BASE_URL", "http://localhost:7860")
17
+
18
+ # ── Task metadata ──────────────────────────────────────────────────────────
19
+
20
+ TASK_INFO = {
21
+ "task1": {
22
+ "name": "Broken Training Loop",
23
+ "difficulty": "🟢 Easy",
24
+ "description": "Fix optimizer order + learning rate bugs in a linear classifier.",
25
+ "hints": "Look at optimizer.step() / loss.backward() order and the learning rate.",
26
+ },
27
+ "task2": {
28
+ "name": "Silent NaN Loss",
29
+ "difficulty": "🟡 Medium",
30
+ "description": "Fix numerical instability causing NaN loss from log(0).",
31
+ "hints": "The loss computation uses torch.log() without clamping — pred can be 0.",
32
+ },
33
+ "task3": {
34
+ "name": "OOM + Data Leakage",
35
+ "difficulty": "🔴 Hard",
36
+ "description": "Fix memory leak (graph accumulation) AND train/val data leakage.",
37
+ "hints": "Two bugs: total_loss accumulates graph, and augmentation is applied before split.",
38
+ },
39
+ "task4": {
40
+ "name": "Wrong Loss Function",
41
+ "difficulty": "🟡 Medium",
42
+ "description": "Multi-label classification incorrectly using CrossEntropyLoss. Fix loss and eval.",
43
+ "hints": "Use BCEWithLogitsLoss for multi-label. Ensure predictions are multi-hot.",
44
+ },
45
+ "task5": {
46
+ "name": "Frozen Backbone",
47
+ "difficulty": "🟡 Medium",
48
+ "description": "Backbone frozen but its parameters are passed to the optimizer.",
49
+ "hints": "Unfreeze backend or only pass head parameters to Adam.",
50
+ },
51
+ }
52
+
53
+ # ── Theme ──────────────────────────────────────────────────────────────────
54
+
55
+ CUSTOM_CSS = """
56
+ /* Dark arena theme */
57
+ .gradio-container { max-width: 1400px !important; }
58
+
59
+ /* Score display */
60
+ .score-high { color: #22c55e !important; font-weight: 700; }
61
+ .score-med { color: #eab308 !important; font-weight: 700; }
62
+ .score-low { color: #ef4444 !important; font-weight: 700; }
63
+
64
+ /* Task badges */
65
+ .task-badge { display: inline-block; padding: 2px 8px; border-radius: 4px; font-size: 0.85em; }
66
+ .badge-easy { background: #16a34a22; color: #22c55e; border: 1px solid #22c55e44; }
67
+ .badge-medium { background: #ca8a0422; color: #eab308; border: 1px solid #eab30844; }
68
+ .badge-hard { background: #dc262622; color: #ef4444; border: 1px solid #ef444444; }
69
+
70
+ /* Diff highlighting */
71
+ .diff-add { color: #22c55e; background: #22c55e11; }
72
+ .diff-del { color: #ef4444; background: #ef444411; }
73
+
74
+ /* Score bar */
75
+ .score-bar-container { position: relative; height: 28px; background: #1e293b; border-radius: 6px; overflow: hidden; margin: 8px 0; }
76
+ .score-bar-fill { height: 100%; border-radius: 6px; transition: width 0.6s ease-in-out; }
77
+ .score-bar-label { position: absolute; right: 8px; top: 4px; font-weight: 600; font-size: 0.9em; }
78
+
79
+ /* Episode step indicators */
80
+ .step-indicator { display: inline-flex; align-items: center; gap: 6px; padding: 4px 12px; border-radius: 16px; font-size: 0.85em; margin: 2px; }
81
+ .step-done { background: #16a34a22; border: 1px solid #22c55e44; color: #22c55e; }
82
+ .step-active { background: #3b82f622; border: 1px solid #3b82f644; color: #60a5fa; }
83
+ .step-pending { background: #334155; border: 1px solid #47556966; color: #94a3b8; }
84
+
85
+ /* Header styling */
86
+ .arena-header { text-align: center; padding: 12px 0; }
87
+ .arena-header h1 { margin: 0; font-size: 1.8em; }
88
+ .arena-header p { margin: 4px 0 0 0; color: #94a3b8; }
89
+ """
90
+
91
+
92
+ # ── API helpers ────────────────────────────────────────────────────────────
93
+
94
+ def _api(base_url: str, method: str, path: str, payload: dict | None = None) -> dict:
95
+ """Call the WhipStudio API and return parsed JSON."""
96
+ base_url = (base_url or DEFAULT_BASE_URL).strip().rstrip("/")
97
+ url = f"{base_url}{path}"
98
+ try:
99
+ with httpx.Client(timeout=90.0) as client:
100
+ if method == "GET":
101
+ resp = client.get(url)
102
+ else:
103
+ resp = client.post(url, json=payload or {})
104
+ resp.raise_for_status()
105
+ return resp.json()
106
+ except Exception as exc:
107
+ return {"error": f"{exc.__class__.__name__}: {exc}"}
108
+
109
+
110
+ def _parse_losses_from_log(log: str) -> list[float]:
111
+ """Extract LOSSES:[...] from stdout."""
112
+ match = re.search(r"LOSSES:\[([^\]]+)\]", log)
113
+ if not match:
114
+ return []
115
+ try:
116
+ return [float(x.strip()) for x in match.group(1).split(",")]
117
+ except Exception:
118
+ return []
119
+
120
+
121
+ def _parse_val_accs_from_log(log: str) -> list[float]:
122
+ """Extract VAL_ACCS:[...] from stdout."""
123
+ match = re.search(r"VAL_ACCS:\[([^\]]+)\]", log)
124
+ if not match:
125
+ return []
126
+ try:
127
+ return [float(x.strip()) for x in match.group(1).split(",")]
128
+ except Exception:
129
+ return []
130
+
131
+
132
+ def _score_color(score: float) -> str:
133
+ if score >= 0.7:
134
+ return "#22c55e"
135
+ if score >= 0.4:
136
+ return "#eab308"
137
+ return "#ef4444"
138
+
139
+
140
+ def _score_html(score: float) -> str:
141
+ pct = int(score * 100)
142
+ color = _score_color(score)
143
+ return f"""
144
+ <div style="text-align:center; margin: 8px 0;">
145
+ <div style="font-size: 2.4em; font-weight: 700; color: {color};">{score:.2f}</div>
146
+ <div style="position:relative; height:24px; background:#1e293b; border-radius:6px; overflow:hidden; margin:8px auto; max-width:280px;">
147
+ <div style="height:100%; width:{pct}%; background:linear-gradient(90deg, {color}88, {color}); border-radius:6px; transition:width 0.6s ease;"></div>
148
+ </div>
149
+ <div style="color:#94a3b8; font-size:0.85em;">{pct}% complete</div>
150
+ </div>"""
151
+
152
+
153
+ def _diff_html(original: str, fixed: str) -> str:
154
+ """Generate an HTML diff view between original and fixed code."""
155
+ orig_lines = original.strip().splitlines()
156
+ fixed_lines = fixed.strip().splitlines()
157
+ diff = difflib.unified_diff(orig_lines, fixed_lines, lineterm="", n=3)
158
+ lines = []
159
+ for line in diff:
160
+ if line.startswith("+++") or line.startswith("---"):
161
+ continue
162
+ if line.startswith("@@"):
163
+ lines.append(f'<div style="color:#60a5fa;margin:8px 0 2px 0;font-size:0.85em;">{line}</div>')
164
+ elif line.startswith("+"):
165
+ lines.append(f'<div style="color:#22c55e;background:#22c55e0d;padding:1px 6px;font-family:monospace;font-size:0.85em;">+ {line[1:]}</div>')
166
+ elif line.startswith("-"):
167
+ lines.append(f'<div style="color:#ef4444;background:#ef44440d;padding:1px 6px;font-family:monospace;font-size:0.85em;">- {line[1:]}</div>')
168
+ else:
169
+ lines.append(f'<div style="color:#94a3b8;padding:1px 6px;font-family:monospace;font-size:0.85em;"> {line}</div>')
170
+ if not lines:
171
+ return '<div style="color:#94a3b8;text-align:center;padding:20px;">No changes detected</div>'
172
+ return '<div style="background:#0f172a;border-radius:8px;padding:12px;overflow-x:auto;max-height:500px;overflow-y:auto;">' + "\n".join(lines) + "</div>"
173
+
174
+
175
+ def _step_timeline_html(trajectory: list[dict], current_step: int, max_steps: int = 3) -> str:
176
+ """Render step timeline as HTML."""
177
+ items = []
178
+ for i in range(1, max_steps + 1):
179
+ entry = next((t for t in trajectory if t.get("step") == i), None)
180
+ if entry:
181
+ r = entry["reward"]
182
+ color = _score_color(r)
183
+ items.append(
184
+ f'<span class="step-indicator step-done" style="border-color:{color}44;background:{color}11;color:{color};">'
185
+ f'Step {i} → {r:.2f}</span>'
186
+ )
187
+ elif i == current_step + 1:
188
+ items.append(f'<span class="step-indicator step-active">Step {i} ▶</span>')
189
+ else:
190
+ items.append(f'<span class="step-indicator step-pending">Step {i}</span>')
191
+
192
+ return '<div style="display:flex;gap:8px;justify-content:center;flex-wrap:wrap;padding:8px 0;">' + "".join(items) + "</div>"
193
+
194
+
195
+ def _loss_plot(losses: list[float], title: str = "Loss Curve"):
196
+ """Generate a matplotlib figure for loss curve."""
197
+ import matplotlib
198
+ matplotlib.use("Agg")
199
+ import matplotlib.pyplot as plt
200
+
201
+ fig, ax = plt.subplots(figsize=(5, 3))
202
+ fig.patch.set_facecolor("#0f172a")
203
+ ax.set_facecolor("#1e293b")
204
+
205
+ if losses:
206
+ valid_losses = [(i, l) for i, l in enumerate(losses) if not (math.isnan(l) or math.isinf(l))]
207
+ nan_steps = [i for i, l in enumerate(losses) if math.isnan(l) or math.isinf(l)]
208
+
209
+ if valid_losses:
210
+ steps, vals = zip(*valid_losses)
211
+ ax.plot(steps, vals, color="#60a5fa", linewidth=2, marker="o", markersize=3, zorder=3)
212
+ ax.fill_between(steps, vals, alpha=0.15, color="#60a5fa")
213
+
214
+ if nan_steps:
215
+ ax.scatter(nan_steps, [max(v for _, v in valid_losses) if valid_losses else 1.0] * len(nan_steps),
216
+ color="#ef4444", marker="x", s=60, zorder=4, label="NaN/Inf")
217
+ ax.legend(facecolor="#1e293b", edgecolor="#334155", labelcolor="#94a3b8")
218
+
219
+ ax.set_xlabel("Step", color="#94a3b8", fontsize=9)
220
+ ax.set_ylabel("Loss", color="#94a3b8", fontsize=9)
221
+ ax.set_title(title, color="#e2e8f0", fontsize=11, fontweight="bold")
222
+ ax.tick_params(colors="#64748b", labelsize=8)
223
+ for spine in ax.spines.values():
224
+ spine.set_color("#334155")
225
+ ax.grid(True, alpha=0.15, color="#475569")
226
+
227
+ fig.tight_layout()
228
+ return fig
229
+
230
+
231
+ def _val_acc_plot(accs: list[float]):
232
+ """Generate a matplotlib figure for validation accuracy."""
233
+ import matplotlib
234
+ matplotlib.use("Agg")
235
+ import matplotlib.pyplot as plt
236
+
237
+ fig, ax = plt.subplots(figsize=(5, 3))
238
+ fig.patch.set_facecolor("#0f172a")
239
+ ax.set_facecolor("#1e293b")
240
+
241
+ if accs:
242
+ epochs = list(range(1, len(accs) + 1))
243
+ ax.plot(epochs, accs, color="#a78bfa", linewidth=2, marker="o", markersize=3)
244
+ ax.fill_between(epochs, accs, alpha=0.15, color="#a78bfa")
245
+ ax.axhline(y=0.7, color="#22c55e", linestyle="--", alpha=0.5, label="Target (0.70)")
246
+ ax.legend(facecolor="#1e293b", edgecolor="#334155", labelcolor="#94a3b8")
247
+
248
+ ax.set_xlabel("Epoch", color="#94a3b8", fontsize=9)
249
+ ax.set_ylabel("Accuracy", color="#94a3b8", fontsize=9)
250
+ ax.set_title("Validation Accuracy", color="#e2e8f0", fontsize=11, fontweight="bold")
251
+ ax.set_ylim(0, 1.05)
252
+ ax.tick_params(colors="#64748b", labelsize=8)
253
+ for spine in ax.spines.values():
254
+ spine.set_color("#334155")
255
+ ax.grid(True, alpha=0.15, color="#475569")
256
+
257
+ fig.tight_layout()
258
+ return fig
259
+
260
+
261
+ # ── Episode state ──────────────────────────────────────────────────────────
262
+
263
+ class EpisodeState:
264
+ """Track episode state across Gradio interactions."""
265
+ def __init__(self):
266
+ self.reset()
267
+
268
+ def reset(self):
269
+ self.task_id = ""
270
+ self.buggy_code = ""
271
+ self.task_description = ""
272
+ self.step = 0
273
+ self.best_reward = 0.0
274
+ self.last_reward = 0.0
275
+ self.error_log = ""
276
+ self.trajectory: list[dict] = []
277
+ self.done = False
278
+ self.last_fixed_code = ""
279
+
280
+
281
+ _state = EpisodeState()
282
+
283
+
284
+ # ── Action handlers ────────────────────────────────────────────────────────
285
+
286
+ def do_reset(base_url: str, task_id: str):
287
+ """Reset the environment for a given task."""
288
+ _state.reset()
289
+ data = _api(base_url, "POST", "/reset", {"task_id": task_id})
290
+ if "error" in data:
291
+ return (
292
+ f"❌ Error: {data['error']}", # status
293
+ "", # code editor
294
+ "", # task desc
295
+ _score_html(0.0), # score
296
+ None, # loss plot
297
+ None, # acc plot
298
+ "", # diff
299
+ _step_timeline_html([], 0), # timeline
300
+ "", # error log
301
+ )
302
+
303
+ obs = data.get("observation", data)
304
+ _state.task_id = obs.get("task_id", task_id)
305
+ _state.buggy_code = obs.get("buggy_code", "")
306
+ _state.task_description = obs.get("task_description", "")
307
+ _state.step = 0
308
+ _state.done = False
309
+
310
+ info = TASK_INFO.get(task_id, {})
311
+ task_md = f"""### {info.get('name', task_id)} {info.get('difficulty', '')}
312
+
313
+ {_state.task_description.strip()}
314
+
315
+ **💡 Hint:** {info.get('hints', 'No hints available.')}
316
+ """
317
+
318
+ return (
319
+ f"✅ Episode started — {info.get('name', task_id)}",
320
+ _state.buggy_code.strip(),
321
+ task_md,
322
+ _score_html(0.0),
323
+ _loss_plot([], "Loss Curve — Submit a fix to see results"),
324
+ None,
325
+ '<div style="color:#94a3b8;text-align:center;padding:20px;">Submit a fix to see the diff</div>',
326
+ _step_timeline_html([], 0),
327
+ "",
328
+ )
329
+
330
+
331
+ def do_step(base_url: str, fixed_code: str):
332
+ """Submit a fix attempt."""
333
+ if _state.done:
334
+ return (
335
+ "⚠️ Episode is done. Reset to start a new one.",
336
+ _score_html(_state.best_reward),
337
+ None, None, "", _step_timeline_html(_state.trajectory, _state.step), "",
338
+ )
339
+
340
+ if not fixed_code or not fixed_code.strip():
341
+ return (
342
+ "⚠️ Please enter code before submitting.",
343
+ _score_html(_state.last_reward),
344
+ None, None, "", _step_timeline_html(_state.trajectory, _state.step), "",
345
+ )
346
+
347
+ _state.step += 1
348
+ _state.last_fixed_code = fixed_code
349
+
350
+ payload = {"action": {"fixed_code": fixed_code, "attempt_number": _state.step}}
351
+ data = _api(base_url, "POST", "/step", payload)
352
+
353
+ if "error" in data:
354
+ _state.step -= 1
355
+ return (
356
+ f"❌ Error: {data['error']}",
357
+ _score_html(_state.last_reward),
358
+ None, None, "", _step_timeline_html(_state.trajectory, _state.step), "",
359
+ )
360
+
361
+ reward = float(data.get("reward", 0.0) or 0.0)
362
+ _state.last_reward = reward
363
+ _state.best_reward = max(_state.best_reward, reward)
364
+ _state.done = data.get("done", False)
365
+
366
+ obs = data.get("observation", {})
367
+ _state.error_log = obs.get("error_log", "")
368
+ metrics = obs.get("metrics", {})
369
+
370
+ _state.trajectory.append({
371
+ "step": _state.step,
372
+ "reward": reward,
373
+ "best_reward": _state.best_reward,
374
+ "metrics": metrics,
375
+ "done": _state.done,
376
+ })
377
+
378
+ # Parse outputs for visualization
379
+ losses = _parse_losses_from_log(_state.error_log)
380
+ val_accs = _parse_val_accs_from_log(_state.error_log)
381
+
382
+ loss_fig = _loss_plot(losses, f"Loss Curve — Step {_state.step}")
383
+ acc_fig = _val_acc_plot(val_accs) if val_accs else None
384
+
385
+ diff = _diff_html(_state.buggy_code, fixed_code)
386
+ timeline = _step_timeline_html(_state.trajectory, _state.step)
387
+
388
+ if reward >= 0.95:
389
+ emoji = "🎯"
390
+ elif reward >= 0.7:
391
+ emoji = "✅"
392
+ elif len(_state.trajectory) > 1 and reward > _state.trajectory[-2]["reward"]:
393
+ emoji = "📈"
394
+ else:
395
+ emoji = "⚠️"
396
+ done_msg = " — Episode complete!" if _state.done else ""
397
+ status = f"{emoji} Step {_state.step}/3 — Reward: {reward:.2f} (Best: {_state.best_reward:.2f}){done_msg}"
398
+
399
+ error_display = _state.error_log if _state.error_log else "No errors — code ran successfully."
400
+
401
+ return (status, _score_html(_state.best_reward), loss_fig, acc_fig, diff, timeline, error_display)
402
+
403
+
404
+ def do_run_baseline(base_url: str, task_id: str):
405
+ """Run the baseline agent on a single task."""
406
+ # First reset
407
+ reset_result = do_reset(base_url, task_id)
408
+ yield reset_result + ("🤖 Resetting environment...",)
409
+
410
+ # Call baseline endpoint
411
+ data = _api(base_url, "GET", "/baseline")
412
+ if "error" in data:
413
+ yield reset_result + (f"❌ Baseline error: {data['error']}",)
414
+ return
415
+
416
+ scores = data.get("baseline_scores", {})
417
+ avg = data.get("average", 0.0)
418
+
419
+ results_md = "### 🤖 Baseline Agent Results\n\n"
420
+ results_md += "| Task | Score |\n|---|---|\n"
421
+ for tid in ["task1", "task2", "task3", "task4", "task5"]:
422
+ s = scores.get(tid, 0.0)
423
+ emoji = "🎯" if s >= 0.9 else ("✅" if s >= 0.7 else ("📈" if s >= 0.4 else "⚠️"))
424
+ results_md += f"| {tid} | {emoji} {s:.4f} |\n"
425
+ results_md += f"\n**Average: {avg:.4f}**"
426
+
427
+ yield reset_result + (results_md,)
428
+
429
+
430
+ def load_current_state(base_url: str):
431
+ """Fetch and format current environment state for UI display."""
432
+ data = _api(base_url, "GET", "/state")
433
+ if "error" in data:
434
+ summary = "⚠️ Could not fetch current state. Start or reset an episode first, then try again."
435
+ return summary, {"error": data["error"]}
436
+
437
+ state_obj = data.get("state", data)
438
+ if not isinstance(state_obj, dict):
439
+ return "⚠️ State endpoint returned an unexpected response format.", {"raw": data}
440
+
441
+ done = bool(state_obj.get("done", False))
442
+ step = state_obj.get("step", 0)
443
+ task_id = state_obj.get("task_id", "-")
444
+ reward = state_obj.get("last_reward", state_obj.get("reward", 0.0))
445
+ summary = (
446
+ f"**Task:** {task_id} | **Step:** {step} | "
447
+ f"**Done:** {'yes' if done else 'no'} | **Last Reward:** {reward}"
448
+ )
449
+ return summary, state_obj
450
+
451
+
452
+ # ── Build the UI ───────────────────────────────────────────────────────────
453
+
454
+ def build_ui() -> gr.Blocks:
455
+ theme = gr.themes.Soft(
456
+ primary_hue=gr.themes.colors.blue,
457
+ secondary_hue=gr.themes.colors.purple,
458
+ neutral_hue=gr.themes.colors.slate,
459
+ font=gr.themes.GoogleFont("Inter"),
460
+ ).set(
461
+ body_background_fill="#0f172a",
462
+ body_background_fill_dark="#0f172a",
463
+ block_background_fill="#1e293b",
464
+ block_background_fill_dark="#1e293b",
465
+ block_border_color="#334155",
466
+ block_border_color_dark="#334155",
467
+ block_label_text_color="#e2e8f0",
468
+ block_label_text_color_dark="#e2e8f0",
469
+ block_title_text_color="#e2e8f0",
470
+ block_title_text_color_dark="#e2e8f0",
471
+ input_background_fill="#0f172a",
472
+ input_background_fill_dark="#0f172a",
473
+ button_primary_background_fill="#3b82f6",
474
+ button_primary_background_fill_dark="#3b82f6",
475
+ button_primary_text_color="#ffffff",
476
+ )
477
+
478
+ with gr.Blocks(title="WhipStudio — ML Debug Arena") as app:
479
+
480
+ # ── Header ──
481
+ gr.HTML("""
482
+ <div class="arena-header">
483
+ <h1>🔧 WhipStudio — ML Debug Arena</h1>
484
+ <p>An RL environment where agents debug broken PyTorch training scripts</p>
485
+ </div>
486
+ """)
487
+
488
+ base_url = gr.Textbox(label="🌐 API Base URL", value=DEFAULT_BASE_URL, scale=1)
489
+
490
+ with gr.Row(equal_height=False):
491
+
492
+ # ── Left column: Task selector ──
493
+ with gr.Column(scale=1, min_width=280):
494
+ gr.Markdown("### 📋 Task Selector")
495
+ task_id = gr.Radio(
496
+ choices=["task1", "task2", "task3", "task4", "task5"],
497
+ value="task1",
498
+ label="Select Task",
499
+ info="Choose a debugging challenge",
500
+ )
501
+
502
+ task_desc = gr.Markdown(
503
+ value="""### Broken Training Loop 🟢 Easy
504
+
505
+ Fix optimizer order + learning rate bugs in a linear classifier.
506
+
507
+ **💡 Hint:** Look at optimizer.step() / loss.backward() order and the learning rate."""
508
+ )
509
+
510
+ with gr.Row():
511
+ btn_reset = gr.Button("🔄 Reset", variant="primary", size="sm")
512
+ btn_baseline = gr.Button("🤖 Auto-Agent", variant="secondary", size="sm")
513
+
514
+ status = gr.Textbox(label="Status", interactive=False, lines=1)
515
+ timeline = gr.HTML(label="Episode Timeline", value=_step_timeline_html([], 0))
516
+
517
+ # ── Center column: Code editor ──
518
+ with gr.Column(scale=2, min_width=400):
519
+ gr.Markdown("### 💻 Code Editor")
520
+ code_editor = gr.Code(
521
+ label="Your Fix (edit the code below)",
522
+ language="python",
523
+ lines=22,
524
+ )
525
+
526
+ with gr.Row():
527
+ btn_submit = gr.Button("🚀 Submit Fix", variant="primary", size="lg")
528
+
529
+ error_log = gr.Textbox(
530
+ label="📋 Execution Output / Error Log",
531
+ lines=6,
532
+ interactive=False,
533
+ )
534
+
535
+ # ── Right column: Results ──
536
+ with gr.Column(scale=1, min_width=300):
537
+ gr.Markdown("### 📊 Results")
538
+ score_display = gr.HTML(value=_score_html(0.0))
539
+
540
+ with gr.Tabs():
541
+ with gr.Tab("📉 Loss Curve"):
542
+ loss_plot = gr.Plot(label="Loss Curve")
543
+ with gr.Tab("📈 Val Accuracy"):
544
+ acc_plot = gr.Plot(label="Validation Accuracy (Task 3)")
545
+ with gr.Tab("🔀 Code Diff"):
546
+ diff_view = gr.HTML(
547
+ value='<div style="color:#94a3b8;text-align:center;padding:20px;">Submit a fix to see the diff</div>'
548
+ )
549
+ with gr.Tab("🧭 Current State"):
550
+ state_summary = gr.Markdown(
551
+ value="Press Reset to start an episode, then Current State will appear here."
552
+ )
553
+ btn_refresh_state = gr.Button("🔄 Refresh State", variant="secondary", size="sm")
554
+ state_json = gr.JSON(label="/state response", value={})
555
+
556
+ baseline_output = gr.Markdown(label="Baseline Results", visible=False)
557
+
558
+
559
+ # ── Bottom: Raw API tab (for developers) ──
560
+ with gr.Accordion("🔧 Developer Tools (Raw API)", open=False):
561
+ with gr.Row():
562
+ with gr.Column():
563
+ dev_method = gr.Radio(["GET", "POST"], value="GET", label="Method")
564
+ dev_path = gr.Textbox(label="Path", value="/health")
565
+ dev_payload = gr.Code(label="Payload (JSON)", language="json", value="{}")
566
+ btn_dev = gr.Button("Send Request", variant="secondary")
567
+ with gr.Column():
568
+ dev_status = gr.Textbox(label="Status")
569
+ dev_response = gr.Code(label="Response", language="json")
570
+
571
+ def dev_call(base, method, path, payload_text):
572
+ base = (base or DEFAULT_BASE_URL).strip().rstrip("/")
573
+ url = f"{base}{path}"
574
+ try:
575
+ payload = json.loads(payload_text) if payload_text.strip() else {}
576
+ except json.JSONDecodeError as e:
577
+ return f"JSON Error", str(e)
578
+ try:
579
+ with httpx.Client(timeout=90.0) as client:
580
+ if method == "GET":
581
+ resp = client.get(url)
582
+ else:
583
+ resp = client.post(url, json=payload)
584
+ ct = resp.headers.get("content-type", "")
585
+ body = json.dumps(resp.json(), indent=2) if "json" in ct else resp.text
586
+ return f"{resp.status_code} {resp.reason_phrase}", body
587
+ except Exception as exc:
588
+ return "Error", f"{exc.__class__.__name__}: {exc}"
589
+
590
+ btn_dev.click(
591
+ fn=dev_call,
592
+ inputs=[base_url, dev_method, dev_path, dev_payload],
593
+ outputs=[dev_status, dev_response],
594
+ )
595
+
596
+
597
+ # ── Event bindings ──
598
+
599
+ # Task selector updates description
600
+ def update_task_desc(tid):
601
+ info = TASK_INFO.get(tid, {})
602
+ return f"""### {info.get('name', tid)} {info.get('difficulty', '')}
603
+
604
+ {info.get('description', '')}
605
+
606
+ **💡 Hint:** {info.get('hints', 'No hints available.')}"""
607
+
608
+ task_id.change(fn=update_task_desc, inputs=[task_id], outputs=[task_desc])
609
+
610
+ # Reset
611
+ btn_reset.click(
612
+ fn=do_reset,
613
+ inputs=[base_url, task_id],
614
+ outputs=[status, code_editor, task_desc, score_display, loss_plot, acc_plot, diff_view, timeline, error_log],
615
+ ).then(
616
+ fn=load_current_state,
617
+ inputs=[base_url],
618
+ outputs=[state_summary, state_json],
619
+ )
620
+
621
+ # Submit fix
622
+ btn_submit.click(
623
+ fn=do_step,
624
+ inputs=[base_url, code_editor],
625
+ outputs=[status, score_display, loss_plot, acc_plot, diff_view, timeline, error_log],
626
+ ).then(
627
+ fn=load_current_state,
628
+ inputs=[base_url],
629
+ outputs=[state_summary, state_json],
630
+ )
631
+
632
+ btn_refresh_state.click(
633
+ fn=load_current_state,
634
+ inputs=[base_url],
635
+ outputs=[state_summary, state_json],
636
+ )
637
+
638
+ # Baseline (auto-agent) — live streaming per-task
639
+ TASK_NAMES = {
640
+ "task1": "🟢 Broken Training Loop",
641
+ "task2": "🟡 Silent NaN Loss",
642
+ "task3": "🔴 OOM + Data Leakage",
643
+ "task4": "🟡 Wrong Loss Function",
644
+ "task5": "🟡 Frozen Backbone",
645
+ }
646
+
647
+ def run_baseline_live(base_url_val):
648
+ """Generator that yields live progress as each task completes."""
649
+ base = (base_url_val or DEFAULT_BASE_URL).strip().rstrip("/")
650
+ results = {}
651
+ lines_header = ["### 🤖 Baseline Agent — Live Progress\n"]
652
+
653
+ # Phase 1: Show "starting" state
654
+ yield "\n".join(lines_header + ["⏳ Starting baseline agent..."])
655
+
656
+ for tid in ["task1", "task2", "task3", "task4", "task5"]:
657
+ tname = TASK_NAMES.get(tid, tid)
658
+
659
+ # Show "running this task" update
660
+ progress_lines = list(lines_header)
661
+ # Show completed tasks
662
+ for done_tid, info in results.items():
663
+ s = info["score"]
664
+ emoji = "🎯" if s >= 0.9 else ("✅" if s >= 0.7 else ("📈" if s >= 0.4 else "⚠️"))
665
+ progress_lines.append(f"- {emoji} **{TASK_NAMES.get(done_tid, done_tid)}**: {s:.4f}")
666
+ if info.get("error"):
667
+ progress_lines.append(f" - ⚠️ `{info['error'][:150]}`")
668
+ # Show currently running task
669
+ progress_lines.append(f"\n🤖 **Running {tname}** — agent is analyzing the code and generating a fix...")
670
+ progress_lines.append(f"\n*This may take 30-60 seconds per task (LLM inference + sandbox execution × 3 attempts)*")
671
+ yield "\n".join(progress_lines)
672
+
673
+ # Actually call the per-task endpoint
674
+ try:
675
+ with httpx.Client(timeout=180.0) as client:
676
+ resp = client.get(f"{base}/baseline/task/{tid}")
677
+ resp.raise_for_status()
678
+ data = resp.json()
679
+ except Exception as exc:
680
+ data = {"score": 0.0, "error": f"{exc.__class__.__name__}: {exc}"}
681
+
682
+ results[tid] = {
683
+ "score": float(data.get("score", 0.0)),
684
+ "error": data.get("error", ""),
685
+ "fixed_code": data.get("fixed_code", ""),
686
+ "output": data.get("output", ""),
687
+ }
688
+
689
+ # Final summary
690
+ final_lines = ["### 🤖 Baseline Agent Results\n", "| Task | Score |", "|---|---|"]
691
+ total = 0.0
692
+ has_errors = False
693
+ for tid in ["task1", "task2", "task3", "task4", "task5"]:
694
+ info = results.get(tid, {"score": 0.0})
695
+ s = info["score"]
696
+ total += s
697
+ emoji = "🎯" if s >= 0.9 else ("✅" if s >= 0.7 else ("📈" if s >= 0.4 else "⚠️"))
698
+ final_lines.append(f"| {TASK_NAMES.get(tid, tid)} | {emoji} **{s:.4f}** |")
699
+ if info.get("error"):
700
+ has_errors = True
701
+ final_lines.append(f"\n> ⚠️ `{info['error'][:200]}`\n")
702
+
703
+ avg = total / 5
704
+ final_lines.append(f"\n**Average: {avg:.4f}**")
705
+ if avg >= 0.7:
706
+ final_lines.append("\n🎉 **Agent performed well!** The environment is solvable.")
707
+ elif avg >= 0.3:
708
+ final_lines.append("\n📈 **Agent showed partial progress.** Reward shaping is working.")
709
+ elif not has_errors:
710
+ final_lines.append("\n⚠️ **Agent scored low.** Tasks may be too challenging for zero-shot inference.")
711
+
712
+ if has_errors:
713
+ final_lines.append("\n---\n> [!WARNING]\n> Some tasks failed. Check if `HF_TOKEN` is valid and the model is accessible.")
714
+
715
+ final_lines.append("\n---\n### 🔍 Auto-Agent Generated Code & Execution Logs")
716
+ for tid in ["task1", "task2", "task3", "task4", "task5"]:
717
+ info = results.get(tid, {})
718
+ fixed_code = str(info.get("fixed_code", ""))
719
+ output = str(info.get("output", ""))
720
+ if fixed_code.strip() or output.strip():
721
+ final_lines.append(f"\n#### {TASK_NAMES.get(tid, tid)}")
722
+ if fixed_code.strip():
723
+ final_lines.append("<details><summary><b>Show Generated Code</b></summary>\n\n```python\n" + fixed_code + "\n```\n</details>")
724
+ if output.strip():
725
+ final_lines.append("<details><summary><b>Show Execution Output</b></summary>\n\n```text\n" + output + "\n```\n</details>")
726
+
727
+ yield "\n".join(final_lines)
728
+
729
+ btn_baseline.click(
730
+ fn=lambda: gr.update(visible=True),
731
+ outputs=[baseline_output],
732
+ ).then(
733
+ fn=run_baseline_live,
734
+ inputs=[base_url],
735
+ outputs=[baseline_output],
736
+ )
737
+
738
+ # Footer
739
+ gr.HTML("""
740
+ <div style="text-align:center; padding:16px 0; color:#64748b; font-size:0.85em; border-top:1px solid #1e293b; margin-top:16px;">
741
+ WhipStudio v1.0 — OpenEnv ML Debug Environment
742
+ · <a href="/web" style="color:#60a5fa;">OpenEnv Web UI →</a>
743
+ · <a href="/docs" style="color:#60a5fa;">API Docs →</a>
744
+ </div>
745
+ """)
746
+
747
+ return app
748
+
749
+
750
+ def main(host: str = "0.0.0.0", port: int = 7860):
751
+ app = build_ui()
752
+ app.launch(server_name=host, server_port=port, css=CUSTOM_CSS)
753
+
754
+
755
+ if __name__ == "__main__":
756
+ main()
models.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openenv.core.env_server.types import Action, Observation
2
+ from pydantic import Field
3
+
4
+
5
+ class MLDebugAction(Action):
6
+ """Agent submits a corrected training script."""
7
+
8
+ fixed_code: str = Field(
9
+ ...,
10
+ description="The corrected Python training script. Must be complete runnable code.",
11
+ )
12
+ explanation: str = Field(
13
+ default="",
14
+ description="Optional: agent's explanation of bugs found (not scored, for logging)",
15
+ )
16
+ attempt_number: int = Field(
17
+ default=1,
18
+ ge=1,
19
+ le=3,
20
+ description="Which attempt this is. Max 3 per episode.",
21
+ )
22
+
23
+
24
+ class MLDebugObservation(Observation):
25
+ """What the agent sees on reset() and after each step()."""
26
+
27
+ task_id: str = Field(..., description="task1 | task2 | task3")
28
+ task_description: str = Field(..., description="Plain English task instructions")
29
+ buggy_code: str = Field(..., description="The broken training script")
30
+ error_log: str = Field(
31
+ default="",
32
+ description="stdout+stderr from the previous attempt. Empty on first step.",
33
+ )
34
+ last_reward: float = Field(
35
+ default=0.0,
36
+ description="Reward from previous attempt. 0.0 on first step.",
37
+ )
38
+ metrics: dict = Field(
39
+ default_factory=dict,
40
+ description="Structured: {final_loss, nan_count, val_acc, timed_out, exit_code}",
41
+ )
openenv.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ spec_version: 2
3
+ name: whipstudio-env
4
+ type: space
5
+ runtime: fastapi
6
+ app: server.app:app
7
+ port: 7860
8
+
9
+ metadata:
10
+ title: WhipStudio Env
11
+ version: "1.1.0"
12
+ description: >
13
+ OpenEnv-compatible RL environment where agents debug broken PyTorch
14
+ training scripts across five tasks with continuous reward scoring.
15
+ author: Amogh-kal1
16
+ license: Apache-2.0
17
+ repository: Amogh-kal1/whipstudio-env
18
+ tags:
19
+ - openenv
20
+ - rl
21
+ - ml-debugging
22
+ - fastapi
23
+ - gradio
24
+
25
+ compute:
26
+ python_version: "3.11"
27
+ cpu: 2
28
+ memory_gb: 16
29
+ gpu: false
30
+
31
+ service:
32
+ host: 0.0.0.0
33
+ port: 7860
34
+ health_path: /health
35
+ base_path: /
36
+ web_path: /web
37
+ ui_path: /ui
38
+ docs_path: /docs
39
+ endpoints:
40
+ - method: POST
41
+ path: /reset
42
+ - method: POST
43
+ path: /step
44
+ - method: GET
45
+ path: /state
46
+ - method: GET
47
+ path: /tasks
48
+ - method: POST
49
+ path: /grader
50
+ - method: GET
51
+ path: /baseline
52
+ - method: GET
53
+ path: /baseline/task/{task_id}
54
+ - method: GET
55
+ path: /baseline/health
56
+
57
+ tasks:
58
+ - id: task1
59
+ name: Broken training loop
60
+ difficulty: easy
61
+ max_steps: 3
62
+ - id: task2
63
+ name: Silent NaN loss
64
+ difficulty: medium
65
+ max_steps: 3
66
+ - id: task3
67
+ name: OOM and data leakage
68
+ difficulty: hard
69
+ max_steps: 3
70
+ - id: task4
71
+ name: Wrong loss function
72
+ difficulty: medium
73
+ max_steps: 3
74
+ - id: task5
75
+ name: Frozen backbone
76
+ difficulty: medium
77
+ max_steps: 3
openenv_whipstudio.egg-info/PKG-INFO ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-whipstudio
3
+ Version: 0.1.0
4
+ Summary: ML Debug environment for OpenEnv
5
+ Requires-Python: >=3.11
6
+ License-File: LICENSE
7
+ Requires-Dist: openenv-core[core]>=0.2.1
8
+ Requires-Dist: fastapi>=0.110.0
9
+ Requires-Dist: uvicorn>=0.27.0
10
+ Requires-Dist: pydantic>=2.0.0
11
+ Requires-Dist: httpx>=0.27.0
12
+ Requires-Dist: torch>=2.2.0
13
+ Requires-Dist: smolagents>=1.0.0
14
+ Provides-Extra: dev
15
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
16
+ Dynamic: license-file
openenv_whipstudio.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ __init__.py
4
+ baseline_agent.py
5
+ client.py
6
+ gradio_app.py
7
+ models.py
8
+ pyproject.toml
9
+ openenv_whipstudio.egg-info/PKG-INFO
10
+ openenv_whipstudio.egg-info/SOURCES.txt
11
+ openenv_whipstudio.egg-info/dependency_links.txt
12
+ openenv_whipstudio.egg-info/entry_points.txt
13
+ openenv_whipstudio.egg-info/requires.txt
14
+ openenv_whipstudio.egg-info/top_level.txt
15
+ server/__init__.py
16
+ server/app.py
17
+ server/environment.py
18
+ server/sandbox.py
19
+ server/tasks/__init__.py
20
+ server/tasks/graders.py
21
+ server/tasks/task1_broken_loop.py
22
+ server/tasks/task2_nan_loss.py
23
+ server/tasks/task3_oom_leakage.py
24
+ server/tasks/task4_wrong_loss.py
25
+ server/tasks/task5_frozen_backbone.py
openenv_whipstudio.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_whipstudio.egg-info/entry_points.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [console_scripts]
2
+ server = server.app:main
3
+ whipstudio-server = server.app:main
openenv_whipstudio.egg-info/requires.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.1
2
+ fastapi>=0.110.0
3
+ uvicorn>=0.27.0
4
+ pydantic>=2.0.0
5
+ httpx>=0.27.0
6
+ torch>=2.2.0
7
+ smolagents>=1.0.0
8
+
9
+ [dev]
10
+ pytest>=8.0.0
openenv_whipstudio.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ server
project_status_report_2703.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WhipStudio Project Status Report
2
+
3
+ **Date:** March 27, 2026
4
+ **Project:** WhipStudio — ML Debugging Arena (OpenEnv AI Hackathon)
5
+ **Status:** 🟢 **On Track for Submission**
6
+
7
+ ---
8
+
9
+ ## Executive Summary
10
+ WhipStudio has successfully transitioned from a conceptual prototype to a fully functional, professional-grade ML Debugging Arena. The core objective of the platform—providing an isolated, standardized environment (Gymnasium) to evaluate and train autonomous AI agents on PyTorch debugging tasks—is now fully operational. The backend conforms exactly to the OpenEnv specification, while the frontend provides a best-in-class spectator and verification experience.
11
+
12
+ ---
13
+
14
+ ## System Architecture & Technical Specifications
15
+
16
+ ### 1. The Core Environment Database (Tasks)
17
+ The platform currently ships with three distinct ML debugging challenges, graded automatically:
18
+ - **Task 1 (Easy): Broken Training Loop.** Fixes a simple linear classifier that has an excessively high learning rate, steps the optimizer out of order, and computes loss incorrectly.
19
+ - **Task 2 (Medium): Silent NaN Loss.** Fixes a CNN training script where the output loss silently turns to `NaN`.
20
+ - **Task 3 (Hard): OOM + Data Leakage.** Fixes a complex PyTorch training loop that creates memory leaks by accumulating computation graphs incorrectly, and introduces data leakage by augmenting before splitting train/val. Graded on a granular scale (0.15 - 0.5 per bug fixed).
21
+
22
+ ### 2. The Execution Sandbox API (Backend)
23
+ All agent submissions are processed securely in an isolated environment.
24
+
25
+ **Core Endpoints (OpenEnv Core API):**
26
+ - `POST /reset`: Initialize the environment for a specific task. Returns the `observation` payload, which includes the `buggy_code` string and `task_description`.
27
+ - `POST /step`: Submit a fixed code attempt (the `action`). The system executes the code in the sandbox, parses the `stdout`/`stderr` using regex metrics extractors (e.g., [parse_scalar](file:///home/amogh/Documents/openenv-comp/WhipStudio/server/tasks/graders.py#35-38) for `FINAL_LOSS`), scores the execution, logs the trajectory, and returns the formal RL `reward` (0.0 to 1.0) along with the `done` state.
28
+
29
+ **Security & Isolation ([server/sandbox.py](file:///home/amogh/Documents/openenv-comp/WhipStudio/server/sandbox.py)):**
30
+ - Python execution restricts dangerous operations using `BANNED_PATTERNS` (e.g., `os.system`, `subprocess.`, `open()`, `socket.`, `requests.`).
31
+ - Code runs in `/tmp` natively in a subprocess with a strict 30-second timeout (`TIMEOUT_SECONDS = 30`) and output sizing limits (`MAX_OUTPUT_BYTES = 8192`).
32
+ - Environment dependencies (`PYTHONPATH`) are strictly controlled, ensuring library access (like `torch`) without host system compromise.
33
+
34
+ ### 3. AI Agent Integrations
35
+ - `GET /baseline/health`: Checks if the default baseline model (Qwen-32B) is accessible and authenticated via the HuggingFace API using the `.env` configuration.
36
+ - `GET /baseline`: Instructs the backend to run the internal autonomous zero-shot agent across all three tasks simultaneously.
37
+ - `GET /baseline/task/{task_id}`: Streams execution of the baseline agent for a specific task (120s timeout). Used extensively by the Gradio UI for real-time visualization.
38
+
39
+ ### 4. User Interfaces (Frontend)
40
+
41
+ WhipStudio provides dual interfaces for varying stakeholder needs:
42
+
43
+ **A. The "ML Debug Arena" Verification Portal (Gradio App)**
44
+ A custom-built, interactive dashboard designed for human verification and hackathon judges.
45
+ - **Code Editor:** Monaco-style interactive Python editor pre-filled with the buggy code payload.
46
+ - **Live Loss Visualization:** Integrates `matplotlib` to plot `LOSS` curves and `VAL_ACCURACY` step-by-step from the agent's debugged submission.
47
+ - **Live Diff Context:** A source-control style diff view showing exactly what code lines the agent inserted or deleted to achieve the fix.
48
+ - **Autonomous Streaming Mode:** The "Auto-Agent" trigger connects to the `/baseline/task/{task_id}` streams to provide a live "Chain-of-Thought" experience, displaying progress and rewards directly to the UI panel in real-time.
49
+
50
+ **B. The OpenEnv Compliance Web UI (`/web`)**
51
+ - Directly embedded via the OpenEnv python framework (`os.environ["ENABLE_WEB_INTERFACE"] = "true"`).
52
+ - Accessed via `http://[host]:8000/web`, this provides the raw, standardized OpenEnv chat-style UI validating 100% adherence to the hackathon's core platform requirements.
53
+
54
+ ---
55
+
56
+ ## Completed Milestones
57
+
58
+ 1. **Bug Remediation:** Fixed critical parsing logic failures where double-escaped regex patterns (`[\\\\d\\\\.\\\\-]+`) prevented the extraction of metrics like `FINAL_LOSS:0.4523`. Also resolved the `ModuleNotFoundError: torch` in the sandbox.
59
+ 2. **Environment Token Logic:** Re-wrote authentication management using `python-dotenv` for local caching overrides, ensuring 401 Unauthorized API failures do not disrupt model inference when tokens expire mid-session.
60
+ 3. **Task Trajectory Generation:** Enabled persistent trajectory state logging inside `environment.py` for advanced analytics and training extraction.
61
+
62
+ ---
63
+
64
+ ## Strategic Next Steps
65
+
66
+ **Phase 2: Reinforcement Learning Integration**
67
+ The zero-shot inference pipeline currently scores dynamically. The immediate next phase leverages the Trajectory Logic:
68
+ 1. Export the stored trajectory JSON logs containing `buggy_code`, `reward`, and `error_log`.
69
+ 2. Integrate the **TRL (Transformer Reinforcement Learning) library**.
70
+ 3. Implement Group Relative Policy Optimization (GRPO) to iteratively train a specific lightweight model (like a 7B parameter local model) to perform significantly better on PyTorch debugging tasks than larger, generalized zero-shot models.
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-whipstudio"
7
+ version = "0.1.0"
8
+ description = "ML Debug environment for OpenEnv"
9
+ requires-python = ">=3.11"
10
+ dependencies = [
11
+ "openenv-core[core]>=0.2.1",
12
+ "fastapi>=0.110.0",
13
+ "uvicorn>=0.27.0",
14
+ "pydantic>=2.0.0",
15
+ "httpx>=0.27.0",
16
+ "torch>=2.2.0",
17
+ "smolagents>=1.0.0",
18
+ ]
19
+
20
+ [project.optional-dependencies]
21
+ dev = [
22
+ "pytest>=8.0.0",
23
+ ]
24
+
25
+ [project.scripts]
26
+ whipstudio-server = "server.app:main"
27
+ server = "server.app:main"
28
+
29
+ [tool.setuptools]
30
+ include-package-data = true
31
+ packages = ["server", "server.tasks"]
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ML debug server package."""
server/app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import sys
4
+
5
+ from dotenv import load_dotenv
6
+ load_dotenv(override=True)
7
+
8
+ import httpx
9
+ from fastapi import FastAPI, Request
10
+ from fastapi.responses import RedirectResponse
11
+ from fastapi.responses import HTMLResponse
12
+
13
+ from openenv.core.env_server.http_server import create_app
14
+
15
+ _project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
16
+ if _project_root not in sys.path:
17
+ sys.path.insert(0, _project_root)
18
+
19
+ try:
20
+ from ..models import MLDebugAction, MLDebugObservation
21
+ from .environment import MLDebugEnvironment
22
+ from .tasks.graders import RunResult, score_task
23
+ except ImportError:
24
+ from models import MLDebugAction, MLDebugObservation
25
+ from server.environment import MLDebugEnvironment
26
+ from server.tasks.graders import RunResult, score_task
27
+
28
+ # Disable OpenEnv's default web UI so /web can mirror the custom Gradio UI.
29
+ os.environ["ENABLE_WEB_INTERFACE"] = "false"
30
+
31
+ app: FastAPI = create_app(
32
+ MLDebugEnvironment,
33
+ MLDebugAction,
34
+ MLDebugObservation,
35
+ env_name="whipstudio-env",
36
+ max_concurrent_envs=4,
37
+ )
38
+
39
+
40
+ @app.get("/__build", include_in_schema=False)
41
+ def build_info():
42
+ """Build/runtime fingerprint to confirm what code is deployed."""
43
+ import platform
44
+
45
+ return {
46
+ "env_name": "whipstudio-env",
47
+ "python": platform.python_version(),
48
+ "platform": platform.platform(),
49
+ "port": os.environ.get("PORT"),
50
+ "enable_web_interface": os.environ.get("ENABLE_WEB_INTERFACE"),
51
+ }
52
+
53
+
54
+ def _has_route(path: str, method: str) -> bool:
55
+ method = method.upper()
56
+ for route in app.router.routes:
57
+ if getattr(route, "path", None) != path:
58
+ continue
59
+ methods = getattr(route, "methods", None)
60
+ if methods and method in methods:
61
+ return True
62
+ return False
63
+
64
+
65
+ @app.get("/", include_in_schema=False)
66
+ def root_redirect():
67
+ return RedirectResponse(url="/ui", status_code=307)
68
+
69
+
70
+ if not _has_route("/health", "GET"):
71
+
72
+ @app.get("/health", include_in_schema=False)
73
+ def health_get():
74
+ return {"status": "ok"}
75
+
76
+
77
+ if not _has_route("/health", "POST"):
78
+
79
+ @app.post("/health", include_in_schema=False)
80
+ def health_post():
81
+ return {"status": "ok"}
82
+
83
+
84
+ @app.get("/reset")
85
+ def reset_liveness():
86
+ return {"status": "ok", "message": "use POST /reset to start an episode"}
87
+
88
+
89
+ @app.get("/tasks")
90
+ def list_tasks():
91
+ return {
92
+ "tasks": [
93
+ {"id": "task1", "name": "Broken training loop", "difficulty": "easy"},
94
+ {"id": "task2", "name": "Silent NaN loss", "difficulty": "medium"},
95
+ {"id": "task3", "name": "OOM and data leakage", "difficulty": "hard"},
96
+ {"id": "task4", "name": "Wrong loss function", "difficulty": "medium"},
97
+ {"id": "task5", "name": "Frozen backbone", "difficulty": "medium"},
98
+ ],
99
+ "action_schema": {
100
+ "fixed_code": "string (required) — complete runnable Python script",
101
+ "explanation": "string (optional) — description of bugs found",
102
+ "attempt_number": "int 1-3 (optional) — which attempt this is",
103
+ },
104
+ }
105
+
106
+
107
+ @app.post("/grader")
108
+ def run_grader(payload: dict):
109
+ task_id = payload.get("task_id", "task1")
110
+ result = RunResult(
111
+ exit_code=payload.get("exit_code", -1),
112
+ stdout=payload.get("stdout", ""),
113
+ stderr=payload.get("stderr", ""),
114
+ elapsed_seconds=payload.get("elapsed", 0.0),
115
+ timed_out=payload.get("timed_out", False),
116
+ fixed_code=payload.get("fixed_code", ""),
117
+ )
118
+ score, breakdown = score_task(task_id, result)
119
+ return {"task_id": task_id, "score": score, "breakdown": breakdown}
120
+
121
+
122
+ @app.get("/baseline")
123
+ async def run_baseline(request: Request):
124
+ try:
125
+ from ..baseline_agent import run_single_task
126
+ except ImportError:
127
+ from baseline_agent import run_single_task
128
+
129
+ env_url = str(request.base_url).rstrip("/")
130
+ results = {}
131
+ task_scores = {}
132
+ for task_id in ["task1", "task2", "task3", "task4", "task5"]:
133
+ try:
134
+ score = await asyncio.wait_for(run_single_task(task_id, env_url), timeout=120.0)
135
+ results[task_id] = round(score, 4)
136
+ task_scores[task_id] = round(score, 4)
137
+ except TimeoutError:
138
+ results[task_id] = 0.0
139
+ task_scores[task_id] = 0.0
140
+ results[f"{task_id}_error"] = "timeout: task took longer than 120s"
141
+ except httpx.HTTPError as exc:
142
+ results[task_id] = 0.0
143
+ task_scores[task_id] = 0.0
144
+ results[f"{task_id}_error"] = f"http_error: {exc.__class__.__name__}: {exc}"
145
+ except Exception as exc:
146
+ results[task_id] = 0.0
147
+ task_scores[task_id] = 0.0
148
+ results[f"{task_id}_error"] = f"internal_error: {exc.__class__.__name__}: {exc}"
149
+ avg = round(sum(task_scores.values()) / max(1, len(task_scores)), 4)
150
+ return {"baseline_scores": results, "average": avg, "env_url": env_url}
151
+
152
+
153
+ @app.get("/baseline/task/{task_id}")
154
+ async def run_baseline_single(task_id: str, request: Request):
155
+ """Run the baseline agent on a single task. Returns score + details."""
156
+ try:
157
+ from ..baseline_agent import run_single_task_detailed
158
+ except ImportError:
159
+ from baseline_agent import run_single_task_detailed
160
+
161
+ env_url = str(request.base_url).rstrip("/")
162
+ try:
163
+ result = await asyncio.wait_for(run_single_task_detailed(task_id, env_url), timeout=120.0)
164
+ return {
165
+ "task_id": task_id,
166
+ "score": round(result["score"], 4),
167
+ "status": "ok",
168
+ "fixed_code": result.get("fixed_code", ""),
169
+ "output": result.get("output", ""),
170
+ "attempts": result.get("attempts", []),
171
+ }
172
+ except TimeoutError:
173
+ return {"task_id": task_id, "score": 0.0, "status": "timeout", "error": "Task took longer than 120s"}
174
+ except Exception as exc:
175
+ return {"task_id": task_id, "score": 0.0, "status": "error", "error": f"{exc.__class__.__name__}: {exc}"}
176
+
177
+
178
+ @app.get("/baseline/health")
179
+ def baseline_health():
180
+ hf_token_present = bool(os.environ.get("HF_TOKEN"))
181
+ model_ready = False
182
+ model_error = None
183
+
184
+ try:
185
+ try:
186
+ from ..baseline_agent import get_model
187
+ except ImportError:
188
+ from baseline_agent import get_model
189
+
190
+ get_model()
191
+ model_ready = True
192
+ except Exception as exc:
193
+ model_error = f"{exc.__class__.__name__}: {exc}"
194
+
195
+ status = "ok" if hf_token_present and model_ready else "degraded"
196
+ return {
197
+ "status": status,
198
+ "hf_token_present": hf_token_present,
199
+ "model_ready": model_ready,
200
+ "model_error": model_error,
201
+ }
202
+
203
+
204
+ _ui_mounted = False
205
+ try:
206
+ import gradio as gr
207
+ try:
208
+ from ..gradio_app import build_ui
209
+ except ImportError:
210
+ from gradio_app import build_ui
211
+
212
+ gradio_ui = build_ui()
213
+ app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
214
+ _ui_mounted = True
215
+ except Exception as e:
216
+ # Don't fail silently in Spaces: return a helpful error page at /ui.
217
+ import traceback
218
+
219
+ print(f"Failed to mount Gradio UI: {e}")
220
+ traceback.print_exc()
221
+
222
+
223
+ if not _ui_mounted:
224
+ @app.get("/ui", include_in_schema=False)
225
+ @app.get("/ui/", include_in_schema=False)
226
+ def ui_mount_failed():
227
+ return HTMLResponse(
228
+ "<h2>WhipStudio UI failed to start</h2>"
229
+ "<p>The API server is running, but the Gradio UI could not be mounted.</p>"
230
+ "<p>Check container logs for <code>Failed to mount Gradio UI</code>.</p>",
231
+ status_code=500,
232
+ )
233
+
234
+
235
+ @app.api_route("/web", methods=["GET", "POST"], include_in_schema=False)
236
+ def web_redirect_root():
237
+ return RedirectResponse(url="/ui", status_code=307)
238
+
239
+
240
+ @app.api_route("/web/{path:path}", methods=["GET", "POST"], include_in_schema=False)
241
+ def web_redirect_path(path: str):
242
+ if path:
243
+ return RedirectResponse(url=f"/ui/{path}", status_code=307)
244
+ return RedirectResponse(url="/ui", status_code=307)
245
+
246
+
247
+ def main(host: str = "0.0.0.0", port: int = 7860):
248
+ import uvicorn
249
+
250
+ uvicorn.run("server.app:app", host=host, port=port, reload=False)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
server/environment.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ import time
4
+ from uuid import uuid4
5
+
6
+ from openenv.core.env_server.interfaces import Environment
7
+ from openenv.core.env_server.types import State
8
+
9
+ try:
10
+ from ..models import MLDebugAction, MLDebugObservation
11
+ from .sandbox import execute_code
12
+ from .tasks import task1_broken_loop, task2_nan_loss, task3_oom_leakage, task4_wrong_loss, task5_frozen_backbone
13
+ from .tasks.graders import parse_losses, parse_val_accs, score_task
14
+ except ImportError:
15
+ from models import MLDebugAction, MLDebugObservation
16
+ from server.sandbox import execute_code
17
+ from server.tasks import task1_broken_loop, task2_nan_loss, task3_oom_leakage, task4_wrong_loss, task5_frozen_backbone
18
+ from server.tasks.graders import parse_losses, parse_val_accs, score_task
19
+
20
+ TASKS = {
21
+ "task1": task1_broken_loop,
22
+ "task2": task2_nan_loss,
23
+ "task3": task3_oom_leakage,
24
+ "task4": task4_wrong_loss,
25
+ "task5": task5_frozen_backbone,
26
+ }
27
+
28
+
29
+ class MLDebugEnvironment(Environment):
30
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
31
+
32
+ def __init__(self):
33
+ self._state = State(episode_id=str(uuid4()), step_count=0)
34
+ self._task_id = "task1"
35
+ self._best_reward = 0.0
36
+ self._trajectory: list[dict] = []
37
+
38
+ def reset(self, task_id: str = "task1", **kwargs) -> MLDebugObservation: # type: ignore[override]
39
+ if task_id not in TASKS:
40
+ task_id = "task1"
41
+
42
+ self._state = State(episode_id=str(uuid4()), step_count=0)
43
+ self._task_id = task_id
44
+ self._best_reward = 0.0
45
+ self._trajectory = []
46
+
47
+ task = TASKS[self._task_id]
48
+ return MLDebugObservation(
49
+ task_id=self._task_id,
50
+ task_description=task.TASK_DESCRIPTION.strip(),
51
+ buggy_code=task.BUGGY_CODE.strip(),
52
+ error_log="",
53
+ last_reward=0.0,
54
+ metrics={},
55
+ done=False,
56
+ reward=0.0,
57
+ )
58
+
59
+ def step(self, action: MLDebugAction) -> MLDebugObservation: # type: ignore[override]
60
+ self._state.step_count += 1
61
+
62
+ if not action.fixed_code or not action.fixed_code.strip():
63
+ done = True
64
+ metrics = {
65
+ "exit_code": -1,
66
+ "elapsed_seconds": 0.0,
67
+ "timed_out": False,
68
+ "step": self._state.step_count,
69
+ "best_reward_so_far": self._best_reward,
70
+ "error": "empty code submitted",
71
+ }
72
+ task = TASKS[self._task_id]
73
+ return MLDebugObservation(
74
+ task_id=self._task_id,
75
+ task_description=task.TASK_DESCRIPTION.strip(),
76
+ buggy_code=task.BUGGY_CODE.strip(),
77
+ error_log="empty code submitted",
78
+ last_reward=0.0,
79
+ metrics=metrics,
80
+ done=done,
81
+ reward=0.0,
82
+ )
83
+
84
+ run_result1 = execute_code(action.fixed_code)
85
+ reward1, breakdown1 = score_task(self._task_id, run_result1)
86
+
87
+ consistency_flag = False
88
+ reward_variance = 0.0
89
+ final_reward = reward1
90
+ final_breakdown = breakdown1
91
+ run_result = run_result1
92
+
93
+ if reward1 > 0.5:
94
+ run_result2 = execute_code(action.fixed_code)
95
+ reward2, breakdown2 = score_task(self._task_id, run_result2)
96
+ reward_variance = abs(reward1 - reward2)
97
+ if reward_variance > 0.15:
98
+ consistency_flag = True
99
+ final_reward = min(reward1, reward2)
100
+ if reward2 < reward1:
101
+ final_breakdown = breakdown2
102
+ run_result = run_result2
103
+ else:
104
+ consistency_flag = False
105
+ final_reward = (reward1 + reward2) / 2.0
106
+ else:
107
+ consistency_flag = False
108
+ final_reward = reward1
109
+
110
+ self._best_reward = max(self._best_reward, final_reward)
111
+ done = self._state.step_count >= 3 or final_reward >= 0.95
112
+
113
+ losses = parse_losses(run_result.stdout)
114
+ val_accs = parse_val_accs(run_result.stdout)
115
+ final_loss = None
116
+ if losses:
117
+ final_loss = losses[-1]
118
+ else:
119
+ match = re.search(r"FINAL_LOSS:([-\d.]+)", run_result.stdout)
120
+ if match:
121
+ final_loss = float(match.group(1))
122
+
123
+ metrics = {
124
+ "exit_code": run_result.exit_code,
125
+ "elapsed_seconds": run_result.elapsed_seconds,
126
+ "timed_out": run_result.timed_out,
127
+ "step": self._state.step_count,
128
+ "best_reward_so_far": self._best_reward,
129
+ "final_loss": final_loss,
130
+ "nan_count": sum(1 for x in losses if math.isnan(x) or math.isinf(x)) if losses else 0,
131
+ "val_acc": val_accs[-1] if val_accs else None,
132
+ "consistency_flag": consistency_flag,
133
+ "reward_variance": round(reward_variance, 4),
134
+ "reward_breakdown": final_breakdown,
135
+ }
136
+
137
+ task = TASKS[self._task_id]
138
+
139
+ self._trajectory.append({
140
+ "step": self._state.step_count,
141
+ "reward": final_reward,
142
+ "best_reward": self._best_reward,
143
+ "metrics": metrics,
144
+ "done": done,
145
+ "timestamp": time.time(),
146
+ })
147
+
148
+ return MLDebugObservation(
149
+ task_id=self._task_id,
150
+ task_description=task.TASK_DESCRIPTION.strip(),
151
+ buggy_code=task.BUGGY_CODE.strip(),
152
+ error_log=(run_result.stdout + "\n" + run_result.stderr).strip()[:2000],
153
+ last_reward=final_reward,
154
+ metrics=metrics,
155
+ done=done,
156
+ reward=final_reward,
157
+ )
158
+
159
+ @property
160
+ def trajectory(self) -> list[dict]:
161
+ return list(self._trajectory)
162
+
163
+ @property
164
+ def state(self) -> State:
165
+ return self._state
server/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core>=0.1.1
2
+ fastapi>=0.110.0
3
+ uvicorn>=0.27.0
4
+ httpx>=0.27.0
5
+ smolagents>=1.0.0
6
+ pydantic>=2.0.0
7
+ python-dotenv>=1.0.0
8
+ gradio>=4.0.0
9
+ tqdm>=4.0.0
10
+ scikit-learn
11
+ matplotlib
server/sandbox.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ import tempfile
5
+ import time
6
+
7
+ from .tasks.graders import RunResult
8
+
9
+ TIMEOUT_SECONDS = 30
10
+ MAX_OUTPUT_BYTES = 8192
11
+ BANNED_PATTERNS = [
12
+ "os.system",
13
+ "subprocess.",
14
+ "shutil.rmtree",
15
+ "open(",
16
+ "__import__",
17
+ "exec(",
18
+ "socket.",
19
+ "urllib.",
20
+ "requests.",
21
+ ]
22
+ SAFE_ENV = {
23
+ "PATH": os.environ.get("PATH", "/usr/bin:/usr/local/bin"),
24
+ "HOME": "/tmp",
25
+ "PYTHONPATH": os.pathsep.join(sys.path),
26
+ "PYTHONDONTWRITEBYTECODE": "1",
27
+ }
28
+
29
+
30
+ def strip_markdown_code(code: str) -> str:
31
+ if "```python" in code:
32
+ return code.split("```python", 1)[1].split("```", 1)[0].strip()
33
+ if "```" in code:
34
+ return code.split("```", 1)[1].split("```", 1)[0].strip()
35
+ return code.strip()
36
+
37
+
38
+ def execute_code(code: str) -> RunResult:
39
+ """Execute agent-submitted code in an isolated subprocess."""
40
+ cleaned_code = strip_markdown_code(code)
41
+
42
+ for pattern in BANNED_PATTERNS:
43
+ if pattern in cleaned_code:
44
+ return RunResult(
45
+ exit_code=-1,
46
+ stdout="",
47
+ stderr=f'Execution blocked: banned pattern "{pattern}" detected.',
48
+ elapsed_seconds=0.0,
49
+ timed_out=False,
50
+ fixed_code=cleaned_code,
51
+ )
52
+
53
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, dir="/tmp") as temp_file:
54
+ temp_file.write(cleaned_code)
55
+ tmp_path = temp_file.name
56
+
57
+ start = time.time()
58
+ try:
59
+ proc = subprocess.run(
60
+ ["python", tmp_path],
61
+ capture_output=True,
62
+ text=True,
63
+ timeout=TIMEOUT_SECONDS,
64
+ env=SAFE_ENV,
65
+ cwd="/tmp",
66
+ )
67
+ return RunResult(
68
+ exit_code=proc.returncode,
69
+ stdout=proc.stdout[:MAX_OUTPUT_BYTES],
70
+ stderr=proc.stderr[:2048],
71
+ elapsed_seconds=round(time.time() - start, 2),
72
+ timed_out=False,
73
+ fixed_code=cleaned_code,
74
+ )
75
+ except subprocess.TimeoutExpired:
76
+ return RunResult(
77
+ exit_code=-1,
78
+ stdout="",
79
+ stderr="Execution timed out after 30 seconds.",
80
+ elapsed_seconds=TIMEOUT_SECONDS,
81
+ timed_out=True,
82
+ fixed_code=cleaned_code,
83
+ )
84
+ finally:
85
+ try:
86
+ os.unlink(tmp_path)
87
+ except Exception:
88
+ pass
server/start.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ export ENABLE_WEB_INTERFACE=false
5
+
6
+ # Single process: FastAPI serves both API and mounted Gradio at /ui
7
+ uvicorn server.app:app --host 0.0.0.0 --port "${PORT:-7860}"
server/tasks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Task definitions for ML debug environment."""
server/tasks/graders.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ import ast
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class RunResult:
9
+ exit_code: int
10
+ stdout: str
11
+ stderr: str
12
+ elapsed_seconds: float
13
+ timed_out: bool
14
+ fixed_code: str = ""
15
+
16
+
17
+ def extract_metrics_block(stdout: str) -> str:
18
+ match = re.search(r"##METRICS_START##(.*?)##METRICS_END##", stdout, re.DOTALL)
19
+ if match:
20
+ return match.group(1)
21
+ return stdout
22
+
23
+
24
+ def parse_losses(stdout: str) -> list[float]:
25
+ stdout = extract_metrics_block(stdout)
26
+ match = re.search(r"LOSSES:\[([^\]]+)\]", stdout)
27
+ if not match:
28
+ return []
29
+ try:
30
+ return [float(x.strip()) for x in match.group(1).split(",")]
31
+ except Exception:
32
+ return []
33
+
34
+
35
+ def parse_val_accs(stdout: str) -> list[float]:
36
+ stdout = extract_metrics_block(stdout)
37
+ match = re.search(r"VAL_ACCS:\[([^\]]+)\]", stdout)
38
+ if not match:
39
+ return []
40
+ try:
41
+ return [float(x.strip()) for x in match.group(1).split(",")]
42
+ except Exception:
43
+ return []
44
+
45
+
46
+ def parse_scalar(stdout: str, key: str) -> float | None:
47
+ stdout = extract_metrics_block(stdout)
48
+ match = re.search(rf"{key}:([-\d.]+)", stdout)
49
+ return float(match.group(1)) if match else None
50
+
51
+
52
+ def is_valid_submission(code: str, stdout: str, exit_code: int) -> tuple[bool, str]:
53
+ if exit_code == 0:
54
+ if "LOSSES:" not in stdout and "FINAL_LOSS:" not in stdout:
55
+ return False, "No valid metrics output detected"
56
+ if "LOSSES:" in stdout:
57
+ losses = parse_losses(stdout)
58
+ if len(losses) < 5:
59
+ return False, "Fewer than 5 loss values parsed"
60
+ try:
61
+ tree = ast.parse(code)
62
+ if not any(isinstance(node, (ast.For, ast.While)) for node in ast.walk(tree)):
63
+ return False, "No ast.For or ast.While node found"
64
+ except Exception:
65
+ pass
66
+ return True, ""
67
+
68
+
69
+ def sigmoid_reward(value: float, center: float, steepness: float, invert: bool = False) -> float:
70
+ try:
71
+ if invert:
72
+ x = steepness * (value - center)
73
+ else:
74
+ x = steepness * (center - value)
75
+ return round(1.0 / (1.0 + math.exp(-x)), 4)
76
+ except OverflowError:
77
+ return 0.0 if (invert and value > center) or (not invert and value < center) else 1.0
78
+
79
+
80
+ def grade_task1(result: RunResult) -> tuple[float, dict]:
81
+ valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
82
+ if not valid:
83
+ return 0.0, {"reason": reason}
84
+
85
+ if result.timed_out:
86
+ return 0.05, {"reason": "timed_out"}
87
+ if result.exit_code != 0:
88
+ return 0.0, {"reason": "crash"}
89
+
90
+ losses = parse_losses(result.stdout)
91
+ if not losses:
92
+ return 0.1, {"reason": "no_losses_parsed"}
93
+ if any(math.isnan(loss) or math.isinf(loss) for loss in losses):
94
+ return 0.15, {"reason": "nan_inf_found"}
95
+
96
+ final = losses[-1]
97
+ base_score = sigmoid_reward(final, center=0.75, steepness=3.0, invert=True)
98
+
99
+ bonus = 0.0
100
+ half = len(losses) // 2
101
+ if half > 0:
102
+ first_half = sum(losses[:half]) / half
103
+ second_half = sum(losses[half:]) / len(losses[half:])
104
+ if second_half < 0.85 * first_half:
105
+ bonus = 0.1
106
+
107
+ final_score = min(1.0, base_score + bonus)
108
+ breakdown = {"base_score": base_score, "monotonicity_bonus": bonus}
109
+ return final_score, breakdown
110
+
111
+
112
+ def grade_task2(result: RunResult) -> tuple[float, dict]:
113
+ valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
114
+ if not valid:
115
+ return 0.0, {"reason": reason}
116
+
117
+ if result.timed_out:
118
+ return 0.05, {"reason": "timed_out"}
119
+ if result.exit_code != 0:
120
+ return 0.0, {"reason": "crash"}
121
+
122
+ losses = parse_losses(result.stdout)
123
+ if not losses or len(losses) < 30:
124
+ return 0.1, {"reason": "too_few_losses"}
125
+
126
+ nan_count = sum(1 for loss in losses if math.isnan(loss) or math.isinf(loss))
127
+ if nan_count == len(losses):
128
+ return 0.0, {"reason": "all_nans"}
129
+
130
+ nan_ratio = nan_count / len(losses)
131
+ finite_losses = [loss for loss in losses if not math.isnan(loss) and not math.isinf(loss)]
132
+ final_finite_loss = finite_losses[-1] if finite_losses else float('inf')
133
+
134
+ convergence_score = sigmoid_reward(final_finite_loss, center=0.5, steepness=4.0, invert=True)
135
+ convergence_score *= (1.0 - nan_ratio)
136
+
137
+ stability_bonus = 0.0
138
+ if len(finite_losses) >= 20:
139
+ tail = finite_losses[-20:]
140
+ mean_tail = sum(tail) / len(tail)
141
+ tail_variance = sum((x - mean_tail) ** 2 for x in tail) / len(tail)
142
+ stability_bonus = sigmoid_reward(tail_variance, center=0.01, steepness=200.0, invert=True) * 0.1
143
+
144
+ final_score = min(1.0, convergence_score + stability_bonus)
145
+ breakdown = {"convergence_score": convergence_score, "nan_penalty": (1.0 - nan_ratio), "stability_bonus": stability_bonus, "nan_ratio": nan_ratio}
146
+ return final_score, breakdown
147
+
148
+
149
+ def grade_task3(result: RunResult) -> tuple[float, dict]:
150
+ valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
151
+ if not valid:
152
+ return 0.0, {"reason": reason}
153
+
154
+ if result.timed_out:
155
+ return 0.1, {"reason": "timed_out"}
156
+
157
+ if result.exit_code != 0:
158
+ if "out of memory" in result.stderr.lower():
159
+ return 0.1, {"reason": "oom"}
160
+ return 0.0, {"reason": "crash"}
161
+
162
+ val_accs = parse_val_accs(result.stdout)
163
+ final_loss_val = parse_scalar(result.stdout, "FINAL_LOSS")
164
+
165
+ memory_score = 0.0
166
+ if final_loss_val is not None:
167
+ memory_score = sigmoid_reward(final_loss_val, center=50.0, steepness=0.05, invert=True) * 0.5
168
+
169
+ leakage_score = 0.0
170
+ early_acc = 0.0
171
+ final_acc = 0.0
172
+ if val_accs and len(val_accs) >= 2:
173
+ early_acc = sum(val_accs[:2]) / 2.0
174
+ final_acc = val_accs[-1]
175
+
176
+ leak_p1 = sigmoid_reward(early_acc, center=0.75, steepness=20.0, invert=True) * 0.3
177
+ leak_p2 = sigmoid_reward(final_acc, center=0.68, steepness=15.0, invert=False) * 0.7
178
+ leakage_score = (leak_p1 + leak_p2) * 0.5
179
+
180
+ final_score = min(1.0, memory_score + leakage_score)
181
+ breakdown = {"memory_score": memory_score, "leakage_score": leakage_score, "early_acc": early_acc, "final_acc": final_acc}
182
+ return final_score, breakdown
183
+
184
+
185
+ def grade_task4(result: RunResult) -> tuple[float, dict]:
186
+ valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
187
+ if not valid:
188
+ return 0.0, {"reason": reason}
189
+
190
+ if result.timed_out:
191
+ return 0.1, {"reason": "timed_out"}
192
+
193
+ if result.exit_code != 0:
194
+ return 0.0, {"reason": "crash"}
195
+
196
+ final_loss = parse_scalar(result.stdout, "FINAL_LOSS")
197
+ avg_labels = parse_scalar(result.stdout, "AVG_LABELS")
198
+ f1 = parse_scalar(result.stdout, "F1_SCORE")
199
+
200
+ loss_score = 0.0
201
+ if final_loss is not None:
202
+ loss_score = sigmoid_reward(final_loss, center=0.5, steepness=4.0, invert=True) * 0.3
203
+
204
+ labels_score = 0.0
205
+ if avg_labels is not None:
206
+ labels_score = sigmoid_reward(avg_labels, center=1.0, steepness=5.0, invert=False) * 0.3
207
+
208
+ f1_s = 0.0
209
+ if f1 is not None:
210
+ f1_s = sigmoid_reward(f1, center=0.6, steepness=10.0, invert=False) * 0.4
211
+
212
+ final_score = min(1.0, loss_score + labels_score + f1_s)
213
+ breakdown = {"loss_score": loss_score, "labels_score": labels_score, "f1_score": f1_s}
214
+ return final_score, breakdown
215
+
216
+
217
+ def grade_task5(result: RunResult) -> tuple[float, dict]:
218
+ valid, reason = is_valid_submission(result.fixed_code, result.stdout, result.exit_code)
219
+ if not valid:
220
+ return 0.0, {"reason": reason}
221
+
222
+ if result.timed_out:
223
+ return 0.1, {"reason": "timed_out"}
224
+
225
+ if result.exit_code != 0:
226
+ return 0.0, {"reason": "crash"}
227
+
228
+ final_loss = parse_scalar(result.stdout, "FINAL_LOSS")
229
+ grad_norm = parse_scalar(result.stdout, "BACKBONE_GRAD_NORM")
230
+
231
+ loss_score = 0.0
232
+ if final_loss is not None:
233
+ loss_score = sigmoid_reward(final_loss, center=2.2, steepness=3.0, invert=True) * 0.5
234
+
235
+ grad_score = 0.0
236
+ if grad_norm is not None:
237
+ grad_score = sigmoid_reward(grad_norm, center=0.001, steepness=1000.0, invert=False) * 0.5
238
+
239
+ final_score = min(1.0, loss_score + grad_score)
240
+ breakdown = {"loss_score": loss_score, "grad_score": grad_score}
241
+ return final_score, breakdown
242
+
243
+
244
+ def score_task(task_id: str, result: RunResult) -> tuple[float, dict]:
245
+ graders = {
246
+ "task1": grade_task1,
247
+ "task2": grade_task2,
248
+ "task3": grade_task3,
249
+ "task4": grade_task4,
250
+ "task5": grade_task5,
251
+ }
252
+ if task_id not in graders:
253
+ raise ValueError(f"Unknown task_id: {task_id}")
254
+
255
+ score, breakdown = graders[task_id](result)
256
+ return round(max(0.0, min(1.0, score)), 4), breakdown
server/tasks/task1_broken_loop.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK_DESCRIPTION = """
2
+ This 2-class linear classifier training loop has bugs preventing convergence.
3
+ Fix it so that after 50 steps the loss is below 0.75 and decreasing.
4
+ Model: nn.Linear(10, 2), dataset: random 2-class, 32 samples/batch.
5
+ Print losses as: LOSSES:[val1, val2, ...]
6
+ """
7
+
8
+ BUGGY_CODE = """
9
+ import torch
10
+ import torch.nn as nn
11
+ torch.manual_seed(0)
12
+ model = nn.Linear(10, 2)
13
+ optimizer = torch.optim.Adam(model.parameters(), lr=10.0) # BUG 1: lr too high
14
+ criterion = nn.CrossEntropyLoss()
15
+ losses = []
16
+ for step in range(50):
17
+ x = torch.randn(32, 10)
18
+ y = torch.randint(0, 2, (32,))
19
+ optimizer.zero_grad()
20
+ logits = model(x)
21
+ loss = criterion(logits, y)
22
+ optimizer.step() # BUG 2: step before backward
23
+ loss.backward() # BUG 3: backward after step
24
+ losses.append(loss.item())
25
+ print('##METRICS_START##')
26
+ print('LOSSES:' + str(losses))
27
+ print('##METRICS_END##')
28
+ """
29
+
30
+ GROUND_TRUTH_BUGS = [
31
+ "optimizer.step() called before loss.backward()",
32
+ "learning rate 10.0 should be ~0.001",
33
+ ]
server/tasks/task2_nan_loss.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK_DESCRIPTION = """
2
+ This binary regression trainer produces NaN loss around step 15.
3
+ Fix the numerical instability so loss stays finite for all 60 steps
4
+ and the final loss is below 0.5.
5
+ Print losses as: LOSSES:[val1, val2, ...]
6
+ """
7
+
8
+ BUGGY_CODE = """
9
+ import torch
10
+ import torch.nn as nn
11
+ torch.manual_seed(42)
12
+ model = nn.Linear(16, 1)
13
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
14
+ losses = []
15
+ for step in range(60):
16
+ x = torch.randn(64, 16)
17
+ y = torch.rand(64, 1)
18
+ optimizer.zero_grad()
19
+ pred = torch.sigmoid(model(x))
20
+ # BUG: log(pred) can be -inf when pred rounds to 0.0
21
+ loss = -torch.mean(y * torch.log(pred) + (1 - y) * torch.log(1 - pred))
22
+ loss.backward()
23
+ optimizer.step()
24
+ losses.append(loss.item())
25
+ print('##METRICS_START##')
26
+ print('LOSSES:' + str(losses))
27
+ print('##METRICS_END##')
28
+ """
29
+
30
+ GROUND_TRUTH_BUGS = [
31
+ "torch.log(pred) when pred can be 0.0 after sigmoid — use F.binary_cross_entropy or clamp",
32
+ ]
server/tasks/task3_oom_leakage.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK_DESCRIPTION = """
2
+ This trainer has TWO independent bugs:
3
+ 1. A memory leak causing OOM crash before epoch 5 on CPU.
4
+ 2. Data leakage inflating validation accuracy.
5
+ Fix both. After 20 epochs: val_acc > 0.70, no OOM, no suspicious early accuracy spike.
6
+ Print as: VAL_ACCS:[v1,v2,...] and FINAL_LOSS:X.XX
7
+ """
8
+
9
+ BUGGY_CODE = """
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.utils.data import DataLoader, TensorDataset, random_split
13
+
14
+ torch.manual_seed(42)
15
+ X = torch.randn(1000, 20)
16
+ y = (X[:, 0] > 0).float()
17
+ # BUG 1: augmentation before split — val set gets augmented
18
+ X = X + torch.randn_like(X) * 0.1
19
+ train_ds, val_ds = random_split(TensorDataset(X, y), [800, 200])
20
+ model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1))
21
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
22
+ criterion = nn.BCEWithLogitsLoss()
23
+ train_losses, val_accs = [], []
24
+ total_loss = torch.tensor(0.0) # BUG 2: keeps computation graph alive
25
+ for epoch in range(20):
26
+ model.train()
27
+ for xb, yb in DataLoader(train_ds, batch_size=32):
28
+ optimizer.zero_grad()
29
+ out = model(xb).squeeze()
30
+ loss = criterion(out, yb)
31
+ loss.backward()
32
+ optimizer.step()
33
+ total_loss = total_loss + loss # BUG 2: graph accumulates
34
+ model.eval()
35
+ with torch.no_grad():
36
+ idx = val_ds.indices
37
+ xv, yv = X[idx], y[idx]
38
+ preds = (torch.sigmoid(model(xv)) > 0.5).float()
39
+ acc = (preds == yv).float().mean().item()
40
+ val_accs.append(round(acc, 4))
41
+ print('##METRICS_START##')
42
+ print('VAL_ACCS:' + str(val_accs))
43
+ print('FINAL_LOSS:' + str(total_loss.item()))
44
+ print('##METRICS_END##')
45
+ """
46
+
47
+ GROUND_TRUTH_BUGS = [
48
+ "Augmentation applied before split — move after split, apply to train only",
49
+ "total_loss += loss retains graph — use total_loss += loss.item()",
50
+ ]
server/tasks/task4_wrong_loss.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK_DESCRIPTION = """
2
+ This is a multi-label classification problem where each sample can have multiple active classes.
3
+ However, the model is currently using `CrossEntropyLoss`, which is meant for single-label (mutually exclusive) classes.
4
+ Because of this, the loss trains but the predictions are essentially garbage (treating it as a single-label problem).
5
+
6
+ Fix the loss function so it correctly handles multi-label classification.
7
+ The grader will check:
8
+ 1. Loss convergence (loss < 0.5)
9
+ 2. Model predictions are actually multi-hot (avg > 1 label/sample)
10
+ 3. F1 Score > 0.6
11
+ """
12
+
13
+ BUGGY_CODE = """
14
+ import torch
15
+ import torch.nn as nn
16
+ from sklearn.metrics import f1_score
17
+
18
+ torch.manual_seed(42)
19
+
20
+ # Generate synthetic multi-label data (100 samples, 20 features, 5 classes)
21
+ X = torch.randn(100, 20)
22
+ # Each sample has a 30% chance of having each class active
23
+ y = (torch.rand(100, 5) > 0.7).float()
24
+
25
+ model = nn.Sequential(
26
+ nn.Linear(20, 64),
27
+ nn.ReLU(),
28
+ nn.Linear(64, 5)
29
+ )
30
+
31
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
32
+
33
+ # BUG: CrossEntropyLoss is for single-label classification
34
+ criterion = nn.CrossEntropyLoss()
35
+
36
+ losses = []
37
+ for step in range(100):
38
+ optimizer.zero_grad()
39
+ logits = model(X)
40
+
41
+ # CrossEntropyLoss expects class indices, not one-hot/multi-hot vectors for the target
42
+ loss = criterion(logits, y)
43
+
44
+ loss.backward()
45
+ optimizer.step()
46
+ losses.append(loss.item())
47
+
48
+ # Evaluation
49
+ with torch.no_grad():
50
+ logits = model(X)
51
+ # Using sigmoid and 0.5 threshold for multi-label prediction
52
+ preds = (torch.sigmoid(logits) > 0.5).float()
53
+
54
+ avg_labels = preds.sum(dim=1).mean().item()
55
+ f1 = f1_score(y.numpy(), preds.numpy(), average='micro')
56
+
57
+ print('##METRICS_START##')
58
+ print('FINAL_LOSS:' + str(losses[-1]))
59
+ print('AVG_LABELS:' + str(avg_labels))
60
+ print('F1_SCORE:' + str(f1))
61
+ print('##METRICS_END##')
62
+ """
server/tasks/task5_frozen_backbone.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK_DESCRIPTION = """
2
+ This is a standard transfer learning setup classifying 10 categories.
3
+ The developer froze the backbone during testing, but forgot to unfreeze it while still passing its parameters to the optimizer.
4
+ Fix the code so the backbone actually trains, or only pass the head parameters.
5
+ The grader checks the gradient norm of the backbone from the first backward pass.
6
+ """
7
+
8
+ BUGGY_CODE = """
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ torch.manual_seed(42)
13
+
14
+ # Dummy dataset
15
+ X = torch.randn(32, 512)
16
+ y = torch.randint(0, 10, (32,))
17
+
18
+ # A simulated pre-trained backbone
19
+ backbone = nn.Sequential(
20
+ nn.Linear(512, 512),
21
+ nn.ReLU(),
22
+ nn.Linear(512, 512),
23
+ nn.ReLU()
24
+ )
25
+
26
+ # BUG: backbone is frozen, but passed to optimizer
27
+ backbone.requires_grad_(False)
28
+
29
+ head = nn.Linear(512, 10)
30
+
31
+ # passing both backbone and head to optimizer even though backbone is frozen
32
+ optimizer = torch.optim.Adam(
33
+ list(backbone.parameters()) + list(head.parameters()), lr=0.001
34
+ )
35
+ criterion = nn.CrossEntropyLoss()
36
+
37
+ losses = []
38
+
39
+ # Take one step to check gradients
40
+ optimizer.zero_grad()
41
+ features = backbone(X)
42
+ logits = head(features)
43
+
44
+ loss = criterion(logits, y)
45
+ loss.backward()
46
+
47
+ # Calculate gradient norm on backbone to see if it's training
48
+ backbone_grad_norm = sum(
49
+ p.grad.norm().item() for p in backbone.parameters() if p.grad is not None
50
+ )
51
+
52
+ optimizer.step()
53
+ losses.append(loss.item())
54
+
55
+ # Note: if backbone is properly frozen and only head is passed, backbone_grad_norm will be 0 but optimizer won't complain.
56
+ # If backbone is unfrozen, backbone_grad_norm will be > 0.
57
+ # The grader handles both valid solutions.
58
+ print('##METRICS_START##')
59
+ print('FINAL_LOSS:' + str(losses[-1]))
60
+ print('BACKBONE_GRAD_NORM:' + str(backbone_grad_norm))
61
+ print('##METRICS_END##')
62
+ """
uv.lock ADDED
The diff for this file is too large to render. See raw diff