File size: 13,997 Bytes
72ddcb6
 
 
 
 
ffd85e1
7ac53ec
72ddcb6
 
 
ffd85e1
72ddcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c735c6
2bf44ab
 
7c735c6
72ddcb6
 
 
 
 
 
 
43923f4
f8f316e
72ddcb6
 
 
 
 
 
 
 
 
43923f4
72ddcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89c9429
72ddcb6
 
 
 
 
 
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
72ddcb6
 
 
 
 
 
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ddcb6
 
 
 
 
 
 
 
 
ffd85e1
 
 
 
 
72ddcb6
 
 
 
1d1a2bf
72ddcb6
 
0c28a91
72ddcb6
 
ffd85e1
 
 
 
 
 
 
 
72ddcb6
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ddcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd85e1
72ddcb6
ffd85e1
72ddcb6
 
0c28a91
ffd85e1
 
0c28a91
 
 
 
 
 
72ddcb6
 
0c28a91
ffd85e1
 
 
72ddcb6
0c28a91
 
ffd85e1
0c28a91
72ddcb6
 
 
7c735c6
 
ffd85e1
72ddcb6
7c735c6
 
72ddcb6
 
7c735c6
 
72ddcb6
 
0c28a91
 
 
 
 
ffd85e1
 
0c28a91
72ddcb6
 
 
 
 
 
ffd85e1
 
 
 
 
72ddcb6
ffd85e1
 
 
 
 
72ddcb6
 
0c28a91
ffd85e1
 
0c28a91
 
 
7c735c6
0c28a91
 
 
 
 
ffd85e1
 
 
 
72ddcb6
0c28a91
ffd85e1
 
0c28a91
72ddcb6
 
 
 
0c28a91
ffd85e1
 
72ddcb6
 
 
ffd85e1
72ddcb6
 
7c735c6
72ddcb6
7c735c6
72ddcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89c9429
 
 
 
 
 
 
 
 
72ddcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89c9429
72ddcb6
 
 
 
 
 
89c9429
72ddcb6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
import asyncio
import os
import sys

from dotenv import load_dotenv

load_dotenv()

import httpx
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from openenv.core.env_server.http_server import create_app

_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)

try:
    from ..models import MLDebugAction, MLDebugObservation
    from .environment import MLDebugEnvironment
    from .tasks.graders import RunResult, score_task
except ImportError:
    from models import MLDebugAction, MLDebugObservation
    from server.environment import MLDebugEnvironment
    from server.tasks.graders import RunResult, score_task

# Score bounds - hackathon requires strictly (0, 1) not [0, 1]
MIN_SCORE = 0.1
MAX_SCORE = 0.9999

# Disable OpenEnv's default web UI so /web can mirror the custom Gradio UI.
os.environ["ENABLE_WEB_INTERFACE"] = "false"

app: FastAPI = create_app(
    MLDebugEnvironment,
    MLDebugAction,
    MLDebugObservation,
    env_name="whipstudio",
    max_concurrent_envs=64,
)


@app.get("/__build", include_in_schema=False)
def build_info():
    """Build/runtime fingerprint to confirm what code is deployed."""
    import platform

    return {
        "env_name": "whipstudio",
        "python": platform.python_version(),
        "platform": platform.platform(),
        "port": os.environ.get("PORT"),
        "enable_web_interface": os.environ.get("ENABLE_WEB_INTERFACE"),
    }


def _has_route(path: str, method: str) -> bool:
    method = method.upper()
    for route in app.router.routes:
        if getattr(route, "path", None) != path:
            continue
        methods = getattr(route, "methods", None)
        if methods and method in methods:
            return True
    return False


@app.get("/", include_in_schema=False)
def root_redirect():
    return RedirectResponse(url="/ui/", status_code=307)


if not _has_route("/health", "GET"):

    @app.get("/health", include_in_schema=False)
    def health_get():
        try:
            from .logging_config import get_metrics
        except ImportError:
            from server.logging_config import get_metrics
        
        metrics = get_metrics()
        return {
            "status": "ok",
            "version": "1.1.0",
            "tasks_available": 6,
            "tools_available": 6,
            "uptime_seconds": metrics.get_metrics().get("uptime_seconds", 0),
            "ready": True,
        }


if not _has_route("/health", "POST"):

    @app.post("/health", include_in_schema=False)
    def health_post():
        return {"status": "ok", "version": "1.1.0"}


@app.get("/metrics")
def get_runtime_metrics():
    """Return runtime metrics for monitoring."""
    try:
        from .logging_config import get_metrics
    except ImportError:
        from server.logging_config import get_metrics
    
    return get_metrics().get_metrics()


@app.get("/session/state")
def get_session_state(episode_id: str = ""):
    """Get state for a specific session by episode_id."""
    try:
        from .environment import MAX_TURNS_PER_EPISODE, _get_session
    except ImportError:
        from server.environment import MAX_TURNS_PER_EPISODE, _get_session
    
    if not episode_id:
        return {
            "error": "episode_id query parameter required",
            "usage": "/session/state?episode_id=<your-episode-id>"
        }
    
    session = _get_session(episode_id)
    if session is None:
        return {
            "error": f"No session found for episode_id '{episode_id}'",
            "hint": "Call POST /reset first to start an episode"
        }
    
    return {
        "episode_id": session.episode_id,
        "task_id": session.task_id,
        "step": session.step_count,
        "turn": session.step_count,  # Alias for compatibility
        "submitted": session.submitted,
        "done": session.submitted,  # Alias for compatibility
        "best_reward": session.best_reward,
        "max_turns": MAX_TURNS_PER_EPISODE,
        "turns_remaining": max(0, MAX_TURNS_PER_EPISODE - session.step_count),
    }


@app.get("/reset")
def reset_liveness():
    return {"status": "ok", "message": "use POST /reset to start an episode"}


@app.get("/tasks")
def list_tasks():
    try:
        from .environment import TOOL_DEFINITIONS
    except ImportError:
        from server.environment import TOOL_DEFINITIONS
    
    return {
        "tasks": [
            {"id": "task1", "name": "Broken training loop", "difficulty": "easy"},
            {"id": "task2", "name": "Silent NaN loss", "difficulty": "medium"},
            {"id": "task3", "name": "Label inversion", "difficulty": "medium"},
            {"id": "task4", "name": "Wrong loss function", "difficulty": "medium"},
            {"id": "task5", "name": "Frozen backbone", "difficulty": "medium"},
            {"id": "task6", "name": "Input-Output mismatch", "difficulty": "hard"},
        ],
        "action_schema": {
            "action_type": "string β€” one of: submit_fix, execute_snippet, inspect_tensor, run_training_probe, get_variable_state, inspect_diff",
            "fixed_code": "string β€” for submit_fix: complete runnable Python script",
            "code": "string β€” for execute_snippet/run_training_probe: Python code to run",
            "setup_code": "string β€” for inspect_tensor/get_variable_state: setup code",
            "target_expression": "string β€” for inspect_tensor: expression to inspect",
            "expressions": "list[string] β€” for get_variable_state: expressions to evaluate",
            "proposed_code": "string β€” for inspect_diff: proposed fix to diff",
            "steps": "int 1-10 β€” for run_training_probe: number of steps",
        },
        "tools": [t["name"] for t in TOOL_DEFINITIONS],
        "max_turns_per_episode": 10,
    }


@app.get("/tools")
def list_tools():
    """Return available tools and their schemas for agent system prompts."""
    try:
        from .environment import MAX_TURNS_PER_EPISODE, TOOL_DEFINITIONS
    except ImportError:
        from server.environment import MAX_TURNS_PER_EPISODE, TOOL_DEFINITIONS
    
    return {
        "tools": TOOL_DEFINITIONS,
        "max_turns_per_episode": MAX_TURNS_PER_EPISODE,
        "usage": {
            "description": "Call step() with action_type set to the tool name. Tools return observations with episode_done=False. submit_fix is the terminal action.",
            "example": {
                "action": {
                    "action_type": "execute_snippet",
                    "code": "print('hello world')"
                }
            }
        }
    }


@app.post("/grader")
def run_grader(payload: dict):
    task_id = payload.get("task_id", "task1")
    result = RunResult(
        exit_code=payload.get("exit_code", -1),
        stdout=payload.get("stdout", ""),
        stderr=payload.get("stderr", ""),
        elapsed_seconds=payload.get("elapsed", 0.0),
        timed_out=payload.get("timed_out", False),
        fixed_code=payload.get("fixed_code", ""),
    )
    score, breakdown = score_task(task_id, result)
    return {"task_id": task_id, "score": score, "breakdown": breakdown}


@app.get("/baseline")
async def run_baseline(request: Request):
    try:
        from ..baseline_agent import SUPPORTED_MODEL_IDS, TASK_CONFIG, run_single_task
    except ImportError:
        from baseline_agent import SUPPORTED_MODEL_IDS, TASK_CONFIG, run_single_task

    env_url = str(request.base_url).rstrip("/")
    model_id = request.query_params.get("model_id", "Qwen/Qwen2.5-Coder-32B-Instruct")
    use_tools = request.query_params.get("use_tools", "true").lower() == "true"
    
    if model_id not in SUPPORTED_MODEL_IDS:
        return {
            "error": f"Unsupported model_id '{model_id}'",
            "supported_model_ids": SUPPORTED_MODEL_IDS,
        }

    results = {}
    task_scores = {}
    for task_id in ["task1", "task2", "task3", "task4", "task5", "task6"]:
        task_cfg = TASK_CONFIG.get(task_id, {})
        # Increase timeout for tool-using agent (more turns = more time)
        timeout_secs = 180.0 if use_tools else 120.0
        try:
            score = await asyncio.wait_for(
                run_single_task(task_id, env_url, model_id=model_id),
                timeout=timeout_secs,
            )
            results[task_id] = round(score, 4)
            task_scores[task_id] = round(score, 4)
        except TimeoutError:
            results[task_id] = MIN_SCORE
            task_scores[task_id] = MIN_SCORE
            results[f"{task_id}_error"] = f"timeout: task took longer than {timeout_secs}s"
        except httpx.HTTPError as exc:
            results[task_id] = MIN_SCORE
            task_scores[task_id] = MIN_SCORE
            results[f"{task_id}_error"] = f"http_error: {exc.__class__.__name__}: {exc}"
        except Exception as exc:
            results[task_id] = MIN_SCORE
            task_scores[task_id] = MIN_SCORE
            results[f"{task_id}_error"] = f"internal_error: {exc.__class__.__name__}: {exc}"
    avg = round(sum(task_scores.values()) / max(1, len(task_scores)), 4)
    return {
        "baseline_scores": results,
        "average": avg,
        "env_url": env_url,
        "model_id": model_id,
        "use_tools": use_tools,
        "task_config": TASK_CONFIG,
    }


@app.get("/baseline/task/{task_id}")
async def run_baseline_single(task_id: str, request: Request):
    """Run the baseline agent on a single task. Returns score + details."""
    try:
        from ..baseline_agent import (
            SUPPORTED_MODEL_IDS,
            TASK_CONFIG,
            run_single_task_detailed,
        )
    except ImportError:
        from baseline_agent import (
            SUPPORTED_MODEL_IDS,
            TASK_CONFIG,
            run_single_task_detailed,
        )

    env_url = str(request.base_url).rstrip("/")
    model_id = request.query_params.get("model_id", "Qwen/Qwen2.5-Coder-32B-Instruct")
    use_tools = request.query_params.get("use_tools", "true").lower() == "true"
    
    if model_id not in SUPPORTED_MODEL_IDS:
        return {
            "task_id": task_id,
            "score": MIN_SCORE,
            "status": "error",
            "error": f"Unsupported model_id '{model_id}'",
            "supported_model_ids": SUPPORTED_MODEL_IDS,
        }

    task_cfg = TASK_CONFIG.get(task_id, {"max_turns": 10})
    # Longer timeout for tool-using agent
    timeout_secs = 180.0 if use_tools else 120.0
    
    try:
        result = await asyncio.wait_for(
            run_single_task_detailed(task_id, env_url, model_id=model_id, use_tools=use_tools),
            timeout=timeout_secs,
        )
        return {
            "task_id": task_id,
            "score": round(result["score"], 4),
            "status": "ok",
            "model_id": model_id,
            "use_tools": use_tools,
            "max_turns": task_cfg.get("max_turns", 10),
            "fixed_code": result.get("fixed_code", ""),
            "output": result.get("output", ""),
            "attempts": result.get("attempts", []),
            "tool_history": result.get("tool_history", []),
        }
    except TimeoutError:
        return {"task_id": task_id, "score": MIN_SCORE, "status": "timeout", "error": f"Task took longer than {timeout_secs}s"}
    except Exception as exc:
        return {"task_id": task_id, "score": MIN_SCORE, "status": "error", "error": f"{exc.__class__.__name__}: {exc}"}


@app.get("/baseline/health")
def baseline_health():
    hf_token_present = bool(os.environ.get("HF_TOKEN"))
    model_ready = False
    model_error = None

    try:
        try:
            from ..baseline_agent import get_model
        except ImportError:
            from baseline_agent import get_model

        get_model()
        model_ready = True
    except Exception as exc:
        model_error = f"{exc.__class__.__name__}: {exc}"

    status = "ok" if hf_token_present and model_ready else "degraded"
    return {
        "status": status,
        "hf_token_present": hf_token_present,
        "model_ready": model_ready,
        "model_error": model_error,
    }


_ui_mounted = False


@app.get("/ui", include_in_schema=False)
def ui_trailing_slash_redirect():
    # Gradio's HTML references assets as `./assets/...`.
    # Without the trailing slash, browsers resolve those to `/assets/...` (breaking the UI).
    return RedirectResponse(url="/ui/", status_code=307)


try:
    import gradio as gr
    try:
        from ..gradio_app import build_ui
    except ImportError:
        from gradio_app import build_ui

    gradio_ui = build_ui()
    app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
    _ui_mounted = True
except Exception as e:
    # Don't fail silently in Spaces: return a helpful error page at /ui.
    import traceback

    print(f"Failed to mount Gradio UI: {e}")
    traceback.print_exc()


if not _ui_mounted:
    @app.get("/ui", include_in_schema=False)
    @app.get("/ui/", include_in_schema=False)
    def ui_mount_failed():
        return HTMLResponse(
            "<h2>WhipStudio UI failed to start</h2>"
            "<p>The API server is running, but the Gradio UI could not be mounted.</p>"
            "<p>Check container logs for <code>Failed to mount Gradio UI</code>.</p>",
            status_code=500,
        )


@app.api_route("/web", methods=["GET", "POST"], include_in_schema=False)
def web_redirect_root():
    return RedirectResponse(url="/ui/", status_code=307)


@app.api_route("/web/{path:path}", methods=["GET", "POST"], include_in_schema=False)
def web_redirect_path(path: str):
    if path:
        return RedirectResponse(url=f"/ui/{path}", status_code=307)
    return RedirectResponse(url="/ui/", status_code=307)


def main(host: str = "0.0.0.0", port: int = 7860):
    import uvicorn

    uvicorn.run("server.app:app", host=host, port=port, reload=False)


if __name__ == "__main__":
    main()