swarm-os / server /app.py
aryxn323's picture
Initial Space deployment with llama-cpp + React dashboard
8892a6c
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Optional
import uvicorn
from fastapi import Body, FastAPI, HTTPException
from pydantic import BaseModel, Field, ValidationError
from openenv.core.env_server.http_server import (
HealthResponse,
JsonRpcErrorCode,
JsonRpcResponse,
ResetRequest,
ResetResponse,
SchemaResponse,
StepRequest,
StepResponse,
)
from swarm_openenv_env.environment import IncidentResponseEnv
from swarm_openenv_env.models import (
IncidentAction,
IncidentObservation,
IncidentState,
)
from swarm_openenv_env.tasks import get_task
from inference import (
ALLOW_SCRIPTED_BASELINE,
MODEL_NAME,
TRAINED_GGUF_PATH,
LOCAL_OPENAI_BASE_URL,
available_provider_names,
clean_error_message,
compact_action_dict,
create_clients,
detect_provider,
effective_prompt_for_task,
get_recommended_prompts,
llm_action,
select_task_id,
)
app = FastAPI(
title="Swarm Incident Response OpenEnv",
version="1.0.0",
description=(
"HTTP server for a real-world incident-response OpenEnv benchmark with "
"reset/step/state endpoints."
),
)
_env = IncidentResponseEnv()
def _serialize_observation(observation: IncidentObservation) -> dict:
return observation.model_dump()
class PromptRunRequest(BaseModel):
prompt: str = Field(
default="",
description="High-level mission brief or incident prompt for the live runner.",
)
task_id: Optional[str] = Field(
default=None,
description="Optional explicit task override.",
)
class PromptRunStep(BaseModel):
step: int
action: dict[str, Any]
reward: float
done: bool
feedback: str
class PromptRunResponse(BaseModel):
provider: str
model: str
task_id: str
prompt: str
success: bool
score: float
steps: int
error: Optional[str] = None
trajectory: list[PromptRunStep]
@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
return HealthResponse()
@app.get("/metadata")
def metadata() -> dict:
metadata = _env.get_metadata().model_dump()
metadata["recommended_prompts"] = get_recommended_prompts()
metadata["model_runtime"] = {
"provider": detect_provider() or "none",
"provider_chain": available_provider_names(),
"model": MODEL_NAME,
"local_openai_base_url": LOCAL_OPENAI_BASE_URL,
"trained_model_path": str(TRAINED_GGUF_PATH),
"trained_model_present": TRAINED_GGUF_PATH.exists(),
}
metadata["validator_runtime"] = _env.get_validator_runtime()
readme_path = Path(__file__).resolve().parent.parent / "README.md"
if readme_path.exists():
metadata["readme_content"] = readme_path.read_text(encoding="utf-8")
return metadata
@app.get("/schema", response_model=SchemaResponse)
def schema() -> SchemaResponse:
return SchemaResponse(
action=IncidentAction.model_json_schema(),
observation=IncidentObservation.model_json_schema(),
state=IncidentState.model_json_schema(),
)
@app.get("/demo-prompt")
def demo_prompt() -> dict:
return {
"provider": detect_provider() or "none",
"provider_chain": available_provider_names(),
"model": MODEL_NAME,
"trained_model_path": str(TRAINED_GGUF_PATH),
"trained_model_present": TRAINED_GGUF_PATH.exists(),
"local_openai_base_url": LOCAL_OPENAI_BASE_URL,
"validator_runtime": _env.get_validator_runtime(),
"recommended_prompts": get_recommended_prompts(),
}
@app.get("/state", response_model=IncidentState)
def state() -> IncidentState:
return _env.state
@app.post("/reset", response_model=ResetResponse)
def reset(request: ResetRequest = Body(default_factory=ResetRequest)) -> ResetResponse:
payload = request.model_dump(exclude_none=True)
observation = _env.reset(**payload)
return ResetResponse(
observation=_serialize_observation(observation),
reward=float(observation.reward or 0.0),
done=bool(observation.done),
)
@app.post("/step", response_model=StepResponse)
def step(request: StepRequest) -> StepResponse:
try:
action = IncidentAction.model_validate(request.action)
except ValidationError as exc:
raise HTTPException(status_code=422, detail=exc.errors()) from exc
observation = _env.step(action, timeout_s=request.timeout_s)
return StepResponse(
observation=_serialize_observation(observation),
reward=float(observation.reward or 0.0),
done=bool(observation.done),
)
@app.post("/run", response_model=PromptRunResponse)
def run_prompt(request: PromptRunRequest) -> PromptRunResponse:
clients = create_clients()
if not clients and not ALLOW_SCRIPTED_BASELINE:
raise HTTPException(
status_code=503,
detail=(
"No live provider configured. Start your local OpenAI-compatible runtime "
"for the trained model, or set OPENAI_API_KEY / GEMINI_API_KEY for a real "
"prompt-driven run."
),
)
resolved_task_id = request.task_id or select_task_id(clients, request.prompt)
if not resolved_task_id:
resolved_task_id = _env.default_task_id
mission_prompt = effective_prompt_for_task(resolved_task_id, request.prompt)
task = get_task(resolved_task_id)
env = IncidentResponseEnv(default_task_id=resolved_task_id)
observation = env.reset(task_id=resolved_task_id, prompt=mission_prompt)
history: list[str] = []
trajectory: list[PromptRunStep] = []
runtime_error: Optional[str] = None
for step_index in range(task.max_steps):
try:
action = llm_action(
clients,
resolved_task_id,
observation,
history,
step_index,
mission_prompt,
)
action_payload = compact_action_dict(action)
observation = env.step(action)
reward = float(observation.reward or 0.0)
step_no = step_index + 1
trajectory.append(
PromptRunStep(
step=step_no,
action=action_payload,
reward=reward,
done=bool(observation.done),
feedback=observation.last_feedback,
)
)
history.append(
f"step={step_no} action={action_payload} score={reward:.3f} "
f"feedback={observation.last_feedback}"
)
if observation.done:
break
except Exception as exc:
runtime_error = clean_error_message(exc)
break
score = round(float(env.state.current_score), 3)
success = score >= task.success_threshold and bool(env.state.resolved)
return PromptRunResponse(
provider=detect_provider() or "none",
model=MODEL_NAME,
task_id=resolved_task_id,
prompt=mission_prompt,
success=success,
score=score,
steps=len(trajectory),
error=runtime_error,
trajectory=trajectory,
)
@app.post("/mcp")
def mcp() -> dict:
return JsonRpcResponse.error_response(
JsonRpcErrorCode.METHOD_NOT_FOUND,
"This environment exposes the standard OpenEnv HTTP API only.",
).model_dump()
def main() -> None:
port = int(os.getenv("PORT", "7860"))
uvicorn.run("server.app:app", host="0.0.0.0", port=port, reload=False)
if __name__ == "__main__":
main()