File size: 3,797 Bytes
ced8fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efeaffa
 
ced8fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efeaffa
 
 
ced8fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
FastAPI server exposing the CodeReviewEnv as an HTTP API.
Endpoints mirror the OpenEnv interface: /reset, /step, /state, /tasks.
"""
from __future__ import annotations

import sys
import os
sys.path.insert(0, os.path.dirname(__file__))

from typing import Any, Dict
import uuid

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from env import CodeReviewEnv, TASK_IDS
from models import ReviewAction, Observation, StepReward, EnvironmentState

app = FastAPI(
    title="CodeReviewEnv",
    description="OpenEnv-compliant environment for AI code review agents.",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── Session store (in-memory, single process) ────────────────────────────────
_sessions: Dict[str, CodeReviewEnv] = {}


# ── Request / Response models ─────────────────────────────────────────────────

class ResetRequest(BaseModel):
    task_id: str = "task_1_easy_bug_hunt"
    session_id: str | None = None

    model_config = {"extra": "allow"}


class ResetResponse(BaseModel):
    session_id: str
    observation: Observation


class StepRequest(BaseModel):
    session_id: str
    action: ReviewAction


class StepResponse(BaseModel):
    observation: Observation
    reward: StepReward
    done: bool
    info: Dict[str, Any]


# ── Routes ────────────────────────────────────────────────────────────────────

@app.get("/")
def root():
    return {
        "name": "CodeReviewEnv",
        "version": "1.0.0",
        "tasks": TASK_IDS,
        "spec": "OpenEnv v1",
    }


@app.get("/tasks")
def list_tasks():
    from tasks.task1_easy import get_task_config as t1
    from tasks.task2_medium import get_task_config as t2
    from tasks.task3_hard import get_task_config as t3
    tasks = []
    for fn in (t1, t2, t3):
        cfg = fn()
        tasks.append({
            "task_id": cfg["task_id"],
            "difficulty": cfg["difficulty"],
            "description": cfg["description"],
        })
    return {"tasks": tasks}

@app.post("/reset", response_model=ResetResponse)
def reset(req: ResetRequest | None = None):
    if req is None:
        req = ResetRequest()
    if req.task_id not in TASK_IDS:
        raise HTTPException(400, f"Unknown task_id {req.task_id!r}. Choose from {TASK_IDS}")
    session_id = req.session_id or str(uuid.uuid4())
    env = CodeReviewEnv()
    obs = env.reset(req.task_id)
    _sessions[session_id] = env
    return ResetResponse(session_id=session_id, observation=obs)


@app.post("/step", response_model=StepResponse)
def step(req: StepRequest):
    env = _sessions.get(req.session_id)
    if env is None:
        raise HTTPException(404, f"Session {req.session_id!r} not found. Call /reset first.")
    obs, reward, done, info = env.step(req.action)
    return StepResponse(observation=obs, reward=reward, done=done, info=info)


@app.get("/state/{session_id}", response_model=EnvironmentState)
def get_state(session_id: str):
    env = _sessions.get(session_id)
    if env is None:
        raise HTTPException(404, f"Session {session_id!r} not found.")
    return env.state()


@app.delete("/session/{session_id}")
def delete_session(session_id: str):
    _sessions.pop(session_id, None)
    return {"deleted": session_id}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run("server:app", host="0.0.0.0", port=7860, reload=False)