swebench-ind / client.py
YUS200619's picture
refresh
4dfc26d
"""SWEbench-IN Environment Client.
Provides a Python client for connecting to a running SWEbench-IN server
via HTTP/WebSocket, enabling efficient multi-step RL interactions.
"""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import SWEbenchINAction, SWEbenchINObservation
class SWEbenchINEnv(
EnvClient[SWEbenchINAction, SWEbenchINObservation, State]
):
"""
Client for the SWEbench-IN Environment.
Maintains a persistent WebSocket connection to the environment server.
Each client instance has its own dedicated environment session.
Example:
>>> with SWEbenchINEnv(base_url="http://localhost:7860") as client:
... result = client.reset()
... print(result.observation.text)
...
... action = SWEbenchINAction(type="run_command", args="ls /home/user2")
... result = client.step(action)
... print(result.observation.text)
"""
def _step_payload(self, action: SWEbenchINAction) -> Dict:
"""Convert SWEbenchINAction to JSON payload for step message."""
return {
"type": action.type,
"args": action.args,
}
def _parse_result(self, payload: Dict) -> StepResult[SWEbenchINObservation]:
"""Parse server response into StepResult[SWEbenchINObservation]."""
obs_data = payload.get("observation", {})
observation = SWEbenchINObservation(
text=obs_data.get("text", ""),
reward=payload.get("reward"),
done=payload.get("done", False),
step_count=obs_data.get("step_count", 0),
max_steps=obs_data.get("max_steps", 15),
tests_passing_ratio=obs_data.get("tests_passing_ratio", 0.0),
server_running=obs_data.get("server_running", False),
reward_breakdown=obs_data.get("reward_breakdown", {}),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
"""Parse server response into State object."""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)