Spaces:
Sleeping
Sleeping
| """ | |
| environment.py — OpenEnv-compliant environment wrapper for SWEbench-IN (Dockerless). | |
| All Docker container management removed. Each episode runs in a fresh | |
| temp directory managed by Simulator. | |
| """ | |
| import json | |
| import random | |
| from dataclasses import dataclass, field | |
| from tasks import TASKS, Task | |
| from simulator import Simulator | |
| from rewards import compute_reward, RewardBreakdown | |
| class State: | |
| task_id: int = 0 | |
| step_count: int = 0 | |
| tests_passing_ratio: float = 0.0 | |
| server_running: bool = False | |
| files_correct: bool = False | |
| action_history: list = field(default_factory=list) | |
| reply_texts: list = field(default_factory=list) | |
| class SWEbenchINEnvironment: | |
| """ | |
| Dockerless RL environment for SWEbench-IN. | |
| Gym-style: reset() -> observation, step() -> (obs, reward, done, info) | |
| """ | |
| def __init__(self): | |
| self.simulator = Simulator() | |
| self.max_steps = 15 | |
| self._state = State() | |
| self._current_task: Task = None | |
| self._done = False | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def reset(self, task_id: int = None) -> dict: | |
| if task_id is None: | |
| task_id = random.choice(list(TASKS.keys())) | |
| if task_id not in TASKS: | |
| raise ValueError(f"Invalid task_id: {task_id}. Must be 1–5.") | |
| self._current_task = TASKS[task_id] | |
| self._done = False | |
| self._state = State(task_id=task_id) | |
| self.simulator.setup_task(task_id) | |
| self.max_steps = self._current_task.max_actions | |
| obs_text = self.simulator.get_initial_observation(task_id) | |
| return self._make_obs(obs_text) | |
| def step(self, action: dict) -> tuple: | |
| if self._done: | |
| return ( | |
| {"text": "Episode done. Call reset().", "step_count": self._state.step_count, | |
| "max_steps": self.max_steps, "tests_passing_ratio": 0.0, | |
| "server_running": False, "reward_breakdown": {}}, | |
| 0.0, True, {"error": "episode_done"}, | |
| ) | |
| action_type = action.get("type", "") | |
| action_args = action.get("args", "") | |
| content = action.get("content", "") # for write_file | |
| # Snapshot state before action | |
| state_before = State( | |
| task_id=self._state.task_id, | |
| step_count=self._state.step_count, | |
| tests_passing_ratio=self._state.tests_passing_ratio, | |
| server_running=self._state.server_running, | |
| files_correct=self._state.files_correct, | |
| action_history=list(self._state.action_history), | |
| reply_texts=list(self._state.reply_texts), | |
| ) | |
| # Execute action | |
| obs_text = self._dispatch(action_type, action_args, content) | |
| # Update state | |
| self._state.action_history.append(f"{action_type}: {action_args}") | |
| self._state.step_count += 1 | |
| # Only update measurements on state-changing actions (lazy updates) | |
| if action_type in ("run_tests", "run_command", "write_file", "check_server", "close_case"): | |
| self._update_state() | |
| # Check done | |
| if action_type == "close_case" or self._state.step_count >= self.max_steps: | |
| self._done = True | |
| # Compute reward | |
| breakdown = compute_reward( | |
| container_id=None, | |
| action_history=self._state.action_history, | |
| state_before=state_before, | |
| state_after=self._state, | |
| output_dir=self.simulator.output_dir, | |
| task_id=self._state.task_id, | |
| work_dir=self.simulator.work_dir, | |
| ) | |
| # Boost technical reward using live state (pytest ratio already updated) | |
| adjusted_total = ( | |
| breakdown.technical | |
| + 0.5 * self._state.tests_passing_ratio # live pytest score | |
| + 0.8 * breakdown.boundaries | |
| + 0.5 * breakdown.communication | |
| + (0.6 * breakdown.leave_protection if self._state.task_id == 5 else 0.0) | |
| + 0.3 * breakdown.shaping | |
| ) | |
| info = { | |
| "reward_breakdown": { | |
| "technical": breakdown.technical, | |
| "boundaries": breakdown.boundaries, | |
| "communication": breakdown.communication, | |
| "leave_protection": breakdown.leave_protection, | |
| "shaping": breakdown.shaping, | |
| }, | |
| "step_count": self._state.step_count, | |
| "max_steps": self.max_steps, | |
| "done_reason": ( | |
| "close_case" if action_type == "close_case" | |
| else "max_steps" if self._state.step_count >= self.max_steps | |
| else None | |
| ), | |
| } | |
| return (self._make_obs(obs_text), adjusted_total, self._done, info) | |
| def state(self) -> State: | |
| return self._state | |
| def grade(self) -> dict: | |
| """Summary grade for the completed episode.""" | |
| return { | |
| "task_id": self._state.task_id, | |
| "steps_taken": self._state.step_count, | |
| "tests_passing_ratio": self._state.tests_passing_ratio, | |
| "server_running": self._state.server_running, | |
| "files_correct": self._state.files_correct, | |
| "total_reward_approx": ( | |
| float(self._state.server_running) | |
| + self._state.tests_passing_ratio * 0.5 | |
| + float(self._state.files_correct) * 0.3 | |
| ), | |
| } | |
| # ------------------------------------------------------------------ | |
| # Internal | |
| # ------------------------------------------------------------------ | |
| ACTION_HANDLERS = { | |
| "run_command", "read_file", "write_file", "run_tests", | |
| "check_server", "reply_slack", "reply_email", "reply_hr", "close_case", | |
| } | |
| def _dispatch(self, action_type: str, action_args: str, content: str = "") -> str: | |
| if action_type not in self.ACTION_HANDLERS: | |
| return ( | |
| f"ERROR: Unknown action '{action_type}'. " | |
| f"Valid: {sorted(self.ACTION_HANDLERS)}" | |
| ) | |
| if action_type == "run_command": | |
| return self.simulator.run_bash(action_args) | |
| if action_type == "read_file": | |
| return self.simulator.read_file(action_args) | |
| if action_type == "write_file": | |
| # Support both "path|content" and separate content field | |
| if content: | |
| return self.simulator.write_file(action_args, content) | |
| if "|" in action_args: | |
| path, file_content = action_args.split("|", 1) | |
| return self.simulator.write_file(path.strip(), file_content) | |
| return "ERROR: write_file needs 'path|content' or a content field." | |
| if action_type == "run_tests": | |
| r = self.simulator.run_pytest() | |
| return ( | |
| f"Pytest Results:\n" | |
| f" Passed: {r['passed']}\n" | |
| f" Failed: {r['failed']}\n" | |
| f" Ratio: {r['ratio']:.0%}\n\n" | |
| f"Output:\n{r['output']}" | |
| ) | |
| if action_type == "check_server": | |
| r = self.simulator.curl_server() | |
| return ( | |
| f"Server Check:\n" | |
| f" Status Code: {r['status_code']}\n" | |
| f" Success: {r['success']}" | |
| ) | |
| if action_type == "reply_slack": | |
| result = self.simulator.write_reply("SLACK", action_args) | |
| self._state.reply_texts.append(f"[SLACK]: {action_args}") | |
| return result | |
| if action_type == "reply_email": | |
| result = self.simulator.write_reply("EMAIL", action_args) | |
| self._state.reply_texts.append(f"[EMAIL]: {action_args}") | |
| return result | |
| if action_type == "reply_hr": | |
| result = self.simulator.write_reply("HR", action_args) | |
| self._state.reply_texts.append(f"[HR]: {action_args}") | |
| return result | |
| if action_type == "close_case": | |
| return "Case closed. Episode ending." | |
| return "ERROR: Dispatch failed." | |
| def _update_state(self): | |
| """Refresh state measurements from live environment (non-blocking with error handling).""" | |
| import os | |
| # Update server status | |
| try: | |
| server = self.simulator.curl_server() | |
| self._state.server_running = server["success"] | |
| except Exception: | |
| pass | |
| # Update test pass ratio | |
| try: | |
| tests = self.simulator.run_pytest() | |
| self._state.tests_passing_ratio = tests["ratio"] | |
| except Exception: | |
| pass | |
| # Update file correctness | |
| try: | |
| reply_path = os.path.join(self.simulator.output_dir, "reply.txt") | |
| self._state.files_correct = ( | |
| os.path.exists(reply_path) and os.path.getsize(reply_path) > 0 | |
| ) | |
| except Exception: | |
| pass | |
| def _make_obs(text: str) -> dict: | |
| """Wrap observation text in a dict for the REST API.""" | |
| return {"text": text} | |