modelbuilderhq commited on
Commit
726cf7a
·
verified ·
1 Parent(s): d33da97

Upload folder using huggingface_hub

Browse files
supportdesk_env/models.py CHANGED
@@ -90,6 +90,7 @@ class SupportDeskObservation(Observation):
90
  class SupportDeskState(State):
91
  """Current environment state returned by the OpenEnv state() API."""
92
 
 
93
  task_id: str
94
  difficulty: Literal["easy", "medium", "hard"]
95
  step_count: int = 0
 
90
  class SupportDeskState(State):
91
  """Current environment state returned by the OpenEnv state() API."""
92
 
93
+ episode_id: str | None = None
94
  task_id: str
95
  difficulty: Literal["easy", "medium", "hard"]
96
  step_count: int = 0
supportdesk_env/server/app.py CHANGED
@@ -3,16 +3,21 @@
3
  from __future__ import annotations
4
 
5
  import os
 
6
  from typing import Any
7
 
8
  import uvicorn
 
 
9
 
10
  try:
11
  from openenv.core.env_server.http_server import create_app
 
12
  except ImportError: # pragma: no cover - package name differs across releases
13
  from openenv_core.env_server.http_server import create_app
 
14
 
15
- from supportdesk_env.models import SupportDeskAction, SupportDeskObservation
16
  from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
17
  from supportdesk_env.tasks import TASKS
18
 
@@ -23,6 +28,82 @@ app = create_app(
23
  env_name="supportdesk_env",
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  @app.get("/tasks")
28
  def list_tasks() -> dict[str, Any]:
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ import threading
7
  from typing import Any
8
 
9
  import uvicorn
10
+ from fastapi import Body
11
+ from fastapi.routing import APIRoute
12
 
13
  try:
14
  from openenv.core.env_server.http_server import create_app
15
+ from openenv.core.env_server.types import ResetRequest, ResetResponse, StepRequest, StepResponse
16
  except ImportError: # pragma: no cover - package name differs across releases
17
  from openenv_core.env_server.http_server import create_app
18
+ from openenv_core.env_server.types import ResetRequest, ResetResponse, StepRequest, StepResponse
19
 
20
+ from supportdesk_env.models import SupportDeskAction, SupportDeskObservation, SupportDeskState
21
  from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
22
  from supportdesk_env.tasks import TASKS
23
 
 
28
  env_name="supportdesk_env",
29
  )
30
 
31
+ _http_session_lock = threading.RLock()
32
+ _http_session_env: SupportDeskEnvironment | None = None
33
+
34
+
35
+ def _remove_stateless_http_routes() -> None:
36
+ """Replace OpenEnv's stateless HTTP simulation routes with a persistent session."""
37
+
38
+ kept_routes = []
39
+ for route in app.router.routes:
40
+ if isinstance(route, APIRoute):
41
+ methods = route.methods or set()
42
+ if route.path == "/reset" and "POST" in methods:
43
+ continue
44
+ if route.path == "/step" and "POST" in methods:
45
+ continue
46
+ if route.path == "/state" and "GET" in methods:
47
+ continue
48
+ kept_routes.append(route)
49
+ app.router.routes = kept_routes
50
+ app.openapi_schema = None
51
+
52
+
53
+ def _get_http_session_env() -> SupportDeskEnvironment:
54
+ global _http_session_env
55
+ if _http_session_env is None:
56
+ _http_session_env = SupportDeskEnvironment()
57
+ return _http_session_env
58
+
59
+
60
+ def _serialize_observation(observation: SupportDeskObservation) -> dict[str, Any]:
61
+ return observation.model_dump()
62
+
63
+
64
+ _remove_stateless_http_routes()
65
+
66
+
67
+ @app.post("/reset", response_model=ResetResponse, tags=["Environment Control"])
68
+ def reset_environment(
69
+ request: ResetRequest = Body(default_factory=ResetRequest),
70
+ ) -> ResetResponse:
71
+ """Keep a persistent HTTP environment so reset/step/state share the same case."""
72
+
73
+ with _http_session_lock:
74
+ env = _get_http_session_env()
75
+ observation = env.reset(seed=request.seed, episode_id=request.episode_id)
76
+ reward = float(observation.reward) if observation.reward is not None else None
77
+ return ResetResponse(
78
+ observation=_serialize_observation(observation),
79
+ reward=reward,
80
+ done=observation.done,
81
+ )
82
+
83
+
84
+ @app.post("/step", response_model=StepResponse, tags=["Environment Control"])
85
+ def step_environment(request: StepRequest) -> StepResponse:
86
+ """Advance the current HTTP session instead of stepping a fresh env instance."""
87
+
88
+ with _http_session_lock:
89
+ env = _get_http_session_env()
90
+ action = SupportDeskAction.model_validate(request.action)
91
+ observation = env.step(action, timeout_s=request.timeout_s)
92
+ reward = float(observation.reward) if observation.reward is not None else None
93
+ return StepResponse(
94
+ observation=_serialize_observation(observation),
95
+ reward=reward,
96
+ done=observation.done,
97
+ )
98
+
99
+
100
+ @app.get("/state", response_model=SupportDeskState, tags=["State Management"])
101
+ def get_environment_state() -> SupportDeskState:
102
+ """Return the real current HTTP session state for grader-style inspection."""
103
+
104
+ with _http_session_lock:
105
+ return _get_http_session_env().state
106
+
107
 
108
  @app.get("/tasks")
109
  def list_tasks() -> dict[str, Any]:
supportdesk_env/server/supportdesk_environment.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  import os
 
6
  from pathlib import Path
7
 
8
  from supportdesk_env.graders import grade_case
@@ -41,6 +42,7 @@ class SupportDeskEnvironment(
41
  self._last_feedback = ""
42
  self._history: list[ActionHistoryEntry] = []
43
  self._case = SupportCaseProgress()
 
44
  initial_grade = grade_case(self.task, self._case)
45
  self._score = initial_grade.total_score
46
  self._completed_milestones = list(initial_grade.completed_milestones)
@@ -48,6 +50,7 @@ class SupportDeskEnvironment(
48
  @property
49
  def state(self) -> SupportDeskState:
50
  return SupportDeskState(
 
51
  task_id=self.task.task_id,
52
  difficulty=self.task.difficulty,
53
  step_count=self._step_count,
@@ -67,6 +70,7 @@ class SupportDeskEnvironment(
67
  episode_id: str | None = None,
68
  **kwargs,
69
  ) -> SupportDeskObservation:
 
70
  self._step_count = 0
71
  self._reward_total = 0.0
72
  self._done = False
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ import uuid
7
  from pathlib import Path
8
 
9
  from supportdesk_env.graders import grade_case
 
42
  self._last_feedback = ""
43
  self._history: list[ActionHistoryEntry] = []
44
  self._case = SupportCaseProgress()
45
+ self._episode_id: str | None = None
46
  initial_grade = grade_case(self.task, self._case)
47
  self._score = initial_grade.total_score
48
  self._completed_milestones = list(initial_grade.completed_milestones)
 
50
  @property
51
  def state(self) -> SupportDeskState:
52
  return SupportDeskState(
53
+ episode_id=self._episode_id,
54
  task_id=self.task.task_id,
55
  difficulty=self.task.difficulty,
56
  step_count=self._step_count,
 
70
  episode_id: str | None = None,
71
  **kwargs,
72
  ) -> SupportDeskObservation:
73
+ self._episode_id = episode_id or f"{self.task.task_id}-{uuid.uuid4().hex[:8]}"
74
  self._step_count = 0
75
  self._reward_total = 0.0
76
  self._done = False
tests/test_supportdesk.py CHANGED
@@ -1,5 +1,12 @@
1
  """Smoke tests for the SupportDesk environment."""
2
 
 
 
 
 
 
 
 
3
  from supportdesk_env.graders import grade_case
4
  from supportdesk_env.models import SupportDeskAction
5
  from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
@@ -73,3 +80,46 @@ def test_grade_is_bounded_between_zero_and_one():
73
  env.reset()
74
  breakdown = grade_case(task, env.state.case)
75
  assert 0.0 <= breakdown.total_score <= 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Smoke tests for the SupportDesk environment."""
2
 
3
+ import pytest
4
+
5
+ try:
6
+ from fastapi.testclient import TestClient
7
+ except RuntimeError:
8
+ TestClient = None # type: ignore[assignment]
9
+
10
  from supportdesk_env.graders import grade_case
11
  from supportdesk_env.models import SupportDeskAction
12
  from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
 
80
  env.reset()
81
  breakdown = grade_case(task, env.state.case)
82
  assert 0.0 <= breakdown.total_score <= 1.0
83
+
84
+
85
+ def test_state_includes_episode_id_after_reset():
86
+ env = SupportDeskEnvironment(task_id="billing_refund_easy")
87
+ env.reset(episode_id="episode-123")
88
+ assert env.state.episode_id == "episode-123"
89
+
90
+
91
+ @pytest.mark.skipif(TestClient is None, reason="httpx is not installed for FastAPI TestClient")
92
+ def test_http_reset_step_state_are_session_consistent():
93
+ from supportdesk_env.server.app import app
94
+
95
+ client = TestClient(app)
96
+
97
+ reset_response = client.post("/reset", json={"episode_id": "http-episode"})
98
+ assert reset_response.status_code == 200
99
+
100
+ step_response = client.post(
101
+ "/step",
102
+ json={
103
+ "action": {
104
+ "operation": "classify",
105
+ "queue": "billing_ops",
106
+ "priority": "high",
107
+ "issue_type": "duplicate_charge",
108
+ "status": "new",
109
+ "requested_fields": [],
110
+ "reply": "",
111
+ "internal_note": "",
112
+ }
113
+ },
114
+ )
115
+ assert step_response.status_code == 200
116
+
117
+ state_response = client.get("/state")
118
+ assert state_response.status_code == 200
119
+ state_payload = state_response.json()
120
+
121
+ assert state_payload["episode_id"] == "http-episode"
122
+ assert state_payload["step_count"] == 1
123
+ assert state_payload["case"]["queue"] == "billing_ops"
124
+ assert state_payload["case"]["priority"] == "high"
125
+ assert state_payload["case"]["issue_type"] == "duplicate_charge"
uvicorn-test.err.log ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ INFO: Started server process [5720]
2
+ INFO: Waiting for application startup.
3
+ INFO: Application startup complete.
4
+ INFO: Uvicorn running on http://127.0.0.1:8001 (Press CTRL+C to quit)
uvicorn-test.out.log ADDED
File without changes