modelbuilderhq commited on
Commit
df943c4
·
verified ·
1 Parent(s): 726cf7a

Upload folder using huggingface_hub

Browse files
supportdesk_env/server/app.py CHANGED
@@ -3,24 +3,22 @@
3
  from __future__ import annotations
4
 
5
  import os
6
- import threading
7
  from typing import Any
8
 
9
  import uvicorn
10
- from fastapi import Body
11
- from fastapi.routing import APIRoute
12
 
13
  try:
14
- from openenv.core.env_server.http_server import create_app
15
- from openenv.core.env_server.types import ResetRequest, ResetResponse, StepRequest, StepResponse
16
  except ImportError: # pragma: no cover - package name differs across releases
17
- from openenv_core.env_server.http_server import create_app
18
- from openenv_core.env_server.types import ResetRequest, ResetResponse, StepRequest, StepResponse
19
 
20
  from supportdesk_env.models import SupportDeskAction, SupportDeskObservation, SupportDeskState
21
  from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
22
  from supportdesk_env.tasks import TASKS
23
 
 
 
 
24
  app = create_app(
25
  SupportDeskEnvironment,
26
  action_cls=SupportDeskAction,
@@ -28,82 +26,6 @@ app = create_app(
28
  env_name="supportdesk_env",
29
  )
30
 
31
- _http_session_lock = threading.RLock()
32
- _http_session_env: SupportDeskEnvironment | None = None
33
-
34
-
35
- def _remove_stateless_http_routes() -> None:
36
- """Replace OpenEnv's stateless HTTP simulation routes with a persistent session."""
37
-
38
- kept_routes = []
39
- for route in app.router.routes:
40
- if isinstance(route, APIRoute):
41
- methods = route.methods or set()
42
- if route.path == "/reset" and "POST" in methods:
43
- continue
44
- if route.path == "/step" and "POST" in methods:
45
- continue
46
- if route.path == "/state" and "GET" in methods:
47
- continue
48
- kept_routes.append(route)
49
- app.router.routes = kept_routes
50
- app.openapi_schema = None
51
-
52
-
53
- def _get_http_session_env() -> SupportDeskEnvironment:
54
- global _http_session_env
55
- if _http_session_env is None:
56
- _http_session_env = SupportDeskEnvironment()
57
- return _http_session_env
58
-
59
-
60
- def _serialize_observation(observation: SupportDeskObservation) -> dict[str, Any]:
61
- return observation.model_dump()
62
-
63
-
64
- _remove_stateless_http_routes()
65
-
66
-
67
- @app.post("/reset", response_model=ResetResponse, tags=["Environment Control"])
68
- def reset_environment(
69
- request: ResetRequest = Body(default_factory=ResetRequest),
70
- ) -> ResetResponse:
71
- """Keep a persistent HTTP environment so reset/step/state share the same case."""
72
-
73
- with _http_session_lock:
74
- env = _get_http_session_env()
75
- observation = env.reset(seed=request.seed, episode_id=request.episode_id)
76
- reward = float(observation.reward) if observation.reward is not None else None
77
- return ResetResponse(
78
- observation=_serialize_observation(observation),
79
- reward=reward,
80
- done=observation.done,
81
- )
82
-
83
-
84
- @app.post("/step", response_model=StepResponse, tags=["Environment Control"])
85
- def step_environment(request: StepRequest) -> StepResponse:
86
- """Advance the current HTTP session instead of stepping a fresh env instance."""
87
-
88
- with _http_session_lock:
89
- env = _get_http_session_env()
90
- action = SupportDeskAction.model_validate(request.action)
91
- observation = env.step(action, timeout_s=request.timeout_s)
92
- reward = float(observation.reward) if observation.reward is not None else None
93
- return StepResponse(
94
- observation=_serialize_observation(observation),
95
- reward=reward,
96
- done=observation.done,
97
- )
98
-
99
-
100
- @app.get("/state", response_model=SupportDeskState, tags=["State Management"])
101
- def get_environment_state() -> SupportDeskState:
102
- """Return the real current HTTP session state for grader-style inspection."""
103
-
104
- with _http_session_lock:
105
- return _get_http_session_env().state
106
-
107
 
108
  @app.get("/tasks")
109
  def list_tasks() -> dict[str, Any]:
 
3
  from __future__ import annotations
4
 
5
  import os
 
6
  from typing import Any
7
 
8
  import uvicorn
 
 
9
 
10
  try:
11
+ from openenv.core.env_server import http_server as openenv_http_server
 
12
  except ImportError: # pragma: no cover - package name differs across releases
13
+ from openenv_core.env_server import http_server as openenv_http_server
 
14
 
15
  from supportdesk_env.models import SupportDeskAction, SupportDeskObservation, SupportDeskState
16
  from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
17
  from supportdesk_env.tasks import TASKS
18
 
19
+ openenv_http_server.State = SupportDeskState
20
+ create_app = openenv_http_server.create_app
21
+
22
  app = create_app(
23
  SupportDeskEnvironment,
24
  action_cls=SupportDeskAction,
 
26
  env_name="supportdesk_env",
27
  )
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @app.get("/tasks")
31
  def list_tasks() -> dict[str, Any]:
supportdesk_env/server/supportdesk_environment.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  import os
 
6
  import uuid
7
  from pathlib import Path
8
 
@@ -31,6 +32,18 @@ class SupportDeskEnvironment(
31
  ):
32
  """A realistic customer support triage environment with dense rewards."""
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def __init__(self, task_id: str | None = None):
35
  super().__init__()
36
  requested_task = task_id or os.getenv("SUPPORTDESK_TASK_ID") or list_task_ids()[0]
@@ -46,23 +59,80 @@ class SupportDeskEnvironment(
46
  initial_grade = grade_case(self.task, self._case)
47
  self._score = initial_grade.total_score
48
  self._completed_milestones = list(initial_grade.completed_milestones)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  @property
51
  def state(self) -> SupportDeskState:
52
- return SupportDeskState(
53
- episode_id=self._episode_id,
54
- task_id=self.task.task_id,
55
- difficulty=self.task.difficulty,
56
- step_count=self._step_count,
57
- reward=round(self._reward_total, 4),
58
- done=self._done,
59
- current_score=round(self._score, 4),
60
- max_steps=self._max_steps,
61
- case=self._case.model_copy(deep=True),
62
- action_history=[entry.model_copy(deep=True) for entry in self._history],
63
- completed_milestones=list(self._completed_milestones),
64
- last_feedback=self._last_feedback,
65
- )
 
 
66
 
67
  def reset(
68
  self,
@@ -70,17 +140,13 @@ class SupportDeskEnvironment(
70
  episode_id: str | None = None,
71
  **kwargs,
72
  ) -> SupportDeskObservation:
73
- self._episode_id = episode_id or f"{self.task.task_id}-{uuid.uuid4().hex[:8]}"
74
- self._step_count = 0
75
- self._reward_total = 0.0
76
- self._done = False
77
- self._last_feedback = "New case loaded. Review the ticket and policy snippets before acting."
78
- self._history = []
79
- self._case = SupportCaseProgress()
80
- initial_grade = grade_case(self.task, self._case)
81
- self._score = initial_grade.total_score
82
- self._completed_milestones = list(initial_grade.completed_milestones)
83
- return self._build_observation(reward=0.0, done=False)
84
 
85
  def step(
86
  self,
@@ -88,51 +154,59 @@ class SupportDeskEnvironment(
88
  timeout_s: float | None = None,
89
  **kwargs,
90
  ) -> SupportDeskObservation:
91
- if self._done:
92
- return self._build_observation(
93
- reward=-0.05,
94
- done=True,
95
- feedback="Episode already finished. Call reset() before taking more actions.",
96
- )
97
-
98
- previous_grade = grade_case(self.task, self._case)
99
- self._apply_action(action)
100
- self._step_count += 1
101
 
102
- current_grade = grade_case(self.task, self._case)
103
- reward = current_grade.total_score - previous_grade.total_score
104
- reward += self._action_penalty(action, current_grade.total_score, previous_grade.total_score)
105
- reward = round(reward, 4)
 
 
106
 
107
- self._score = current_grade.total_score
108
- self._completed_milestones = list(current_grade.completed_milestones)
 
109
 
110
- if action.operation == "submit":
111
- self._done = True
112
- self._last_feedback = (
113
- "Case submitted. Final deterministic grade is "
114
- f"{current_grade.total_score:.2f}."
 
115
  )
116
- elif self._step_count >= self._max_steps:
117
- self._done = True
118
- self._last_feedback = (
119
- f"Reached max steps ({self._max_steps}). Final deterministic grade is "
120
- f"{current_grade.total_score:.2f}."
121
- )
122
- else:
123
- self._last_feedback = self._build_feedback(current_grade, reward)
124
-
125
- self._reward_total = round(self._reward_total + reward, 4)
126
- self._history.append(
127
- ActionHistoryEntry(
128
- step=self._step_count,
129
- operation=action.operation,
130
- summary=self._summarize_action(action),
131
- reward_delta=reward,
 
 
 
 
 
 
 
 
 
 
 
 
132
  )
133
- )
134
 
135
- return self._build_observation(reward=reward, done=self._done)
136
 
137
  def close(self) -> None:
138
  """No-op close hook for compatibility with local scripts."""
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ import threading
7
  import uuid
8
  from pathlib import Path
9
 
 
32
  ):
33
  """A realistic customer support triage environment with dense rewards."""
34
 
35
+ _state_lock = threading.RLock()
36
+ _shared_task_id: str | None = None
37
+ _shared_step_count = 0
38
+ _shared_reward_total = 0.0
39
+ _shared_done = False
40
+ _shared_last_feedback = ""
41
+ _shared_history: list[ActionHistoryEntry] = []
42
+ _shared_case = SupportCaseProgress()
43
+ _shared_episode_id: str | None = None
44
+ _shared_score = 0.0
45
+ _shared_completed_milestones: list[str] = []
46
+
47
  def __init__(self, task_id: str | None = None):
48
  super().__init__()
49
  requested_task = task_id or os.getenv("SUPPORTDESK_TASK_ID") or list_task_ids()[0]
 
59
  initial_grade = grade_case(self.task, self._case)
60
  self._score = initial_grade.total_score
61
  self._completed_milestones = list(initial_grade.completed_milestones)
62
+ self._ensure_shared_state(self.task)
63
+
64
+ @classmethod
65
+ def _initialize_shared_state(
66
+ cls,
67
+ task: SupportTaskSpec,
68
+ *,
69
+ episode_id: str | None = None,
70
+ ) -> None:
71
+ initial_case = SupportCaseProgress()
72
+ initial_grade = grade_case(task, initial_case)
73
+ cls._shared_task_id = task.task_id
74
+ cls._shared_step_count = 0
75
+ cls._shared_reward_total = 0.0
76
+ cls._shared_done = False
77
+ cls._shared_last_feedback = (
78
+ "New case loaded. Review the ticket and policy snippets before acting."
79
+ )
80
+ cls._shared_history = []
81
+ cls._shared_case = initial_case
82
+ cls._shared_episode_id = episode_id
83
+ cls._shared_score = initial_grade.total_score
84
+ cls._shared_completed_milestones = list(initial_grade.completed_milestones)
85
+
86
+ @classmethod
87
+ def _ensure_shared_state(cls, task: SupportTaskSpec) -> None:
88
+ with cls._state_lock:
89
+ if cls._shared_task_id is None:
90
+ cls._initialize_shared_state(task)
91
+
92
+ def _sync_from_shared(self) -> None:
93
+ task = get_task(self.__class__._shared_task_id or self.task.task_id)
94
+ self.task = task
95
+ self._max_steps = task.max_steps
96
+ self._step_count = self.__class__._shared_step_count
97
+ self._reward_total = self.__class__._shared_reward_total
98
+ self._done = self.__class__._shared_done
99
+ self._last_feedback = self.__class__._shared_last_feedback
100
+ self._history = [entry.model_copy(deep=True) for entry in self.__class__._shared_history]
101
+ self._case = self.__class__._shared_case.model_copy(deep=True)
102
+ self._episode_id = self.__class__._shared_episode_id
103
+ self._score = self.__class__._shared_score
104
+ self._completed_milestones = list(self.__class__._shared_completed_milestones)
105
+
106
+ def _sync_to_shared(self) -> None:
107
+ self.__class__._shared_task_id = self.task.task_id
108
+ self.__class__._shared_step_count = self._step_count
109
+ self.__class__._shared_reward_total = self._reward_total
110
+ self.__class__._shared_done = self._done
111
+ self.__class__._shared_last_feedback = self._last_feedback
112
+ self.__class__._shared_history = [entry.model_copy(deep=True) for entry in self._history]
113
+ self.__class__._shared_case = self._case.model_copy(deep=True)
114
+ self.__class__._shared_episode_id = self._episode_id
115
+ self.__class__._shared_score = self._score
116
+ self.__class__._shared_completed_milestones = list(self._completed_milestones)
117
 
118
  @property
119
  def state(self) -> SupportDeskState:
120
+ with self.__class__._state_lock:
121
+ self._sync_from_shared()
122
+ return SupportDeskState(
123
+ episode_id=self._episode_id,
124
+ task_id=self.task.task_id,
125
+ difficulty=self.task.difficulty,
126
+ step_count=self._step_count,
127
+ reward=round(self._reward_total, 4),
128
+ done=self._done,
129
+ current_score=round(self._score, 4),
130
+ max_steps=self._max_steps,
131
+ case=self._case.model_copy(deep=True),
132
+ action_history=[entry.model_copy(deep=True) for entry in self._history],
133
+ completed_milestones=list(self._completed_milestones),
134
+ last_feedback=self._last_feedback,
135
+ )
136
 
137
  def reset(
138
  self,
 
140
  episode_id: str | None = None,
141
  **kwargs,
142
  ) -> SupportDeskObservation:
143
+ with self.__class__._state_lock:
144
+ self.__class__._initialize_shared_state(
145
+ self.task,
146
+ episode_id=episode_id or f"{self.task.task_id}-{uuid.uuid4().hex[:8]}",
147
+ )
148
+ self._sync_from_shared()
149
+ return self._build_observation(reward=0.0, done=False)
 
 
 
 
150
 
151
  def step(
152
  self,
 
154
  timeout_s: float | None = None,
155
  **kwargs,
156
  ) -> SupportDeskObservation:
157
+ with self.__class__._state_lock:
158
+ self._sync_from_shared()
 
 
 
 
 
 
 
 
159
 
160
+ if self._done:
161
+ return self._build_observation(
162
+ reward=-0.05,
163
+ done=True,
164
+ feedback="Episode already finished. Call reset() before taking more actions.",
165
+ )
166
 
167
+ previous_grade = grade_case(self.task, self._case)
168
+ self._apply_action(action)
169
+ self._step_count += 1
170
 
171
+ current_grade = grade_case(self.task, self._case)
172
+ reward = current_grade.total_score - previous_grade.total_score
173
+ reward += self._action_penalty(
174
+ action,
175
+ current_grade.total_score,
176
+ previous_grade.total_score,
177
  )
178
+ reward = round(reward, 4)
179
+
180
+ self._score = current_grade.total_score
181
+ self._completed_milestones = list(current_grade.completed_milestones)
182
+
183
+ if action.operation == "submit":
184
+ self._done = True
185
+ self._last_feedback = (
186
+ "Case submitted. Final deterministic grade is "
187
+ f"{current_grade.total_score:.2f}."
188
+ )
189
+ elif self._step_count >= self._max_steps:
190
+ self._done = True
191
+ self._last_feedback = (
192
+ f"Reached max steps ({self._max_steps}). Final deterministic grade is "
193
+ f"{current_grade.total_score:.2f}."
194
+ )
195
+ else:
196
+ self._last_feedback = self._build_feedback(current_grade, reward)
197
+
198
+ self._reward_total = round(self._reward_total + reward, 4)
199
+ self._history.append(
200
+ ActionHistoryEntry(
201
+ step=self._step_count,
202
+ operation=action.operation,
203
+ summary=self._summarize_action(action),
204
+ reward_delta=reward,
205
+ )
206
  )
207
+ self._sync_to_shared()
208
 
209
+ return self._build_observation(reward=reward, done=self._done)
210
 
211
  def close(self) -> None:
212
  """No-op close hook for compatibility with local scripts."""