| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
| import uvicorn |
| from fastapi import Body, FastAPI, HTTPException |
| from pydantic import BaseModel, Field, ValidationError |
|
|
| from openenv.core.env_server.http_server import ( |
| HealthResponse, |
| JsonRpcErrorCode, |
| JsonRpcResponse, |
| ResetRequest, |
| ResetResponse, |
| SchemaResponse, |
| StepRequest, |
| StepResponse, |
| ) |
|
|
| from swarm_openenv_env.environment import IncidentResponseEnv |
| from swarm_openenv_env.models import ( |
| IncidentAction, |
| IncidentObservation, |
| IncidentState, |
| ) |
| from swarm_openenv_env.tasks import get_task |
|
|
| from inference import ( |
| ALLOW_SCRIPTED_BASELINE, |
| MODEL_NAME, |
| TRAINED_GGUF_PATH, |
| LOCAL_OPENAI_BASE_URL, |
| available_provider_names, |
| clean_error_message, |
| compact_action_dict, |
| create_clients, |
| detect_provider, |
| effective_prompt_for_task, |
| get_recommended_prompts, |
| llm_action, |
| select_task_id, |
| ) |
|
|
|
|
| app = FastAPI( |
| title="Swarm Incident Response OpenEnv", |
| version="1.0.0", |
| description=( |
| "HTTP server for a real-world incident-response OpenEnv benchmark with " |
| "reset/step/state endpoints." |
| ), |
| ) |
|
|
| _env = IncidentResponseEnv() |
|
|
|
|
| def _serialize_observation(observation: IncidentObservation) -> dict: |
| return observation.model_dump() |
|
|
|
|
| class PromptRunRequest(BaseModel): |
| prompt: str = Field( |
| default="", |
| description="High-level mission brief or incident prompt for the live runner.", |
| ) |
| task_id: Optional[str] = Field( |
| default=None, |
| description="Optional explicit task override.", |
| ) |
|
|
|
|
| class PromptRunStep(BaseModel): |
| step: int |
| action: dict[str, Any] |
| reward: float |
| done: bool |
| feedback: str |
|
|
|
|
| class PromptRunResponse(BaseModel): |
| provider: str |
| model: str |
| task_id: str |
| prompt: str |
| success: bool |
| score: float |
| steps: int |
| error: Optional[str] = None |
| trajectory: list[PromptRunStep] |
|
|
|
|
| @app.get("/health", response_model=HealthResponse) |
| def health() -> HealthResponse: |
| return HealthResponse() |
|
|
|
|
| @app.get("/metadata") |
| def metadata() -> dict: |
| metadata = _env.get_metadata().model_dump() |
| metadata["recommended_prompts"] = get_recommended_prompts() |
| metadata["model_runtime"] = { |
| "provider": detect_provider() or "none", |
| "provider_chain": available_provider_names(), |
| "model": MODEL_NAME, |
| "local_openai_base_url": LOCAL_OPENAI_BASE_URL, |
| "trained_model_path": str(TRAINED_GGUF_PATH), |
| "trained_model_present": TRAINED_GGUF_PATH.exists(), |
| } |
| metadata["validator_runtime"] = _env.get_validator_runtime() |
| readme_path = Path(__file__).resolve().parent.parent / "README.md" |
| if readme_path.exists(): |
| metadata["readme_content"] = readme_path.read_text(encoding="utf-8") |
| return metadata |
|
|
|
|
| @app.get("/schema", response_model=SchemaResponse) |
| def schema() -> SchemaResponse: |
| return SchemaResponse( |
| action=IncidentAction.model_json_schema(), |
| observation=IncidentObservation.model_json_schema(), |
| state=IncidentState.model_json_schema(), |
| ) |
|
|
|
|
| @app.get("/demo-prompt") |
| def demo_prompt() -> dict: |
| return { |
| "provider": detect_provider() or "none", |
| "provider_chain": available_provider_names(), |
| "model": MODEL_NAME, |
| "trained_model_path": str(TRAINED_GGUF_PATH), |
| "trained_model_present": TRAINED_GGUF_PATH.exists(), |
| "local_openai_base_url": LOCAL_OPENAI_BASE_URL, |
| "validator_runtime": _env.get_validator_runtime(), |
| "recommended_prompts": get_recommended_prompts(), |
| } |
|
|
|
|
| @app.get("/state", response_model=IncidentState) |
| def state() -> IncidentState: |
| return _env.state |
|
|
|
|
| @app.post("/reset", response_model=ResetResponse) |
| def reset(request: ResetRequest = Body(default_factory=ResetRequest)) -> ResetResponse: |
| payload = request.model_dump(exclude_none=True) |
| observation = _env.reset(**payload) |
| return ResetResponse( |
| observation=_serialize_observation(observation), |
| reward=float(observation.reward or 0.0), |
| done=bool(observation.done), |
| ) |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(request: StepRequest) -> StepResponse: |
| try: |
| action = IncidentAction.model_validate(request.action) |
| except ValidationError as exc: |
| raise HTTPException(status_code=422, detail=exc.errors()) from exc |
|
|
| observation = _env.step(action, timeout_s=request.timeout_s) |
| return StepResponse( |
| observation=_serialize_observation(observation), |
| reward=float(observation.reward or 0.0), |
| done=bool(observation.done), |
| ) |
|
|
|
|
| @app.post("/run", response_model=PromptRunResponse) |
| def run_prompt(request: PromptRunRequest) -> PromptRunResponse: |
| clients = create_clients() |
| if not clients and not ALLOW_SCRIPTED_BASELINE: |
| raise HTTPException( |
| status_code=503, |
| detail=( |
| "No live provider configured. Start your local OpenAI-compatible runtime " |
| "for the trained model, or set OPENAI_API_KEY / GEMINI_API_KEY for a real " |
| "prompt-driven run." |
| ), |
| ) |
|
|
| resolved_task_id = request.task_id or select_task_id(clients, request.prompt) |
| if not resolved_task_id: |
| resolved_task_id = _env.default_task_id |
|
|
| mission_prompt = effective_prompt_for_task(resolved_task_id, request.prompt) |
| task = get_task(resolved_task_id) |
| env = IncidentResponseEnv(default_task_id=resolved_task_id) |
| observation = env.reset(task_id=resolved_task_id, prompt=mission_prompt) |
| history: list[str] = [] |
| trajectory: list[PromptRunStep] = [] |
| runtime_error: Optional[str] = None |
|
|
| for step_index in range(task.max_steps): |
| try: |
| action = llm_action( |
| clients, |
| resolved_task_id, |
| observation, |
| history, |
| step_index, |
| mission_prompt, |
| ) |
| action_payload = compact_action_dict(action) |
| observation = env.step(action) |
| reward = float(observation.reward or 0.0) |
| step_no = step_index + 1 |
| trajectory.append( |
| PromptRunStep( |
| step=step_no, |
| action=action_payload, |
| reward=reward, |
| done=bool(observation.done), |
| feedback=observation.last_feedback, |
| ) |
| ) |
| history.append( |
| f"step={step_no} action={action_payload} score={reward:.3f} " |
| f"feedback={observation.last_feedback}" |
| ) |
| if observation.done: |
| break |
| except Exception as exc: |
| runtime_error = clean_error_message(exc) |
| break |
|
|
| score = round(float(env.state.current_score), 3) |
| success = score >= task.success_threshold and bool(env.state.resolved) |
| return PromptRunResponse( |
| provider=detect_provider() or "none", |
| model=MODEL_NAME, |
| task_id=resolved_task_id, |
| prompt=mission_prompt, |
| success=success, |
| score=score, |
| steps=len(trajectory), |
| error=runtime_error, |
| trajectory=trajectory, |
| ) |
|
|
|
|
| @app.post("/mcp") |
| def mcp() -> dict: |
| return JsonRpcResponse.error_response( |
| JsonRpcErrorCode.METHOD_NOT_FOUND, |
| "This environment exposes the standard OpenEnv HTTP API only.", |
| ).model_dump() |
|
|
|
|
| def main() -> None: |
| port = int(os.getenv("PORT", "7860")) |
| uvicorn.run("server.app:app", host="0.0.0.0", port=port, reload=False) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|