""" environment.py — OpenEnv-compliant environment wrapper for SWEbench-IN (Dockerless). All Docker container management removed. Each episode runs in a fresh temp directory managed by Simulator. """ import json import random from dataclasses import dataclass, field from tasks import TASKS, Task from simulator import Simulator from rewards import compute_reward, RewardBreakdown @dataclass class State: task_id: int = 0 step_count: int = 0 tests_passing_ratio: float = 0.0 server_running: bool = False files_correct: bool = False action_history: list = field(default_factory=list) reply_texts: list = field(default_factory=list) class SWEbenchINEnvironment: """ Dockerless RL environment for SWEbench-IN. Gym-style: reset() -> observation, step() -> (obs, reward, done, info) """ def __init__(self): self.simulator = Simulator() self.max_steps = 15 self._state = State() self._current_task: Task = None self._done = False # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def reset(self, task_id: int = None) -> dict: if task_id is None: task_id = random.choice(list(TASKS.keys())) if task_id not in TASKS: raise ValueError(f"Invalid task_id: {task_id}. Must be 1–5.") self._current_task = TASKS[task_id] self._done = False self._state = State(task_id=task_id) self.simulator.setup_task(task_id) self.max_steps = self._current_task.max_actions obs_text = self.simulator.get_initial_observation(task_id) return self._make_obs(obs_text) def step(self, action: dict) -> tuple: if self._done: return ( {"text": "Episode done. Call reset().", "step_count": self._state.step_count, "max_steps": self.max_steps, "tests_passing_ratio": 0.0, "server_running": False, "reward_breakdown": {}}, 0.0, True, {"error": "episode_done"}, ) action_type = action.get("type", "") action_args = action.get("args", "") content = action.get("content", "") # for write_file # Snapshot state before action state_before = State( task_id=self._state.task_id, step_count=self._state.step_count, tests_passing_ratio=self._state.tests_passing_ratio, server_running=self._state.server_running, files_correct=self._state.files_correct, action_history=list(self._state.action_history), reply_texts=list(self._state.reply_texts), ) # Execute action obs_text = self._dispatch(action_type, action_args, content) # Update state self._state.action_history.append(f"{action_type}: {action_args}") self._state.step_count += 1 # Only update measurements on state-changing actions (lazy updates) if action_type in ("run_tests", "run_command", "write_file", "check_server", "close_case"): self._update_state() # Check done if action_type == "close_case" or self._state.step_count >= self.max_steps: self._done = True # Compute reward breakdown = compute_reward( container_id=None, action_history=self._state.action_history, state_before=state_before, state_after=self._state, output_dir=self.simulator.output_dir, task_id=self._state.task_id, work_dir=self.simulator.work_dir, ) # Boost technical reward using live state (pytest ratio already updated) adjusted_total = ( breakdown.technical + 0.5 * self._state.tests_passing_ratio # live pytest score + 0.8 * breakdown.boundaries + 0.5 * breakdown.communication + (0.6 * breakdown.leave_protection if self._state.task_id == 5 else 0.0) + 0.3 * breakdown.shaping ) info = { "reward_breakdown": { "technical": breakdown.technical, "boundaries": breakdown.boundaries, "communication": breakdown.communication, "leave_protection": breakdown.leave_protection, "shaping": breakdown.shaping, }, "step_count": self._state.step_count, "max_steps": self.max_steps, "done_reason": ( "close_case" if action_type == "close_case" else "max_steps" if self._state.step_count >= self.max_steps else None ), } return (self._make_obs(obs_text), adjusted_total, self._done, info) def state(self) -> State: return self._state def grade(self) -> dict: """Summary grade for the completed episode.""" return { "task_id": self._state.task_id, "steps_taken": self._state.step_count, "tests_passing_ratio": self._state.tests_passing_ratio, "server_running": self._state.server_running, "files_correct": self._state.files_correct, "total_reward_approx": ( float(self._state.server_running) + self._state.tests_passing_ratio * 0.5 + float(self._state.files_correct) * 0.3 ), } # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ ACTION_HANDLERS = { "run_command", "read_file", "write_file", "run_tests", "check_server", "reply_slack", "reply_email", "reply_hr", "close_case", } def _dispatch(self, action_type: str, action_args: str, content: str = "") -> str: if action_type not in self.ACTION_HANDLERS: return ( f"ERROR: Unknown action '{action_type}'. " f"Valid: {sorted(self.ACTION_HANDLERS)}" ) if action_type == "run_command": return self.simulator.run_bash(action_args) if action_type == "read_file": return self.simulator.read_file(action_args) if action_type == "write_file": # Support both "path|content" and separate content field if content: return self.simulator.write_file(action_args, content) if "|" in action_args: path, file_content = action_args.split("|", 1) return self.simulator.write_file(path.strip(), file_content) return "ERROR: write_file needs 'path|content' or a content field." if action_type == "run_tests": r = self.simulator.run_pytest() return ( f"Pytest Results:\n" f" Passed: {r['passed']}\n" f" Failed: {r['failed']}\n" f" Ratio: {r['ratio']:.0%}\n\n" f"Output:\n{r['output']}" ) if action_type == "check_server": r = self.simulator.curl_server() return ( f"Server Check:\n" f" Status Code: {r['status_code']}\n" f" Success: {r['success']}" ) if action_type == "reply_slack": result = self.simulator.write_reply("SLACK", action_args) self._state.reply_texts.append(f"[SLACK]: {action_args}") return result if action_type == "reply_email": result = self.simulator.write_reply("EMAIL", action_args) self._state.reply_texts.append(f"[EMAIL]: {action_args}") return result if action_type == "reply_hr": result = self.simulator.write_reply("HR", action_args) self._state.reply_texts.append(f"[HR]: {action_args}") return result if action_type == "close_case": return "Case closed. Episode ending." return "ERROR: Dispatch failed." def _update_state(self): """Refresh state measurements from live environment (non-blocking with error handling).""" import os # Update server status try: server = self.simulator.curl_server() self._state.server_running = server["success"] except Exception: pass # Update test pass ratio try: tests = self.simulator.run_pytest() self._state.tests_passing_ratio = tests["ratio"] except Exception: pass # Update file correctness try: reply_path = os.path.join(self.simulator.output_dir, "reply.txt") self._state.files_correct = ( os.path.exists(reply_path) and os.path.getsize(reply_path) > 0 ) except Exception: pass @staticmethod def _make_obs(text: str) -> dict: """Wrap observation text in a dict for the REST API.""" return {"text": text}