swebench-ind / environment.py
YUS200619's picture
optimize: Make state updates lazy and non-blocking with comprehensive error handling
7f888a4
"""
environment.py — OpenEnv-compliant environment wrapper for SWEbench-IN (Dockerless).
All Docker container management removed. Each episode runs in a fresh
temp directory managed by Simulator.
"""
import json
import random
from dataclasses import dataclass, field
from tasks import TASKS, Task
from simulator import Simulator
from rewards import compute_reward, RewardBreakdown
@dataclass
class State:
task_id: int = 0
step_count: int = 0
tests_passing_ratio: float = 0.0
server_running: bool = False
files_correct: bool = False
action_history: list = field(default_factory=list)
reply_texts: list = field(default_factory=list)
class SWEbenchINEnvironment:
"""
Dockerless RL environment for SWEbench-IN.
Gym-style: reset() -> observation, step() -> (obs, reward, done, info)
"""
def __init__(self):
self.simulator = Simulator()
self.max_steps = 15
self._state = State()
self._current_task: Task = None
self._done = False
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def reset(self, task_id: int = None) -> dict:
if task_id is None:
task_id = random.choice(list(TASKS.keys()))
if task_id not in TASKS:
raise ValueError(f"Invalid task_id: {task_id}. Must be 1–5.")
self._current_task = TASKS[task_id]
self._done = False
self._state = State(task_id=task_id)
self.simulator.setup_task(task_id)
self.max_steps = self._current_task.max_actions
obs_text = self.simulator.get_initial_observation(task_id)
return self._make_obs(obs_text)
def step(self, action: dict) -> tuple:
if self._done:
return (
{"text": "Episode done. Call reset().", "step_count": self._state.step_count,
"max_steps": self.max_steps, "tests_passing_ratio": 0.0,
"server_running": False, "reward_breakdown": {}},
0.0, True, {"error": "episode_done"},
)
action_type = action.get("type", "")
action_args = action.get("args", "")
content = action.get("content", "") # for write_file
# Snapshot state before action
state_before = State(
task_id=self._state.task_id,
step_count=self._state.step_count,
tests_passing_ratio=self._state.tests_passing_ratio,
server_running=self._state.server_running,
files_correct=self._state.files_correct,
action_history=list(self._state.action_history),
reply_texts=list(self._state.reply_texts),
)
# Execute action
obs_text = self._dispatch(action_type, action_args, content)
# Update state
self._state.action_history.append(f"{action_type}: {action_args}")
self._state.step_count += 1
# Only update measurements on state-changing actions (lazy updates)
if action_type in ("run_tests", "run_command", "write_file", "check_server", "close_case"):
self._update_state()
# Check done
if action_type == "close_case" or self._state.step_count >= self.max_steps:
self._done = True
# Compute reward
breakdown = compute_reward(
container_id=None,
action_history=self._state.action_history,
state_before=state_before,
state_after=self._state,
output_dir=self.simulator.output_dir,
task_id=self._state.task_id,
work_dir=self.simulator.work_dir,
)
# Boost technical reward using live state (pytest ratio already updated)
adjusted_total = (
breakdown.technical
+ 0.5 * self._state.tests_passing_ratio # live pytest score
+ 0.8 * breakdown.boundaries
+ 0.5 * breakdown.communication
+ (0.6 * breakdown.leave_protection if self._state.task_id == 5 else 0.0)
+ 0.3 * breakdown.shaping
)
info = {
"reward_breakdown": {
"technical": breakdown.technical,
"boundaries": breakdown.boundaries,
"communication": breakdown.communication,
"leave_protection": breakdown.leave_protection,
"shaping": breakdown.shaping,
},
"step_count": self._state.step_count,
"max_steps": self.max_steps,
"done_reason": (
"close_case" if action_type == "close_case"
else "max_steps" if self._state.step_count >= self.max_steps
else None
),
}
return (self._make_obs(obs_text), adjusted_total, self._done, info)
def state(self) -> State:
return self._state
def grade(self) -> dict:
"""Summary grade for the completed episode."""
return {
"task_id": self._state.task_id,
"steps_taken": self._state.step_count,
"tests_passing_ratio": self._state.tests_passing_ratio,
"server_running": self._state.server_running,
"files_correct": self._state.files_correct,
"total_reward_approx": (
float(self._state.server_running)
+ self._state.tests_passing_ratio * 0.5
+ float(self._state.files_correct) * 0.3
),
}
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
ACTION_HANDLERS = {
"run_command", "read_file", "write_file", "run_tests",
"check_server", "reply_slack", "reply_email", "reply_hr", "close_case",
}
def _dispatch(self, action_type: str, action_args: str, content: str = "") -> str:
if action_type not in self.ACTION_HANDLERS:
return (
f"ERROR: Unknown action '{action_type}'. "
f"Valid: {sorted(self.ACTION_HANDLERS)}"
)
if action_type == "run_command":
return self.simulator.run_bash(action_args)
if action_type == "read_file":
return self.simulator.read_file(action_args)
if action_type == "write_file":
# Support both "path|content" and separate content field
if content:
return self.simulator.write_file(action_args, content)
if "|" in action_args:
path, file_content = action_args.split("|", 1)
return self.simulator.write_file(path.strip(), file_content)
return "ERROR: write_file needs 'path|content' or a content field."
if action_type == "run_tests":
r = self.simulator.run_pytest()
return (
f"Pytest Results:\n"
f" Passed: {r['passed']}\n"
f" Failed: {r['failed']}\n"
f" Ratio: {r['ratio']:.0%}\n\n"
f"Output:\n{r['output']}"
)
if action_type == "check_server":
r = self.simulator.curl_server()
return (
f"Server Check:\n"
f" Status Code: {r['status_code']}\n"
f" Success: {r['success']}"
)
if action_type == "reply_slack":
result = self.simulator.write_reply("SLACK", action_args)
self._state.reply_texts.append(f"[SLACK]: {action_args}")
return result
if action_type == "reply_email":
result = self.simulator.write_reply("EMAIL", action_args)
self._state.reply_texts.append(f"[EMAIL]: {action_args}")
return result
if action_type == "reply_hr":
result = self.simulator.write_reply("HR", action_args)
self._state.reply_texts.append(f"[HR]: {action_args}")
return result
if action_type == "close_case":
return "Case closed. Episode ending."
return "ERROR: Dispatch failed."
def _update_state(self):
"""Refresh state measurements from live environment (non-blocking with error handling)."""
import os
# Update server status
try:
server = self.simulator.curl_server()
self._state.server_running = server["success"]
except Exception:
pass
# Update test pass ratio
try:
tests = self.simulator.run_pytest()
self._state.tests_passing_ratio = tests["ratio"]
except Exception:
pass
# Update file correctness
try:
reply_path = os.path.join(self.simulator.output_dir, "reply.txt")
self._state.files_correct = (
os.path.exists(reply_path) and os.path.getsize(reply_path) > 0
)
except Exception:
pass
@staticmethod
def _make_obs(text: str) -> dict:
"""Wrap observation text in a dict for the REST API."""
return {"text": text}