File size: 7,179 Bytes
fdce872
 
 
 
 
83ea4bd
 
fdce872
 
 
83ea4bd
 
fdce872
 
 
 
b28acab
fdce872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ea4bd
fdce872
 
83ea4bd
 
fdce872
 
 
 
 
 
 
 
83ea4bd
fdce872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ea4bd
fdce872
 
 
83ea4bd
fdce872
83ea4bd
fdce872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ea4bd
fdce872
 
83ea4bd
fdce872
 
83ea4bd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
SWEbench-IN Environment Implementation for OpenEnv server.

Wraps the SWEbench-IN environment logic into the OpenEnv
Environment interface (reset/step/state).

Dockerless: No container management, uses local temp directories.
"""

from uuid import uuid4
import random
import os

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

from models import SWEbenchINAction, SWEbenchINObservation
from tasks import TASKS
from simulator import Simulator
from rewards import compute_reward

from dataclasses import dataclass, field


@dataclass
class EnvState:
    """Internal environment state tracking."""
    task_id: int = 0
    step_count: int = 0
    tests_passing_ratio: float = 0.0
    server_running: bool = False
    files_correct: bool = False
    action_history: list = field(default_factory=list)
    reply_texts: list = field(default_factory=list)


class SWEbenchINEnvironment(Environment):
    """
    OpenEnv-compliant SWEbench-IN environment (Dockerless).

    Trains an LLM agent to fix broken Linux systems while managing
    stakeholder communication simultaneously. Uses local temp directories
    instead of Docker containers.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        """Initialize the SWEbench-IN environment."""
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._env_state = EnvState()
        self._simulator = Simulator()
        self._current_task = None
        self._max_steps = 15
        self._done = False

    def reset(self) -> SWEbenchINObservation:
        """Reset the environment to a new episode."""
        # Sample a random task
        task_id = random.choice(list(TASKS.keys()))

        self._current_task = TASKS[task_id]
        self._done = False
        self._max_steps = self._current_task.max_actions
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._env_state = EnvState(task_id=task_id)
        self._simulator.setup_task(task_id)

        obs_text = self._simulator.get_initial_observation(task_id)

        return SWEbenchINObservation(
            text=obs_text,
            reward=0.0,
            done=False,
            step_count=0,
            max_steps=self._max_steps,
            tests_passing_ratio=0.0,
            server_running=False,
        )

    def step(self, action: SWEbenchINAction) -> SWEbenchINObservation:
        """Execute a step in the environment."""
        if self._done:
            return SWEbenchINObservation(
                text="Episode is done. Call reset() to start a new episode.",
                reward=0.0,
                done=True,
                step_count=self._state.step_count,
                max_steps=self._max_steps,
            )

        action_type = action.type
        action_args = action.args

        # Record state before
        state_before = EnvState(
            task_id=self._env_state.task_id,
            step_count=self._env_state.step_count,
            tests_passing_ratio=self._env_state.tests_passing_ratio,
            server_running=self._env_state.server_running,
            files_correct=self._env_state.files_correct,
        )

        # Dispatch action
        obs_text = self._dispatch_action(action_type, action_args)

        # Update state
        self._env_state.action_history.append(f"{action_type}: {action_args}")
        self._env_state.step_count += 1
        self._state.step_count += 1
        self._update_measurements()

        # Check done
        if action_type == "close_case" or self._env_state.step_count >= self._max_steps:
            self._done = True

        # Compute reward
        reward_breakdown = compute_reward(
            container_id=None,
            action_history=self._env_state.action_history,
            state_before=state_before,
            state_after=self._env_state,
            output_dir=self._simulator.output_dir,
            task_id=self._env_state.task_id,
            work_dir=self._simulator.work_dir,
        )

        return SWEbenchINObservation(
            text=obs_text,
            reward=reward_breakdown.total,
            done=self._done,
            step_count=self._env_state.step_count,
            max_steps=self._max_steps,
            tests_passing_ratio=self._env_state.tests_passing_ratio,
            server_running=self._env_state.server_running,
            reward_breakdown={
                "technical": reward_breakdown.technical,
                "boundaries": reward_breakdown.boundaries,
                "communication": reward_breakdown.communication,
                "leave_protection": reward_breakdown.leave_protection,
                "shaping": reward_breakdown.shaping,
            },
        )

    def state(self) -> State:
        """Get the current environment state."""
        return self._state

    # --- Internal helpers ---

    VALID_ACTIONS = {
        "run_command", "read_file", "write_file", "run_tests",
        "check_server", "reply_slack", "reply_email", "reply_hr", "close_case",
    }

    def _dispatch_action(self, action_type: str, action_args: str) -> str:
        """Dispatch an action to the simulator."""
        if action_type not in self.VALID_ACTIONS:
            return f"ERROR: Unknown action '{action_type}'"

        if action_type == "run_command":
            return self._simulator.run_bash(action_args)
        elif action_type == "read_file":
            return self._simulator.read_file(action_args)
        elif action_type == "write_file":
            if "|" in action_args:
                path, content = action_args.split("|", 1)
                return self._simulator.write_file(path.strip(), content)
            return "ERROR: write_file args must be 'path|content'"
        elif action_type == "run_tests":
            result = self._simulator.run_pytest()
            return f"Passed: {result['passed']}, Failed: {result['failed']}, Ratio: {result['ratio']:.0%}\n{result['output']}"
        elif action_type == "check_server":
            result = self._simulator.curl_server()
            return f"Status: {result['status_code']}, Success: {result['success']}"
        elif action_type in ("reply_slack", "reply_email", "reply_hr"):
            recipient = action_type.replace("reply_", "").upper()
            self._env_state.reply_texts.append(f"[{recipient}]: {action_args}")
            return self._simulator.write_reply(recipient, action_args)
        elif action_type == "close_case":
            return "Case closed. Episode ending."
        return "ERROR: dispatch failed"

    def _update_measurements(self):
        """Update state measurements from live environment."""
        server_result = self._simulator.curl_server()
        self._env_state.server_running = server_result["success"]
        
        test_result = self._simulator.run_pytest()
        self._env_state.tests_passing_ratio = test_result["ratio"]
        
        reply_path = os.path.join(self._simulator.output_dir, "reply.txt")
        self._env_state.files_correct = (
            os.path.exists(reply_path) and os.path.getsize(reply_path) > 0
        )