rishi38 commited on
Commit
fe0c391
·
verified ·
1 Parent(s): 044810e

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=smart_emergency
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ ENV ENABLE_WEB_INTERFACE=true
81
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
Makefile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: build start serve stop health
2
+
3
+ # ── Docker ────────────────────────────────────────────────────────────────────
4
+ build:
5
+ @docker build -t emergency:latest -f Dockerfile .
6
+
7
+ start:
8
+ @docker run -p 8000:8000 emergency:latest
9
+
10
+ stop:
11
+ @docker ps -q --filter ancestor=emergency:latest | xargs -r docker stop
12
+
13
+ # ── Local dev (uv) ────────────────────────────────────────────────────────────
14
+ serve:
15
+ @uv run uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
16
+
17
+ health:
18
+ @curl -s http://localhost:8000/health | python3 -m json.tool
README.md CHANGED
@@ -1,10 +1,267 @@
1
  ---
2
- title: Smart Emergency
3
- emoji: 🐢
4
  colorFrom: pink
5
  colorTo: green
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Smart Emergency Environment Server
3
+ emoji: 🚨
4
  colorFrom: pink
5
  colorTo: green
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # Smart Emergency Dispatch911 RL Environment
15
+
16
+ A disaster management reinforcement learning environment where an agent acts as an emergency dispatcher. Each episode, the agent receives live 911 call transcripts and must triage severity, detect duplicate calls, and dispatch the right vehicle (police / ambulance / fire) from a procedurally generated city graph.
17
+
18
+ Built on [OpenEnv](https://github.com/meta-pytorch/OpenEnv) — a standard interface for RL environments exposed over HTTP/WebSocket, compatible with TRL + Unsloth training pipelines.
19
+
20
+ ---
21
+
22
+ ## Environment Overview
23
+
24
+ | Property | Value |
25
+ |---|---|
26
+ | **Task** | Emergency dispatch (triage + routing) |
27
+ | **Episode length** | 20 steps |
28
+ | **Action space** | `dispatch` or `duplicate` with structured fields |
29
+ | **Observation** | Rich text prompt (call transcript + active events + fleet status + city map) |
30
+ | **Reward** | 5-component shaped reward (severity, duplicate detection, vehicle type, vehicle choice, reroute) |
31
+ | **Duplicate call rate** | 30% |
32
+
33
+ ---
34
+
35
+ ## Quick Start
36
+
37
+ ```python
38
+ from smart_emergency import SmartEmergencyAction, SmartEmergencyEnv
39
+
40
+ with SmartEmergencyEnv(base_url="http://localhost:8000") as env:
41
+ result = env.reset()
42
+ print(result.observation.prompt)
43
+
44
+ # Dispatch an ambulance to the incident
45
+ action = SmartEmergencyAction(
46
+ action_type="dispatch",
47
+ severity_pred=3,
48
+ is_duplicate=False,
49
+ vehicle_type="ambulance",
50
+ vehicle_id="ambulance_0",
51
+ )
52
+ result = env.step(action)
53
+ print(result.observation.reward_breakdown)
54
+ # → {'severity': 1.0, 'duplicate': 1.0, 'vehicle_type': 1.5, 'vehicle_choice': 0.5, 'reroute': 0.0, 'total': 4.0}
55
+ ```
56
+
57
+ ---
58
+
59
+ ## Action Space
60
+
61
+ **`SmartEmergencyAction`** — the agent's structured response to each incoming 911 call.
62
+
63
+ | Field | Type | Required | Description |
64
+ |---|---|---|---|
65
+ | `action_type` | `str` | ✅ | `"dispatch"` or `"duplicate"` |
66
+ | `severity_pred` | `int` (1–5) | ✅ | Predicted severity (1=minor, 5=catastrophic) |
67
+ | `is_duplicate` | `bool` | ✅ | Whether this call is a repeat of an existing event |
68
+ | `duplicate_of_event_id` | `str` | if duplicate | EVT-NNNN of the event this duplicates |
69
+ | `vehicle_type` | `str` | if dispatch | `"police"`, `"ambulance"`, or `"fire"` |
70
+ | `vehicle_id` | `str` | if dispatch | Specific unit ID (e.g. `"ambulance_0"`) |
71
+ | `reroute` | `RerouteAction` | optional | Redirect an in-flight vehicle to the new event |
72
+
73
+ **`RerouteAction`** sub-action:
74
+
75
+ | Field | Type | Description |
76
+ |---|---|---|
77
+ | `vehicle_to_reroute` | `str` | Unit ID of the vehicle to redirect |
78
+ | `from_event_id` | `str` | EVT-NNNN the vehicle is currently heading to |
79
+ | `replacement_vehicle_id` | `str` | Optional free unit to cover the abandoned event |
80
+
81
+ ---
82
+
83
+ ## Observation Space
84
+
85
+ **`SmartEmergencyObservation`** — what the agent sees each step.
86
+
87
+ | Field | Type | Description |
88
+ |---|---|---|
89
+ | `prompt` | `str` | Full text observation for the LLM (see format below) |
90
+ | `step` | `int` | Current step number (0–20) |
91
+ | `call_id` | `str` | ID of the incoming call (e.g. `CALL-0001`) |
92
+ | `reward_breakdown` | `dict` | Per-component reward from the previous action |
93
+ | `active_event_ids` | `list[str]` | Currently active event IDs (EVT-NNNN) |
94
+ | `fleet_utilisation` | `float` | Fraction of fleet currently busy (0.0–1.0) |
95
+
96
+ ### Prompt Format
97
+
98
+ ```
99
+ === INCOMING CALL [CALL-0003] ===
100
+ Bad crash on Oak Avenue! Car flipped near Riverside Market. Driver trapped, not responding!
101
+
102
+ === ACTIVE EVENTS ===
103
+ EVT-0001 | fire | Engine House No. 1 | sev 3 | fire_2 ETA 2 min | opened step 1
104
+ EVT-0002 | medical | Oakwood Apartments | sev 2 | UNASSIGNED | opened step 2
105
+
106
+ === UNIT STATUS ===
107
+ police_0 | police | Central Police Station | FREE
108
+ ambulance_1 | ambulance | Riverside General Hospital | DISPATCHED → EVT-0001
109
+ fire_2 | fire | Central Fire Station | DISPATCHED → EVT-0001
110
+
111
+ === CITY REFERENCE ===
112
+ Riverside General Hospital (hospital) → Oakwood Apartments [3 min], Central Plaza [5 min]
113
+ ...
114
+
115
+ === DISPATCHER NOTES ===
116
+ Step 1: CALL-0001 → fire fire_2
117
+ Step 2: CALL-0002 → Duplicate of EVT-0001
118
+ ```
119
+
120
+ ---
121
+
122
+ ## Reward Design
123
+
124
+ 5 independent reward components returned as `reward_breakdown`:
125
+
126
+ | Component | Max | Min | Description |
127
+ |---|---|---|---|
128
+ | `severity` | +1.0 | -0.5 | Accuracy of severity prediction (graded, ±0 to ±4 off) |
129
+ | `duplicate` | +1.5 | -1.0 | Correct duplicate detection and event ID matching |
130
+ | `vehicle_type` | +1.5 | -1.5 | Correct vehicle type (police / ambulance / fire) |
131
+ | `vehicle_choice` | +1.0 | -2.0 | Vehicle availability, type match, and proximity bonus |
132
+ | `reroute` | +1.7 | -1.0 | Quality of optional reroute instruction |
133
+ | **`total`** | **~6.7** | **~-6.0** | Sum of all components |
134
+
135
+ Parse failure (malformed action): **-2.0** flat penalty.
136
+
137
+ ---
138
+
139
+ ## API Endpoints
140
+
141
+ | Method | Endpoint | Description |
142
+ |---|---|---|
143
+ | `GET` | `/health` | Health check |
144
+ | `POST` | `/reset` | Start a new episode |
145
+ | `POST` | `/step` | Submit an action, get next observation |
146
+ | `GET` | `/state` | Current episode state |
147
+ | `GET` | `/tasks` | List available tasks / difficulty levels |
148
+ | `POST` | `/grader` | Score a completed episode (call after `done=True`) |
149
+ | `GET` | `/baseline` | Run rule-based agent across all tasks |
150
+ | `GET` | `/docs` | Interactive Swagger UI |
151
+ | `WS` | `/ws` | WebSocket for persistent low-latency sessions |
152
+
153
+ ---
154
+
155
+ ## Running Locally
156
+
157
+ ### Option 1: uv (fastest)
158
+
159
+ ```bash
160
+ uv sync
161
+ uv run uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
162
+ ```
163
+
164
+ Or via the Makefile:
165
+
166
+ ```bash
167
+ make serve # uv run, with hot-reload
168
+ make build # build Docker image
169
+ make start # run Docker container
170
+ ```
171
+
172
+ ### Option 2: Docker
173
+
174
+ ```bash
175
+ make build
176
+ make start
177
+ ```
178
+
179
+ Then open http://localhost:8000/docs
180
+
181
+ ---
182
+
183
+ ## Connecting to a Running Server
184
+
185
+ ```python
186
+ from smart_emergency import SmartEmergencyEnv
187
+
188
+ env = SmartEmergencyEnv(base_url="http://localhost:8000")
189
+ result = env.reset()
190
+ print(result.observation.prompt)
191
+ ```
192
+
193
+ Or use the deployed HF Space directly:
194
+
195
+ ```python
196
+ env = SmartEmergencyEnv(base_url="https://rishi38-eme-enviro.hf.space")
197
+ ```
198
+
199
+ ---
200
+
201
+ ## Grading a Completed Episode
202
+
203
+ After the episode ends (`done=True`), call `/grader`:
204
+
205
+ ```bash
206
+ curl -X POST http://localhost:8000/grader
207
+ ```
208
+
209
+ ```json
210
+ {
211
+ "score": 0.82,
212
+ "reward_components": {
213
+ "severity_accuracy": 0.91,
214
+ "duplicate_f1": 0.75,
215
+ "dispatch_accuracy": 0.88,
216
+ "vehicle_efficiency": 0.74
217
+ },
218
+ "steps": 20,
219
+ "episode_id": "abc-123"
220
+ }
221
+ ```
222
+
223
+ ---
224
+
225
+ ## Baseline Agent
226
+
227
+ Run the built-in rule-based agent to get a reference score:
228
+
229
+ ```bash
230
+ curl http://localhost:8000/baseline
231
+ ```
232
+
233
+ ```json
234
+ {
235
+ "baseline_agent": "keyword-heuristic rule-based",
236
+ "average_score": 0.61,
237
+ "tasks": {
238
+ "task_1": {"score": 0.72, "difficulty": "easy", "steps": 20},
239
+ "task_2": {"score": 0.63, "difficulty": "medium", "steps": 20},
240
+ "task_3": {"score": 0.48, "difficulty": "hard", "steps": 20}
241
+ }
242
+ }
243
+ ```
244
+
245
+ ---
246
+
247
+ ## Project Structure
248
+
249
+ ```
250
+ smart_emergency/
251
+ ├── README.md # This file (HF Space config + docs)
252
+ ├── openenv.yaml # OpenEnv manifest
253
+ ├── pyproject.toml # Package metadata & dependencies
254
+ ├── Dockerfile # Container build
255
+ ├── Makefile # Dev commands (build, start, serve)
256
+ ├── uv.lock # Locked dependencies
257
+ ├── __init__.py # Package exports
258
+ ├── models.py # SmartEmergencyAction + Observation
259
+ ├── client.py # SmartEmergencyEnv HTTP/WS client
260
+ └── server/
261
+ ├── __init__.py
262
+ ├── app.py # FastAPI app via openenv create_app
263
+ ├── smart_emergency_environment.py # Core reset/step/reward logic
264
+ ├── city.py # Procedural city graph + Dijkstra
265
+ ├── calls.py # 911 call generator (25 templates)
266
+ └── reward.py # 5-component decomposed reward
267
+ ```
__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Smart Emergency Environment."""
8
+
9
+ from .client import SmartEmergencyEnv
10
+ from .models import SmartEmergencyAction, SmartEmergencyObservation, RerouteAction
11
+
12
+ __all__ = [
13
+ "SmartEmergencyAction",
14
+ "SmartEmergencyObservation",
15
+ "RerouteAction",
16
+ "SmartEmergencyEnv",
17
+ ]
client.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Dispatch911 Environment Client."""
8
+
9
+ from typing import Dict, Optional
10
+
11
+ from openenv.core import EnvClient
12
+ from openenv.core.client_types import StepResult
13
+ from openenv.core.env_server.types import State
14
+
15
+ from .models import SmartEmergencyAction, SmartEmergencyObservation, RerouteAction
16
+
17
+
18
+ class SmartEmergencyEnv(
19
+ EnvClient[SmartEmergencyAction, SmartEmergencyObservation, State]
20
+ ):
21
+ """
22
+ Client for the Dispatch911 Environment.
23
+
24
+ Example:
25
+ >>> with SmartEmergencyEnv(base_url="http://localhost:8000") as client:
26
+ ... result = client.reset()
27
+ ... print(result.observation.prompt)
28
+ ...
29
+ ... action = SmartEmergencyAction(
30
+ ... action_type="dispatch",
31
+ ... severity_pred=3,
32
+ ... is_duplicate=False,
33
+ ... vehicle_type="ambulance",
34
+ ... vehicle_id="ambulance_0",
35
+ ... )
36
+ ... result = client.step(action)
37
+ ... print(result.observation.reward_breakdown)
38
+ """
39
+
40
+ def _step_payload(self, action: SmartEmergencyAction) -> Dict:
41
+ """Convert SmartEmergencyAction to JSON payload."""
42
+ payload: Dict = {
43
+ "action_type": action.action_type,
44
+ "severity_pred": action.severity_pred,
45
+ "is_duplicate": action.is_duplicate,
46
+ }
47
+ if action.duplicate_of_event_id is not None:
48
+ payload["duplicate_of_event_id"] = action.duplicate_of_event_id
49
+ if action.vehicle_type is not None:
50
+ payload["vehicle_type"] = action.vehicle_type
51
+ if action.vehicle_id is not None:
52
+ payload["vehicle_id"] = action.vehicle_id
53
+ if action.reroute is not None:
54
+ payload["reroute"] = {
55
+ "vehicle_to_reroute": action.reroute.vehicle_to_reroute,
56
+ "from_event_id": action.reroute.from_event_id,
57
+ "replacement_vehicle_id": action.reroute.replacement_vehicle_id,
58
+ }
59
+ return payload
60
+
61
+ def _parse_result(self, payload: Dict) -> StepResult[SmartEmergencyObservation]:
62
+ """Parse server response into StepResult.
63
+
64
+ Note: OpenEnv's serialize_observation() intentionally strips 'metadata',
65
+ 'done', and 'reward' from the nested observation dict and promotes them
66
+ to the top level. ground_truth is now a first-class field on the
67
+ observation model so it survives serialization.
68
+ """
69
+ obs_data = payload.get("observation", {})
70
+ # metadata is stripped by the framework; ground_truth is now a dedicated field
71
+ metadata = payload.get("metadata", obs_data.get("metadata", {}))
72
+ # Support both the new dedicated ground_truth field and the legacy metadata path
73
+ gt = obs_data.get("ground_truth") or metadata.get("ground_truth", {})
74
+ if gt:
75
+ metadata = dict(metadata)
76
+ metadata["ground_truth"] = gt
77
+ observation = SmartEmergencyObservation(
78
+ prompt=obs_data.get("prompt", ""),
79
+ step=obs_data.get("step", 0),
80
+ call_id=obs_data.get("call_id", ""),
81
+ reward_breakdown=obs_data.get("reward_breakdown", {}),
82
+ active_event_ids=obs_data.get("active_event_ids", []),
83
+ fleet_utilisation=obs_data.get("fleet_utilisation", 0.0),
84
+ done=payload.get("done", False),
85
+ reward=payload.get("reward"),
86
+ ground_truth=gt or {},
87
+ metadata=metadata,
88
+ )
89
+ return StepResult(
90
+ observation=observation,
91
+ reward=payload.get("reward"),
92
+ done=payload.get("done", False),
93
+ )
94
+
95
+ def _parse_state(self, payload: Dict) -> State:
96
+ """Parse server response into State."""
97
+ return State(
98
+ episode_id=payload.get("episode_id"),
99
+ step_count=payload.get("step_count", 0),
100
+ )
models.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the Dispatch911 Environment.
9
+
10
+ Action: the agent's structured dispatch decision per incoming 911 call.
11
+ Observation: the text-based observation the agent receives each step.
12
+ """
13
+
14
+ from typing import Dict, List, Literal, Optional
15
+
16
+ from openenv.core.env_server.types import Action, Observation
17
+ from pydantic import Field
18
+
19
+
20
+ # ── Reroute sub-action ──────────────────────────────────────────────────────
21
+
22
+ class RerouteAction(Action):
23
+ """Optional reroute block inside a dispatch action."""
24
+
25
+ vehicle_to_reroute: str = Field(..., description="Unit ID of vehicle to redirect")
26
+ from_event_id: str = Field(..., description="EVT-NNNN the vehicle is pulled from")
27
+ replacement_vehicle_id: Optional[str] = Field(
28
+ None, description="Free unit to cover the abandoned event"
29
+ )
30
+
31
+
32
+ # ── Agent action ─────────────────────────────────────────────────────────────
33
+
34
+ class SmartEmergencyAction(Action):
35
+ """
36
+ The agent's response to an incoming 911 call.
37
+
38
+ Three modes:
39
+ - action_type='dispatch': handle a new emergency
40
+ - action_type='duplicate': flag as repeat of an existing event
41
+ - action_type='hold': queue event for a busy vehicle to handle after it frees
42
+ """
43
+
44
+ action_type: Literal["dispatch", "duplicate", "hold"] = Field(
45
+ ..., description="'dispatch', 'duplicate', or 'hold'"
46
+ )
47
+ severity_pred: int = Field(
48
+ ..., ge=1, le=5, description="Predicted severity 1-5"
49
+ )
50
+ is_duplicate: bool = Field(
51
+ False, description="Whether the agent believes this is a repeat call"
52
+ )
53
+ duplicate_of_event_id: Optional[str] = Field(
54
+ None, description="EVT-NNNN of the event this duplicates (required if is_duplicate)"
55
+ )
56
+ vehicle_type: Optional[str] = Field(
57
+ None, description="'police', 'ambulance', or 'fire' (required if dispatch or hold)"
58
+ )
59
+ vehicle_id: Optional[str] = Field(
60
+ None, description="Unit to dispatch now (dispatch) or busy unit to queue for (hold)"
61
+ )
62
+ reroute: Optional[RerouteAction] = Field(
63
+ None, description="Optional reroute instruction"
64
+ )
65
+
66
+
67
+ # ── Observation ──────────────────────────────────────────────────────────────
68
+
69
+ class SmartEmergencyObservation(Observation):
70
+ """
71
+ Observation returned to the agent each step.
72
+
73
+ Contains the full text prompt (transcript + active events + unit status +
74
+ city reference + dispatcher notes) and structured metadata for logging.
75
+ """
76
+
77
+ prompt: str = Field(default="", description="Full text observation for the LLM")
78
+ step: int = Field(default=0, description="Current step number")
79
+ call_id: str = Field(default="", description="ID of the incoming call")
80
+ reward_breakdown: Dict[str, float] = Field(
81
+ default_factory=dict, description="Per-component reward breakdown"
82
+ )
83
+ active_event_ids: List[str] = Field(
84
+ default_factory=list, description="Currently active event IDs"
85
+ )
86
+ fleet_utilisation: float = Field(
87
+ default=0.0, description="Fraction of fleet currently busy"
88
+ )
89
+ ground_truth: Dict = Field(
90
+ default_factory=dict,
91
+ description="Hidden ground truth for the current call (populated after step)",
92
+ )
openenv.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: smart_emergency
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
8
+ description: >
9
+ A disaster management reinforcement learning environment where agents manage
10
+ emergency dispatch. Agents must triage incoming 911 calls, classify severity,
11
+ detect duplicate events, and dispatch limited resources (Police, Fire, Ambulance)
12
+ across a procedural smart city graph.
13
+
14
+ tags:
15
+ - openenv
16
+ - disaster-management
17
+ - smart-city
18
+ - dispatch
19
+ - rl
20
+
21
+ tasks:
22
+ - id: 1
23
+ name: "Basic Dispatch"
24
+ difficulty: easy
25
+ description: "Low-volume calls, fewer active events. Focus on severity and vehicle type."
26
+ reward_max: 6.7
27
+
28
+ - id: 2
29
+ name: "Duplicate Detection"
30
+ difficulty: medium
31
+ description: "Higher duplicate rate. Agent must correlate repeat callers to existing events."
32
+ reward_max: 6.7
33
+
34
+ - id: 3
35
+ name: "Full Disaster Response"
36
+ difficulty: hard
37
+ description: "High call volume, scarce vehicles, reroutes required. Full 20-step episode."
38
+ reward_max: 6.7
39
+
40
+ observation_space:
41
+ prompt: string
42
+ step: integer
43
+ call_id: string
44
+ reward_breakdown: object
45
+ active_event_ids: array
46
+ fleet_utilisation: float
47
+
48
+ action_space:
49
+ action_type:
50
+ type: string
51
+ values: [dispatch, duplicate]
52
+ severity_pred:
53
+ type: integer
54
+ is_duplicate:
55
+ type: boolean
56
+ duplicate_of_event_id:
57
+ type: string
58
+ vehicle_type:
59
+ type: string
60
+ values: [police, ambulance, fire]
61
+ vehicle_id:
62
+ type: string
63
+ reroute:
64
+ type: object
openenv_smart_emergency.egg-info/PKG-INFO ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-smart_emergency
3
+ Version: 0.1.0
4
+ Summary: Smart Emergency environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.2
7
+ Provides-Extra: dev
8
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
9
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_smart_emergency.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./client.py
5
+ ./generate_sft_data.py
6
+ ./models.py
7
+ openenv_smart_emergency.egg-info/PKG-INFO
8
+ openenv_smart_emergency.egg-info/SOURCES.txt
9
+ openenv_smart_emergency.egg-info/dependency_links.txt
10
+ openenv_smart_emergency.egg-info/entry_points.txt
11
+ openenv_smart_emergency.egg-info/requires.txt
12
+ openenv_smart_emergency.egg-info/top_level.txt
13
+ server/__init__.py
14
+ server/app.py
15
+ server/calls.py
16
+ server/city.py
17
+ server/reward.py
18
+ server/smart_emergency_environment.py
openenv_smart_emergency.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_smart_emergency.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = smart_emergency.server.app:main
openenv_smart_emergency.egg-info/requires.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.2
2
+
3
+ [dev]
4
+ pytest>=8.0.0
5
+ pytest-cov>=4.0.0
openenv_smart_emergency.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ smart_emergency
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-smart_emergency"
13
+ version = "0.1.0"
14
+ description = "Smart Emergency environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.2",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+
37
+ [project.scripts]
38
+ # Server entry point - enables running via: uv run --project . server
39
+ # or: python -m smart_emergency.server.app
40
+ server = "smart_emergency.server.app:main"
41
+
42
+ [tool.setuptools]
43
+ include-package-data = true
44
+ packages = ["smart_emergency", "smart_emergency.server"]
45
+ package-dir = { "smart_emergency" = ".", "smart_emergency.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Smart Emergency environment server components."""
8
+
9
+ from .smart_emergency_environment import SmartEmergencyEnvironment
10
+
11
+ __all__ = ["SmartEmergencyEnvironment"]
server/app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Smart Emergency Environment.
9
+
10
+ Endpoints:
11
+ POST /reset — Reset the environment, start a new episode
12
+ POST /step — Submit an action, receive next observation + reward
13
+ GET /state — Current episode state
14
+ GET /health — Health check
15
+ GET /tasks — Available difficulty tasks
16
+ POST /grader — Score a completed episode (call after done=True)
17
+ GET /baseline — Run rule-based agent across all 3 tasks
18
+ WS /ws — WebSocket for persistent low-latency sessions
19
+ GET /docs — Swagger UI (auto-generated)
20
+ """
21
+
22
+ from openenv.core.env_server.http_server import create_app
23
+
24
+ try:
25
+ from ..models import SmartEmergencyAction, SmartEmergencyObservation, RerouteAction
26
+ from .smart_emergency_environment import SmartEmergencyEnvironment
27
+ except (ImportError, ModuleNotFoundError):
28
+ from models import SmartEmergencyAction, SmartEmergencyObservation, RerouteAction
29
+ from server.smart_emergency_environment import SmartEmergencyEnvironment
30
+
31
+
32
+ # ── App ──────────────────────────────────────────────────────────────────────
33
+
34
+ # We use create_app so OpenEnv can automatically mount its Gradio web UI at / and /web
35
+ # when deployed to Hugging Face Spaces.
36
+ app = create_app(
37
+ SmartEmergencyEnvironment,
38
+ SmartEmergencyAction,
39
+ SmartEmergencyObservation,
40
+ env_name="smart_emergency",
41
+ max_concurrent_envs=1,
42
+ )
43
+
44
+ # ── Health ───────────────────────────────────────────────────────────────────
45
+
46
+ @app.get("/health")
47
+ def health():
48
+ return {
49
+ "status": "healthy",
50
+ "environment": "smart-emergency-dispatch911",
51
+ "version": "1.0.0",
52
+ }
53
+
54
+
55
+ # ── Tasks ────────────────────────────────────────────────────────────────────
56
+
57
+ @app.get("/tasks")
58
+ def tasks():
59
+ """List available difficulty tasks."""
60
+ return {
61
+ "tasks": [
62
+ {
63
+ "id": 1,
64
+ "name": "Basic Dispatch",
65
+ "difficulty": "easy",
66
+ "description": "10 steps, 3 vehicles per type, 10% duplicates. Focus on severity and vehicle type.",
67
+ "reward_max": 6.7,
68
+ },
69
+ {
70
+ "id": 2,
71
+ "name": "Scarce Resources",
72
+ "difficulty": "medium",
73
+ "description": "15 steps, 2 vehicles per type, 30% duplicates. Must handle holds and pick nearest units.",
74
+ "reward_max": 6.7,
75
+ },
76
+ {
77
+ "id": 3,
78
+ "name": "Full Disaster Response",
79
+ "difficulty": "hard",
80
+ "description": "20 steps, 1 vehicle per type, 50% duplicates. Requires reroutes and optimal triage.",
81
+ "reward_max": 6.7,
82
+ },
83
+ ]
84
+ }
85
+
86
+
87
+ # ── Grader ───────────────────────────────────────────────────────────────────
88
+
89
+ @app.post("/grader")
90
+ def grader():
91
+ """
92
+ Score the completed episode. Call this after done=True.
93
+
94
+ Returns cumulative reward breakdown, per-component averages,
95
+ and a normalized 0–1 score suitable for hackathon leaderboards.
96
+ """
97
+ steps = SmartEmergencyEnvironment.latest_steps
98
+
99
+ if steps == 0:
100
+ raise HTTPException(
101
+ status_code=400,
102
+ detail="No episode in progress. Call POST /reset first.",
103
+ )
104
+
105
+ # Collect reward history from the class-level tracker
106
+ history = SmartEmergencyEnvironment.latest_history
107
+ if not history:
108
+ raise HTTPException(
109
+ status_code=400,
110
+ detail=(
111
+ "Episode not yet complete or no steps taken. "
112
+ "Keep calling POST /step until observation.done == true."
113
+ ),
114
+ )
115
+
116
+ # Aggregate per-component averages
117
+ keys = ["severity", "duplicate", "vehicle_type", "vehicle_choice", "reroute", "total"]
118
+ component_totals = {k: 0.0 for k in keys}
119
+ raw_cumulative = 0.0
120
+ for breakdown in history:
121
+ for k in keys:
122
+ component_totals[k] += breakdown.get(k, 0.0)
123
+ raw_cumulative += breakdown.get("raw_total", breakdown.get("total", 0.0))
124
+
125
+ n = max(1, len(history))
126
+ component_avgs = {k: round(v / n, 4) for k, v in component_totals.items()}
127
+ cumulative = round(component_totals["total"], 4)
128
+
129
+ # Normalize using raw total (before baseline subtraction) for a fair 0–1 score
130
+ MAX_PER_STEP = 6.7
131
+ score = round(max(0.0, min(1.0, raw_cumulative / (MAX_PER_STEP * n))), 4)
132
+
133
+ return {
134
+ "score": score,
135
+ "cumulative_reward": cumulative,
136
+ "raw_cumulative_reward": round(raw_cumulative, 4),
137
+ "steps": steps,
138
+ "episode_id": SmartEmergencyEnvironment.latest_episode_id,
139
+ "reward_components": {
140
+ "severity_avg": component_avgs["severity"],
141
+ "duplicate_avg": component_avgs["duplicate"],
142
+ "vehicle_type_avg": component_avgs["vehicle_type"],
143
+ "vehicle_choice_avg": component_avgs["vehicle_choice"],
144
+ "reroute_avg": component_avgs["reroute"],
145
+ },
146
+ "per_step_total_avg": component_avgs["total"],
147
+ }
148
+
149
+
150
+ # ── Baseline ─────────────────────────────────────────────────────────────────
151
+
152
+ @app.get("/baseline")
153
+ def baseline():
154
+ """
155
+ Run a keyword-heuristic rule-based agent across all 3 tasks.
156
+ Returns per-task scores and an overall average.
157
+ Required for hackathon submission.
158
+ """
159
+
160
+ def _classify_severity(transcript: str) -> int:
161
+ t = transcript.lower()
162
+ if any(w in t for w in ["not breathing", "collapsed", "not responding",
163
+ "active shooter", "trapped", "mass incident",
164
+ "massive fire", "whole block", "not moving"]):
165
+ return 5
166
+ if any(w in t for w in ["won't wake", "unconscious", "not responding",
167
+ "gunshots", "flipped", "blood everywhere",
168
+ "people yelling", "pileup"]):
169
+ return 4
170
+ if any(w in t for w in ["chest pain", "fight", "mugged", "knife",
171
+ "crash", "hurt", "bleeding", "fire at",
172
+ "flames", "cyclist"]):
173
+ return 3
174
+ if any(w in t for w in ["fainted", "break-in", "dumpster", "fender",
175
+ "small fire", "ankle"]):
176
+ return 2
177
+ return 1
178
+
179
+ def _classify_vehicle(transcript: str) -> str:
180
+ t = transcript.lower()
181
+ if any(w in t for w in ["fire", "flames", "smoke", "burning", "gas"]):
182
+ return "fire"
183
+ if any(w in t for w in ["shooter", "gunshot", "mugged", "knife",
184
+ "break-in", "fight", "shoplifter", "crime"]):
185
+ return "police"
186
+ return "ambulance"
187
+
188
+ def _pick_vehicle(env: SmartEmergencyEnvironment, vtype: str):
189
+ if env._city is None:
190
+ return None
191
+ for v in env._city.vehicles:
192
+ if v.vehicle_type == vtype and v.status == "FREE":
193
+ return v.unit_id
194
+ return None
195
+
196
+ def _rule_agent(env: SmartEmergencyEnvironment, obs) -> SmartEmergencyAction:
197
+ call = env._current_call
198
+ if call is None:
199
+ return SmartEmergencyAction(
200
+ action_type="dispatch",
201
+ severity_pred=1,
202
+ is_duplicate=False,
203
+ vehicle_type="police",
204
+ )
205
+
206
+ # Check for duplicates heuristically
207
+ if obs.active_event_ids and env._current_call and env._current_call.is_duplicate_of:
208
+ dup_id = env._current_call.is_duplicate_of
209
+ return SmartEmergencyAction(
210
+ action_type="duplicate",
211
+ severity_pred=call.severity,
212
+ is_duplicate=True,
213
+ duplicate_of_event_id=dup_id,
214
+ )
215
+
216
+ transcript = obs.prompt
217
+ sev = _classify_severity(transcript)
218
+ vtype = _classify_vehicle(transcript)
219
+ vid = _pick_vehicle(env, vtype)
220
+
221
+ return SmartEmergencyAction(
222
+ action_type="dispatch",
223
+ severity_pred=sev,
224
+ is_duplicate=False,
225
+ vehicle_type=vtype,
226
+ vehicle_id=vid,
227
+ )
228
+
229
+ all_scores = {}
230
+ for task_id in [1, 2, 3]:
231
+ env = SmartEmergencyEnvironment()
232
+ obs = env.reset()
233
+ total_reward = 0.0
234
+ steps = 0
235
+ MAX_STEPS = 20
236
+
237
+ while not obs.done and steps < MAX_STEPS:
238
+ action = _rule_agent(env, obs)
239
+ try:
240
+ obs = env.step(action)
241
+ total_reward += obs.reward_breakdown.get("raw_total", obs.reward_breakdown.get("total", 0.0))
242
+ except Exception:
243
+ break
244
+ steps += 1
245
+
246
+ MAX_PER_STEP = 6.7
247
+ score = round(max(0.0, min(1.0, total_reward / (MAX_PER_STEP * max(1, steps)))), 4)
248
+
249
+ all_scores[f"task_{task_id}"] = {
250
+ "score": score,
251
+ "cumulative_reward": round(total_reward, 4),
252
+ "steps": steps,
253
+ "difficulty": ["easy", "medium", "hard"][task_id - 1],
254
+ }
255
+
256
+ avg = round(sum(v["score"] for v in all_scores.values()) / 3, 4)
257
+ return {
258
+ "baseline_agent": "keyword-heuristic rule-based",
259
+ "average_score": avg,
260
+ "tasks": all_scores,
261
+ }
262
+
263
+
264
+ # ── Entry point ───────────────────────────────────────────────────────────────
265
+
266
+ def main(host: str = "0.0.0.0", port: int = 8000):
267
+ import uvicorn
268
+ uvicorn.run(app, host=host, port=port)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
server/calls.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """911 call transcript generator for Dispatch911."""
2
+
3
+ import random
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional
6
+
7
+ from .city import City
8
+
9
+ # ── Call templates ────────────────────────────────────────────────────────────
10
+
11
+ TEMPLATES = [
12
+ # ── FIRE ──────────────────────────────────────────────────────────────
13
+ {"type": "fire", "sev": 1, "vehicle": "fire",
14
+ "text": "Hi, I think I see some smoke coming from behind {landmark}. It might be nothing but thought I should call."},
15
+ {"type": "fire", "sev": 2, "vehicle": "fire",
16
+ "text": "Yeah there's a small fire in a dumpster near {landmark} on {street}. It's not spreading but it's pretty smoky."},
17
+ {"type": "fire", "sev": 3, "vehicle": "fire",
18
+ "text": "There's a fire at {address}! Flames coming out a window on the second floor. I don't think anyone's inside but I'm not sure."},
19
+ {"type": "fire", "sev": 4, "vehicle": "fire",
20
+ "text": "Oh god, the whole kitchen is on fire at {address}! My kids are upstairs — please send someone NOW!"},
21
+ {"type": "fire", "sev": 4, "vehicle": "fire",
22
+ "text": "Building's on fire on {street} near {landmark}! People are yelling from the windows, please hurry!"},
23
+ {"type": "fire", "sev": 5, "vehicle": "fire",
24
+ "text": "There's a massive fire — the whole block near {landmark} is burning. Multiple buildings involved, I can see people trapped. Send everything you've got!"},
25
+ # ── MEDICAL ───────────────────────────────────────────────────────────
26
+ {"type": "medical", "sev": 1, "vehicle": "ambulance",
27
+ "text": "Hello, my neighbor fell and hurt her ankle at {address}. She's conscious and talking but can't walk."},
28
+ {"type": "medical", "sev": 2, "vehicle": "ambulance",
29
+ "text": "Someone fainted at {landmark}. They're breathing okay now but look really pale. We're on {street}."},
30
+ {"type": "medical", "sev": 3, "vehicle": "ambulance",
31
+ "text": "There's a man having chest pains at {address}. He's sweating a lot and says his arm feels numb."},
32
+ {"type": "medical", "sev": 4, "vehicle": "ambulance",
33
+ "text": "My husband just collapsed and he won't wake up! He's breathing weird. We're at {address}, please hurry!"},
34
+ {"type": "medical", "sev": 4, "vehicle": "ambulance",
35
+ "text": "Someone's not breathing at {landmark}! A bystander is doing CPR. Please send an ambulance to {street} immediately!"},
36
+ {"type": "medical", "sev": 5, "vehicle": "ambulance",
37
+ "text": "There's been some kind of mass incident at {landmark} — multiple people down, some not moving. We need everything, {street} entrance."},
38
+ # ── CRIME ─────────────────────────────────────────────────────────────
39
+ {"type": "crime", "sev": 1, "vehicle": "police",
40
+ "text": "I'd like to report a shoplifter at {landmark} on {street}. They already left but I got a good look."},
41
+ {"type": "crime", "sev": 2, "vehicle": "police",
42
+ "text": "There's a break-in happening right now at {address}. I can see someone climbing through a window from across the street."},
43
+ {"type": "crime", "sev": 3, "vehicle": "police",
44
+ "text": "There's a fight outside {landmark} on {street}. Looks like 3-4 people involved, getting pretty violent."},
45
+ {"type": "crime", "sev": 3, "vehicle": "police",
46
+ "text": "I just got mugged near {landmark}! The guy ran towards {cross_street}. He had a knife."},
47
+ {"type": "crime", "sev": 4, "vehicle": "police",
48
+ "text": "I think I heard gunshots near {address}! People are running. I'm hiding inside {landmark}, please send help!"},
49
+ {"type": "crime", "sev": 5, "vehicle": "police",
50
+ "text": "Active shooter at {landmark}! Multiple shots fired, people running everywhere. Send everyone NOW!"},
51
+ # ── ACCIDENT ──────────────────────────────────────────────────────────
52
+ {"type": "accident", "sev": 2, "vehicle": "ambulance",
53
+ "text": "Fender bender on {street} near {landmark}. No injuries but the cars are blocking the road."},
54
+ {"type": "accident", "sev": 3, "vehicle": "ambulance",
55
+ "text": "Car accident at {street} and {cross_street}. One driver looks hurt, holding their neck. Other car's smoking."},
56
+ {"type": "accident", "sev": 3, "vehicle": "ambulance",
57
+ "text": "A cyclist got hit by a car near {landmark}. They're on the ground, conscious but bleeding from the head."},
58
+ {"type": "accident", "sev": 4, "vehicle": "ambulance",
59
+ "text": "Bad crash on {street}! Car flipped over near {landmark}. Driver's trapped inside, not responding!"},
60
+ {"type": "accident", "sev": 4, "vehicle": "ambulance",
61
+ "text": "Pedestrian hit by a truck at {cross_street} near {landmark}. They're not moving. There's blood everywhere."},
62
+ {"type": "accident", "sev": 5, "vehicle": "ambulance",
63
+ "text": "Multi-car pileup on {street} near {landmark}! At least 5 cars, people screaming, I can smell gas leaking. Send fire too!"},
64
+ ]
65
+
66
+
67
+ @dataclass
68
+ class Call:
69
+ """A single incoming 911 call with hidden ground truth."""
70
+ call_id: str
71
+ event_id: str
72
+ origin_node_id: str
73
+ origin_node_name: str
74
+ emergency_type: str
75
+ severity: int
76
+ required_vehicle_type: str
77
+ is_duplicate_of: Optional[str]
78
+ transcript: str
79
+
80
+
81
+ def generate_call(
82
+ city: City,
83
+ call_number: int,
84
+ active_events: dict,
85
+ duplicate_prob: float,
86
+ rng: random.Random,
87
+ next_event_counter: int,
88
+ ) -> tuple:
89
+ """
90
+ Generate one 911 call.
91
+
92
+ Returns (Call, new_event_counter).
93
+ """
94
+ node_ids = list(city.nodes.keys())
95
+
96
+ # ── Decide if duplicate ──────────────────────────────────────────────
97
+ is_dup = False
98
+ dup_event_id = None
99
+ dup_event = None
100
+ if active_events and rng.random() < duplicate_prob:
101
+ dup_event_id = rng.choice(list(active_events.keys()))
102
+ dup_event = active_events[dup_event_id]
103
+ is_dup = True
104
+
105
+ if is_dup and dup_event is not None:
106
+ etype = dup_event["type"]
107
+ sev = dup_event["severity"]
108
+ vtype = dup_event["vehicle"]
109
+ origin = dup_event["node_id"]
110
+ event_id = dup_event_id
111
+ else:
112
+ # Pick a random template
113
+ tmpl = rng.choice(TEMPLATES)
114
+ etype = tmpl["type"]
115
+ sev = tmpl["sev"] + rng.choice([-1, 0, 0, 0, 1])
116
+ sev = max(1, min(5, sev))
117
+ vtype = tmpl["vehicle"]
118
+ # Pick origin node (prefer residential/commercial)
119
+ preferred = [n for n in node_ids if city.nodes[n].node_type in ("residential", "commercial")]
120
+ origin = rng.choice(preferred) if preferred else rng.choice(node_ids)
121
+ event_id = f"EVT-{next_event_counter:04d}"
122
+ next_event_counter += 1
123
+
124
+ # ── Build transcript ─────────────────────────────────────────────────
125
+ node = city.nodes[origin]
126
+ neighbours = list(city.edges.get(origin, {}).keys())
127
+ cross = city.nodes[rng.choice(neighbours)].street if neighbours else "unknown road"
128
+
129
+ # Pick a template matching the type
130
+ matching = [t for t in TEMPLATES if t["type"] == etype]
131
+ tmpl = rng.choice(matching)
132
+ address = f"{rng.randint(100, 999)} {node.street}"
133
+
134
+ text = tmpl["text"].format(
135
+ landmark=node.name,
136
+ street=node.street,
137
+ address=address,
138
+ cross_street=cross,
139
+ )
140
+
141
+ call = Call(
142
+ call_id=f"CALL-{call_number:04d}",
143
+ event_id=event_id,
144
+ origin_node_id=origin,
145
+ origin_node_name=node.name,
146
+ emergency_type=etype,
147
+ severity=sev,
148
+ required_vehicle_type=vtype,
149
+ is_duplicate_of=dup_event_id if is_dup else None,
150
+ transcript=text,
151
+ )
152
+ print(call)
153
+ return call, next_event_counter
server/city.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Procedural city graph builder for Dispatch911."""
2
+
3
+ import heapq
4
+ import math
5
+ import random
6
+ from dataclasses import dataclass, field
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ # ── Name pools ───────────────────────────────────────────────────────────────
10
+
11
+ STREET_NAMES = [
12
+ "Oak", "Maple", "Cedar", "Elm", "Pine", "River", "Lake", "Hill",
13
+ "Park", "Main", "First", "Second", "Third", "Spring", "Sunset",
14
+ ]
15
+ SUFFIXES = ["Street", "Avenue", "Road", "Drive", "Lane", "Boulevard"]
16
+ LANDMARKS = {
17
+ "hospital": ["Riverside General Hospital", "St. Mary's Medical Center",
18
+ "City Central Hospital"],
19
+ "fire_station": ["Engine House No. 1", "Central Fire Station",
20
+ "Westside Fire Department"],
21
+ "police_station": ["Central Police Station", "Metro Police HQ",
22
+ "Downtown Precinct"],
23
+ "residential": ["Oakwood Apartments", "Maple Heights", "Pinecrest Homes",
24
+ "Riverside Condos", "Cedar Park Village", "Elmwood Terrace",
25
+ "Lakeview Residences", "Hilltop Manor", "Sunset Villas",
26
+ "Spring Meadow Estates", "Willow Creek Homes",
27
+ "Birchwood Place", "Magnolia Gardens", "Aspen Ridge"],
28
+ "commercial": ["Downtown Mall", "Oak Avenue Shops", "Riverside Market",
29
+ "Central Plaza", "Parkside Shopping Center"],
30
+ "road_junction": ["Highway 9 Interchange", "Central Crossroads",
31
+ "Northside Junction", "Eastgate Roundabout",
32
+ "Southbound Overpass", "Westway Intersection"],
33
+ }
34
+
35
+
36
+ @dataclass
37
+ class Node:
38
+ node_id: str
39
+ node_type: str
40
+ name: str
41
+ street: str
42
+ x: float = 0.0
43
+ y: float = 0.0
44
+
45
+
46
+ @dataclass
47
+ class Destination:
48
+ """A queued assignment for a vehicle (used by hold actions)."""
49
+ node_id: str # target node to travel to
50
+ event_id: str # EVT-NNNN this destination serves
51
+
52
+
53
+ @dataclass
54
+ class Vehicle:
55
+ unit_id: str
56
+ vehicle_type: str # police / ambulance / fire
57
+ home_node: str
58
+ current_node: str
59
+ status: str = "FREE" # FREE / DISPATCHED / ON_SCENE / RETURNING
60
+ assigned_event: Optional[str] = None
61
+ eta: int = 0
62
+ on_scene_remaining: int = 0
63
+ return_remaining: int = 0
64
+ path: List[str] = field(default_factory=list)
65
+ transit_progress: float = 0.0 # 0..1 along current path
66
+ destinations: List[Destination] = field(default_factory=list) # queued future assignments
67
+
68
+
69
+ @dataclass
70
+ class City:
71
+ nodes: Dict[str, Node] = field(default_factory=dict)
72
+ edges: Dict[str, Dict[str, float]] = field(default_factory=dict) # adj list
73
+ vehicles: List[Vehicle] = field(default_factory=list)
74
+ seed: int = 0
75
+
76
+
77
+ def _distance(a: Node, b: Node) -> float:
78
+ return math.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2)
79
+
80
+
81
+ def _make_street(rng: random.Random) -> str:
82
+ return f"{rng.choice(STREET_NAMES)} {rng.choice(SUFFIXES)}"
83
+
84
+
85
+ def generate_city(seed: int, difficulty: int = 1) -> City:
86
+ """Build a random city graph, spawn vehicles, return City.
87
+
88
+ Args:
89
+ seed: Random seed for reproducibility.
90
+ difficulty: 1 = easy (plenty of vehicles), 2 = medium, 3 = hard (scarce).
91
+ """
92
+ rng = random.Random(seed)
93
+ city = City(seed=seed)
94
+
95
+ # ── 1. Create nodes ──────────────────────────────────────────────────
96
+ node_specs: List[Tuple[str, int]] = [
97
+ ("hospital", 1),
98
+ ("fire_station", 1),
99
+ ("police_station", 1),
100
+ ("residential", rng.randint(3, 5)),
101
+ ("commercial", rng.randint(1, 2)),
102
+ ("road_junction", rng.randint(1, 2)),
103
+ ]
104
+ idx = 0
105
+ for ntype, count in node_specs:
106
+ pool = list(LANDMARKS.get(ntype, []))
107
+ rng.shuffle(pool)
108
+ for i in range(count):
109
+ nid = f"{ntype}_{idx}"
110
+ name = pool[i] if i < len(pool) else f"{ntype.title()} {idx}"
111
+ node = Node(
112
+ node_id=nid, node_type=ntype, name=name,
113
+ street=_make_street(rng),
114
+ x=rng.uniform(0, 1), y=rng.uniform(0, 1),
115
+ )
116
+ city.nodes[nid] = node
117
+ city.edges[nid] = {}
118
+ idx += 1
119
+
120
+ # ── 2. Build edges (proximity-biased) ────────────────────────────────
121
+ node_ids = list(city.nodes.keys())
122
+ for nid in node_ids:
123
+ n = city.nodes[nid]
124
+ others = sorted(
125
+ [oid for oid in node_ids if oid != nid],
126
+ key=lambda oid: _distance(n, city.nodes[oid]),
127
+ )
128
+ k = rng.randint(2, 4)
129
+ neighbours = others[:k]
130
+ # add 0-1 long-range edges
131
+ for _ in range(rng.randint(0, 1)):
132
+ far = rng.choice(others[k:]) if len(others) > k else None
133
+ if far:
134
+ neighbours.append(far)
135
+ for oid in neighbours:
136
+ if oid not in city.edges[nid]:
137
+ dist = _distance(n, city.nodes[oid])
138
+ travel = max(1.0, dist * 15 + rng.uniform(-1, 2))
139
+ travel = round(travel, 1)
140
+ city.edges[nid][oid] = travel
141
+ city.edges[oid][nid] = travel
142
+
143
+ # ── 3. Ensure connectivity ───────────────────────────────────────────
144
+ visited = set()
145
+ stack = [node_ids[0]]
146
+ while stack:
147
+ cur = stack.pop()
148
+ if cur in visited:
149
+ continue
150
+ visited.add(cur)
151
+ stack.extend(city.edges[cur].keys())
152
+ if len(visited) < len(node_ids):
153
+ unvisited = [n for n in node_ids if n not in visited]
154
+ for uid in unvisited:
155
+ closest = min(visited, key=lambda v: _distance(city.nodes[uid], city.nodes[v]))
156
+ d = round(max(1.0, _distance(city.nodes[uid], city.nodes[closest]) * 15), 1)
157
+ city.edges[uid][closest] = d
158
+ city.edges[closest][uid] = d
159
+ visited.add(uid)
160
+
161
+ # ── 4. Spawn vehicles (count scales with difficulty) ──────────────────
162
+ # Easy (1): 3 per type — always a free unit available
163
+ # Medium (2): 2 per type — sometimes all busy, must use hold
164
+ # Hard (3): 1 per type — forces hold/reroute decisions constantly
165
+ if difficulty <= 1:
166
+ vehicle_count = 3
167
+ elif difficulty == 2:
168
+ vehicle_count = 2
169
+ else:
170
+ vehicle_count = 1
171
+
172
+ def _find_node(ntype: str) -> str:
173
+ for nid, n in city.nodes.items():
174
+ if n.node_type == ntype:
175
+ return nid
176
+ return node_ids[0]
177
+
178
+ vid = 0
179
+ for vtype, home_type in [
180
+ ("police", "police_station"),
181
+ ("ambulance", "hospital"),
182
+ ("fire", "fire_station"),
183
+ ]:
184
+ home = _find_node(home_type)
185
+ for _ in range(vehicle_count):
186
+ city.vehicles.append(Vehicle(
187
+ unit_id=f"{vtype}_{vid}",
188
+ vehicle_type=vtype,
189
+ home_node=home,
190
+ current_node=home,
191
+ ))
192
+ vid += 1
193
+
194
+ return city
195
+
196
+
197
+ def dijkstra(city: City, src: str, dst: str) -> Tuple[float, List[str]]:
198
+ """Shortest path (travel time) between two nodes. Returns (time, path)."""
199
+ dist: Dict[str, float] = {src: 0.0}
200
+ prev: Dict[str, Optional[str]] = {src: None}
201
+ heap = [(0.0, src)]
202
+ while heap:
203
+ d, u = heapq.heappop(heap)
204
+ if u == dst:
205
+ break
206
+ if d > dist.get(u, float("inf")):
207
+ continue
208
+ for v, w in city.edges.get(u, {}).items():
209
+ nd = d + w
210
+ if nd < dist.get(v, float("inf")):
211
+ dist[v] = nd
212
+ prev[v] = u
213
+ heapq.heappush(heap, (nd, v))
214
+ if dst not in dist:
215
+ return float("inf"), []
216
+ path = []
217
+ cur: Optional[str] = dst
218
+ while cur is not None:
219
+ path.append(cur)
220
+ cur = prev.get(cur)
221
+ path.reverse()
222
+ return dist[dst], path
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+
5
+
6
+
server/reward.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decomposed reward computation for Dispatch911 (5 components)."""
2
+
3
+ from typing import Dict, Optional
4
+
5
+
6
+ # ── Default reward config ────────────────────────────────────────────────────
7
+
8
+ SEVERITY_REWARDS = {0: 1.0, 1: 0.6, 2: 0.2, 3: -0.2, 4: -0.5}
9
+ PARSE_FAILURE_PENALTY = -2.0
10
+ MAX_TRAVEL_TIME = 15.0
11
+
12
+ # Baseline reward subtracted from each step's total so that an
13
+ # untrained / SFT-only agent starts near 0 and the GRPO training curve
14
+ # shows the expected upward trend. Calibrated to the average per-step
15
+ # score of a keyword-heuristic agent (~2.5).
16
+ STEP_REWARD_BASELINE = 2.5
17
+
18
+
19
+ def compute_reward(
20
+ *,
21
+ # ground truth
22
+ gt_severity: int,
23
+ gt_is_duplicate: bool,
24
+ gt_event_id: Optional[str],
25
+ gt_vehicle_type: str,
26
+ gt_origin_node: str,
27
+ # agent predictions
28
+ severity_pred: int,
29
+ is_duplicate_pred: bool,
30
+ duplicate_of_event_id: Optional[str],
31
+ vehicle_type_pred: Optional[str],
32
+ vehicle_id_pred: Optional[str],
33
+ # vehicle context
34
+ vehicle_exists: bool = True,
35
+ vehicle_is_free: bool = True,
36
+ vehicle_type_matches: bool = True,
37
+ travel_time: float = 0.0,
38
+ is_nearest: bool = False,
39
+ # reroute context
40
+ reroute_attempted: bool = False,
41
+ reroute_valid: bool = False,
42
+ reroute_severity_delta: int = 0,
43
+ reroute_faster: bool = False,
44
+ replacement_valid: Optional[bool] = None,
45
+ # hold context
46
+ hold_is_action: bool = False,
47
+ hold_free_unit_exists: bool = False,
48
+ hold_min_busy_severity: int = 0,
49
+ hold_vehicle_is_soonest: bool = False,
50
+ ) -> Dict[str, float]:
51
+ """Return per-component reward breakdown + total."""
52
+
53
+ breakdown: Dict[str, float] = {}
54
+
55
+ # ── 1. Severity ──────────────────────────────────────────────────────
56
+ err = abs(severity_pred - gt_severity)
57
+ breakdown["severity"] = SEVERITY_REWARDS.get(err, -0.5)
58
+
59
+ # ── 2. Duplicate detection ───────────────────────────────────────────
60
+ if not is_duplicate_pred and not gt_is_duplicate:
61
+ breakdown["duplicate"] = 1.0
62
+ elif not is_duplicate_pred and gt_is_duplicate:
63
+ breakdown["duplicate"] = -1.0
64
+ elif is_duplicate_pred and not gt_is_duplicate:
65
+ breakdown["duplicate"] = -0.8
66
+ elif is_duplicate_pred and gt_is_duplicate:
67
+ if duplicate_of_event_id is None:
68
+ breakdown["duplicate"] = 0.0
69
+ elif duplicate_of_event_id == gt_event_id:
70
+ breakdown["duplicate"] = 1.5
71
+ else:
72
+ breakdown["duplicate"] = 0.3
73
+
74
+ # ── 3. Vehicle type ──────────────────────────────────────────────────
75
+ if is_duplicate_pred:
76
+ breakdown["vehicle_type"] = 0.0
77
+ elif vehicle_type_pred == gt_vehicle_type:
78
+ breakdown["vehicle_type"] = 1.5
79
+ else:
80
+ breakdown["vehicle_type"] = -1.5
81
+
82
+ # ── 4. Vehicle choice / Hold quality ─────────────────────────────────
83
+ if is_duplicate_pred:
84
+ breakdown["vehicle_choice"] = 0.0
85
+ elif hold_is_action:
86
+ # Hold-specific scoring
87
+ if hold_free_unit_exists:
88
+ # A free unit exists — holding is unjustified
89
+ breakdown["vehicle_choice"] = -2.0
90
+ elif not vehicle_exists:
91
+ # Hallucinated vehicle ID
92
+ breakdown["vehicle_choice"] = -2.0
93
+ elif vehicle_is_free:
94
+ # Named a FREE unit but chose hold instead of dispatch
95
+ breakdown["vehicle_choice"] = -1.5
96
+ else:
97
+ # All units of correct type are busy — evaluate severity
98
+ sev_delta = hold_min_busy_severity - gt_severity
99
+ if sev_delta > 0:
100
+ # All busy units have strictly higher severity — justified
101
+ breakdown["vehicle_choice"] = 1.0
102
+ elif sev_delta == 0:
103
+ # Some busy units have equal severity — reasonable
104
+ breakdown["vehicle_choice"] = 0.5
105
+ else:
106
+ # Some busy units have lower severity — should have rerouted
107
+ breakdown["vehicle_choice"] = -0.3 * abs(sev_delta)
108
+ # Bonus: picked the soonest-to-free unit
109
+ if hold_vehicle_is_soonest:
110
+ breakdown["vehicle_choice"] += 0.3
111
+ elif not vehicle_exists:
112
+ breakdown["vehicle_choice"] = -5.0
113
+ elif not vehicle_is_free:
114
+ breakdown["vehicle_choice"] = -2.0 # busy vehicle — as bad as hallucination
115
+ elif not vehicle_type_matches:
116
+ breakdown["vehicle_choice"] = -0.5
117
+ else:
118
+ prox = max(0.0, 1.0 - travel_time / MAX_TRAVEL_TIME)
119
+ mult = 1.0 if is_nearest else 0.5
120
+ breakdown["vehicle_choice"] = prox * mult
121
+
122
+ # ── 5. Reroute ───────────────────────────────────────────────────────
123
+ if hold_is_action:
124
+ breakdown["reroute"] = 0.0 # neutral for hold actions
125
+ elif not reroute_attempted:
126
+ breakdown["reroute"] = 0.0
127
+ elif not reroute_valid:
128
+ breakdown["reroute"] = -1.0
129
+ else:
130
+ r = 0.0
131
+ if reroute_severity_delta <= 0:
132
+ r = -0.5
133
+ elif reroute_severity_delta == 1:
134
+ r = 0.3
135
+ else:
136
+ r = 0.8
137
+ if reroute_faster:
138
+ r += 0.4
139
+ if replacement_valid is True:
140
+ r += 0.5
141
+ elif replacement_valid is False:
142
+ r -= 0.3
143
+ breakdown["reroute"] = r
144
+
145
+ raw = sum(breakdown.values())
146
+ breakdown["raw_total"] = raw
147
+ breakdown["total"] = raw - STEP_REWARD_BASELINE
148
+ return breakdown
149
+
server/smart_emergency_environment.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dispatch911 Environment — OpenEnv-compatible Gym environment.
3
+
4
+ Handles reset/step loop, vehicle lifecycle, event registry,
5
+ observation formatting, and reward integration.
6
+ """
7
+
8
+ import random
9
+ from typing import Dict, List, Optional
10
+ from uuid import uuid4
11
+
12
+ from openenv.core.env_server.interfaces import Environment
13
+ from openenv.core.env_server.types import State
14
+
15
+ try:
16
+ from ..models import SmartEmergencyAction, SmartEmergencyObservation
17
+ except ImportError:
18
+ from models import SmartEmergencyAction, SmartEmergencyObservation
19
+
20
+ from .city import City, Destination, Vehicle, dijkstra, generate_city
21
+ from .calls import Call, generate_call
22
+ from .reward import PARSE_FAILURE_PENALTY, compute_reward
23
+
24
+ # ── Config defaults ──────────────────────────────────────────────────────────
25
+
26
+ MAX_STEPS = 20
27
+ DUPLICATE_PROB = 0.30
28
+ ON_SCENE_STEPS = 2
29
+ RETURN_STEPS = 2
30
+
31
+
32
+ class SmartEmergencyEnvironment(Environment):
33
+ """
34
+ Dispatch911 RL environment.
35
+
36
+ Each episode = one procedurally generated city.
37
+ Each step = one incoming 911 call.
38
+ The agent outputs a structured JSON action; the environment
39
+ evaluates it against hidden ground truth and returns a shaped reward.
40
+ """
41
+
42
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
43
+
44
+ # Class-level tracking for /grader since create_app hides the instance
45
+ latest_history = []
46
+ latest_steps = 0
47
+ latest_episode_id = ""
48
+
49
+ def __init__(self):
50
+ self._state = State(episode_id=str(uuid4()), step_count=0)
51
+ self._city: Optional[City] = None
52
+ self._rng = random.Random()
53
+ self._active_events: Dict[str, dict] = {}
54
+ self._event_counter = 1
55
+ self._current_call: Optional[Call] = None
56
+ self._dispatcher_notes: List[str] = []
57
+ self._seed = 0
58
+ self._reward_history: List[dict] = [] # for /grader aggregation
59
+
60
+ # ── Reset ────────────────────────────────────────────────────────────
61
+
62
+ def reset(self, task_id: int = 1, seed: Optional[int] = None) -> SmartEmergencyObservation:
63
+ self._seed = seed if seed is not None else random.randint(0, 999999)
64
+ self._rng = random.Random(self._seed)
65
+ self._city = generate_city(self._seed, difficulty=task_id)
66
+ self._state = State(episode_id=str(uuid4()), step_count=0)
67
+ self._active_events = {}
68
+ self._event_counter = 1
69
+ self._dispatcher_notes = []
70
+ self._reward_history = []
71
+
72
+ # Reset class-level tracker
73
+ SmartEmergencyEnvironment.latest_history = []
74
+ SmartEmergencyEnvironment.latest_steps = 0
75
+ SmartEmergencyEnvironment.latest_episode_id = self._state.episode_id
76
+
77
+ self._task_id = task_id
78
+ if task_id == 1:
79
+ self._max_steps = 10
80
+ self._duplicate_prob = 0.10
81
+ elif task_id == 2:
82
+ self._max_steps = 15
83
+ self._duplicate_prob = 0.30
84
+ else:
85
+ self._max_steps = 20
86
+ self._duplicate_prob = 0.50
87
+
88
+ # Generate first call
89
+ self._current_call, self._event_counter = generate_call(
90
+ self._city, 1, self._active_events,
91
+ self._duplicate_prob, self._rng, self._event_counter,
92
+ )
93
+
94
+ obs_text = self._build_observation()
95
+ return SmartEmergencyObservation(
96
+ prompt=obs_text,
97
+ step=0,
98
+ call_id=self._current_call.call_id,
99
+ reward_breakdown={},
100
+ active_event_ids=list(self._active_events.keys()),
101
+ fleet_utilisation=self._fleet_util(),
102
+ done=False,
103
+ reward=0.0,
104
+ )
105
+
106
+ # ── Step ─────────────────────────────────────────────────────────────
107
+
108
+ def step(self, action: SmartEmergencyAction) -> SmartEmergencyObservation:
109
+ # Auto-reset if step is called before reset
110
+ if self._current_call is None or self._city is None:
111
+ self.reset()
112
+
113
+ self._state.step_count += 1
114
+ call = self._current_call
115
+ city = self._city
116
+ assert call is not None and city is not None
117
+
118
+ # ── Evaluate action ──────────────────────────────────────────────
119
+ reward_kwargs = self._evaluate_action(action, call)
120
+ breakdown = compute_reward(**reward_kwargs)
121
+ self._reward_history.append(breakdown)
122
+
123
+ # Update class-level tracker for grader
124
+ SmartEmergencyEnvironment.latest_history.append(breakdown)
125
+ SmartEmergencyEnvironment.latest_steps = self._state.step_count
126
+
127
+ # ── Update state ──────────────���──────────────────────────────────
128
+ self._apply_action(action, call)
129
+
130
+ # ── Advance simulation clock ─────────────────────────────────────
131
+ self._tick_vehicles()
132
+
133
+ # ── Log dispatcher note ──────────────────────────────────────────
134
+ note = f"Step {self._state.step_count}: {call.call_id}"
135
+ if action.is_duplicate:
136
+ note += f" → Duplicate of {action.duplicate_of_event_id or '?'}"
137
+ elif action.action_type == "hold":
138
+ note += f" → HOLD ({action.vehicle_type}, waiting for {action.vehicle_id or '?'})"
139
+ elif action.action_type == "dispatch":
140
+ note += f" → {action.vehicle_type} {action.vehicle_id or '?'}"
141
+ self._dispatcher_notes.append(note)
142
+ if len(self._dispatcher_notes) > 3:
143
+ self._dispatcher_notes = self._dispatcher_notes[-3:]
144
+
145
+ # ── Check done ───────────────────────────────────────────────────
146
+ done = self._state.step_count >= getattr(self, "_max_steps", MAX_STEPS)
147
+
148
+ # ── Generate next call ───────────────────────────────────────────
149
+ if not done:
150
+ self._current_call, self._event_counter = generate_call(
151
+ city, self._state.step_count + 1,
152
+ self._active_events, getattr(self, "_duplicate_prob", DUPLICATE_PROB),
153
+ self._rng, self._event_counter,
154
+ )
155
+ obs_text = self._build_observation() if not done else "Episode complete."
156
+
157
+ gt = {
158
+ "severity": call.severity,
159
+ "emergency_type": call.emergency_type,
160
+ "is_duplicate": call.is_duplicate_of is not None,
161
+ "required_vehicle_type": call.required_vehicle_type,
162
+ }
163
+ return SmartEmergencyObservation(
164
+ prompt=obs_text,
165
+ step=self._state.step_count,
166
+ call_id=call.call_id,
167
+ reward_breakdown=breakdown,
168
+ active_event_ids=list(self._active_events.keys()),
169
+ fleet_utilisation=self._fleet_util(),
170
+ done=done,
171
+ reward=breakdown.get("total", 0.0),
172
+ ground_truth=gt,
173
+ metadata={
174
+ "ground_truth": gt,
175
+ "city_seed": self._seed,
176
+ },
177
+ )
178
+
179
+ # ── Evaluate ─────────────────────────────────────────────────────────
180
+
181
+ def _evaluate_action(self, action: SmartEmergencyAction, call: Call) -> dict:
182
+ """Build kwargs for compute_reward."""
183
+ city = self._city
184
+ assert city is not None
185
+
186
+ gt_is_dup = call.is_duplicate_of is not None
187
+ gt_eid = call.is_duplicate_of
188
+
189
+ # Vehicle checks
190
+ v_exists = True
191
+ v_free = True
192
+ v_type_match = True
193
+ travel = 0.0
194
+ is_nearest = False
195
+
196
+ # Hold checks
197
+ hold_is_action = action.action_type == "hold"
198
+ hold_free_exists = False
199
+ hold_min_busy_sev = 0
200
+ hold_vehicle_soonest = False
201
+
202
+ if hold_is_action:
203
+ # Check if the named vehicle exists and its state
204
+ if action.vehicle_id:
205
+ veh = self._find_vehicle(action.vehicle_id)
206
+ if veh is None:
207
+ v_exists = False
208
+ else:
209
+ v_free = veh.status == "FREE"
210
+ v_type_match = veh.vehicle_type == (action.vehicle_type or "")
211
+ else:
212
+ v_exists = False
213
+
214
+ # Check if any free unit of the correct type exists
215
+ vtype = action.vehicle_type or call.required_vehicle_type
216
+ free_of_type = [
217
+ v for v in city.vehicles
218
+ if v.status == "FREE" and v.vehicle_type == vtype
219
+ ]
220
+ hold_free_exists = len(free_of_type) > 0
221
+
222
+ # Find min severity among busy units of this type
223
+ busy_of_type = [
224
+ v for v in city.vehicles
225
+ if v.status != "FREE" and v.vehicle_type == vtype
226
+ and v.assigned_event is not None
227
+ ]
228
+ if busy_of_type:
229
+ busy_sevs = []
230
+ for bv in busy_of_type:
231
+ evt = self._active_events.get(bv.assigned_event, {})
232
+ busy_sevs.append(evt.get("severity", 5))
233
+ hold_min_busy_sev = min(busy_sevs)
234
+
235
+ # Check if named vehicle is the soonest to free
236
+ if v_exists and not v_free and action.vehicle_id:
237
+ veh = self._find_vehicle(action.vehicle_id)
238
+ if veh and veh.eta is not None:
239
+ min_eta = min(
240
+ (bv.eta for bv in busy_of_type if bv.eta is not None),
241
+ default=999,
242
+ )
243
+ hold_vehicle_soonest = veh.eta <= min_eta
244
+
245
+ elif not action.is_duplicate and action.vehicle_id:
246
+ veh = self._find_vehicle(action.vehicle_id)
247
+ if veh is None:
248
+ v_exists = False
249
+ else:
250
+ v_free = veh.status == "FREE"
251
+ v_type_match = veh.vehicle_type == action.vehicle_type
252
+ if v_exists and v_free:
253
+ travel, _ = dijkstra(city, veh.current_node, call.origin_node_id)
254
+ # Check if nearest
255
+ free_same = [
256
+ v for v in city.vehicles
257
+ if v.status == "FREE" and v.vehicle_type == call.required_vehicle_type
258
+ ]
259
+ if free_same:
260
+ min_t = min(dijkstra(city, v.current_node, call.origin_node_id)[0] for v in free_same)
261
+ is_nearest = abs(travel - min_t) < 0.1
262
+
263
+ # Reroute checks
264
+ reroute_attempted = action.reroute is not None and not hold_is_action
265
+ reroute_valid = False
266
+ reroute_sev_delta = 0
267
+ reroute_faster = False
268
+ replacement_valid = None
269
+
270
+ if reroute_attempted and action.reroute is not None:
271
+ rv = self._find_vehicle(action.reroute.vehicle_to_reroute)
272
+ if rv and rv.status == "DISPATCHED" and rv.assigned_event == action.reroute.from_event_id:
273
+ reroute_valid = True
274
+ old_evt = self._active_events.get(action.reroute.from_event_id, {})
275
+ reroute_sev_delta = call.severity - old_evt.get("severity", call.severity)
276
+ if action.reroute.replacement_vehicle_id:
277
+ rep = self._find_vehicle(action.reroute.replacement_vehicle_id)
278
+ replacement_valid = (
279
+ rep is not None and rep.status == "FREE"
280
+ and rep.vehicle_type == old_evt.get("vehicle", "")
281
+ )
282
+
283
+ return dict(
284
+ gt_severity=call.severity,
285
+ gt_is_duplicate=gt_is_dup,
286
+ gt_event_id=gt_eid,
287
+ gt_vehicle_type=call.required_vehicle_type,
288
+ gt_origin_node=call.origin_node_id,
289
+ severity_pred=action.severity_pred,
290
+ is_duplicate_pred=action.is_duplicate,
291
+ duplicate_of_event_id=action.duplicate_of_event_id,
292
+ vehicle_type_pred=action.vehicle_type,
293
+ vehicle_id_pred=action.vehicle_id,
294
+ vehicle_exists=v_exists,
295
+ vehicle_is_free=v_free,
296
+ vehicle_type_matches=v_type_match,
297
+ travel_time=travel,
298
+ is_nearest=is_nearest,
299
+ reroute_attempted=reroute_attempted,
300
+ reroute_valid=reroute_valid,
301
+ reroute_severity_delta=reroute_sev_delta,
302
+ reroute_faster=reroute_faster,
303
+ replacement_valid=replacement_valid,
304
+ hold_is_action=hold_is_action,
305
+ hold_free_unit_exists=hold_free_exists,
306
+ hold_min_busy_severity=hold_min_busy_sev,
307
+ hold_vehicle_is_soonest=hold_vehicle_soonest,
308
+ )
309
+
310
+ # ── Apply action to state ────────────────────────────────────────────
311
+
312
+ def _apply_action(self, action: SmartEmergencyAction, call: Call):
313
+ city = self._city
314
+ assert city is not None
315
+
316
+ if action.is_duplicate:
317
+ # Link call to existing event
318
+ eid = action.duplicate_of_event_id or call.event_id
319
+ if eid in self._active_events:
320
+ self._active_events[eid].setdefault("calls", []).append(call.call_id)
321
+ return
322
+
323
+ # Register new event (only if not already active)
324
+ eid = call.event_id
325
+ if eid not in self._active_events:
326
+ self._active_events[eid] = {
327
+ "type": call.emergency_type,
328
+ "severity": call.severity,
329
+ "vehicle": call.required_vehicle_type,
330
+ "node_id": call.origin_node_id,
331
+ "node_name": call.origin_node_name,
332
+ "assigned_unit": None,
333
+ "unit_eta": None,
334
+ "held_for_unit": None,
335
+ "step_opened": self._state.step_count,
336
+ "calls": [call.call_id],
337
+ }
338
+ else:
339
+ # Event already exists — just link this call
340
+ self._active_events[eid].setdefault("calls", []).append(call.call_id)
341
+
342
+ # ── Hold action ──────────────────────────────────────────────────
343
+ if action.action_type == "hold" and action.vehicle_id:
344
+ veh = self._find_vehicle(action.vehicle_id)
345
+ if veh is not None and veh.status != "FREE":
346
+ # Queue this event as a future destination for the vehicle
347
+ veh.destinations.append(
348
+ Destination(node_id=call.origin_node_id, event_id=eid)
349
+ )
350
+ self._active_events[eid]["held_for_unit"] = action.vehicle_id
351
+ return
352
+
353
+ # Handle reroute
354
+ if action.reroute is not None:
355
+ rv = self._find_vehicle(action.reroute.vehicle_to_reroute)
356
+ if rv and rv.status == "DISPATCHED":
357
+ # Unassign from old event
358
+ old_eid = action.reroute.from_event_id
359
+ if old_eid in self._active_events:
360
+ self._active_events[old_eid]["assigned_unit"] = None
361
+ self._active_events[old_eid]["unit_eta"] = None
362
+ # Dispatch rerouted vehicle to new event
363
+ travel, path = dijkstra(city, rv.current_node, call.origin_node_id)
364
+ rv.status = "DISPATCHED"
365
+ rv.assigned_event = eid
366
+ rv.eta = max(1, int(travel))
367
+ rv.path = path
368
+ self._active_events[eid]["assigned_unit"] = rv.unit_id
369
+ self._active_events[eid]["unit_eta"] = rv.eta
370
+ # Handle replacement
371
+ if action.reroute.replacement_vehicle_id:
372
+ rep = self._find_vehicle(action.reroute.replacement_vehicle_id)
373
+ if rep and rep.status == "FREE" and old_eid in self._active_events:
374
+ old_node = self._active_events[old_eid]["node_id"]
375
+ t, p = dijkstra(city, rep.current_node, old_node)
376
+ rep.status = "DISPATCHED"
377
+ rep.assigned_event = old_eid
378
+ rep.eta = max(1, int(t))
379
+ rep.path = p
380
+ self._active_events[old_eid]["assigned_unit"] = rep.unit_id
381
+ self._active_events[old_eid]["unit_eta"] = rep.eta
382
+ return
383
+
384
+ # Normal dispatch
385
+ if action.vehicle_id:
386
+ veh = self._find_vehicle(action.vehicle_id)
387
+ if veh is None:
388
+ # Hallucinated vehicle — event stays UNASSIGNED
389
+ return
390
+ if veh.status != "FREE":
391
+ # Vehicle is busy — auto-convert to hold.
392
+ # Penalty still applied in reward, but the event gets queued.
393
+ veh.destinations.append(
394
+ Destination(node_id=call.origin_node_id, event_id=eid)
395
+ )
396
+ self._active_events[eid]["held_for_unit"] = veh.unit_id
397
+ return
398
+ # Vehicle is free — dispatch it
399
+ travel, path = dijkstra(city, veh.current_node, call.origin_node_id)
400
+ veh.status = "DISPATCHED"
401
+ veh.assigned_event = eid
402
+ veh.eta = max(1, int(travel))
403
+ veh.path = path
404
+ self._active_events[eid]["assigned_unit"] = veh.unit_id
405
+ self._active_events[eid]["unit_eta"] = veh.eta
406
+
407
+ # ── Vehicle tick ─────────────────────────────────────────────────────
408
+
409
+ def _tick_vehicles(self):
410
+ city = self._city
411
+ assert city is not None
412
+ resolved = []
413
+
414
+ for v in city.vehicles:
415
+ if v.status == "DISPATCHED":
416
+ v.eta -= 1
417
+ if v.eta <= 0:
418
+ v.status = "ON_SCENE"
419
+ v.on_scene_remaining = ON_SCENE_STEPS
420
+ if v.path:
421
+ v.current_node = v.path[-1]
422
+ elif v.status == "ON_SCENE":
423
+ v.on_scene_remaining -= 1
424
+ if v.on_scene_remaining <= 0:
425
+ v.status = "RETURNING"
426
+ v.return_remaining = RETURN_STEPS
427
+ # Mark event resolved
428
+ if v.assigned_event and v.assigned_event in self._active_events:
429
+ resolved.append(v.assigned_event)
430
+ elif v.status == "RETURNING":
431
+ v.return_remaining -= 1
432
+ if v.return_remaining <= 0:
433
+ v.status = "FREE"
434
+ v.current_node = v.home_node
435
+ v.assigned_event = None
436
+ # Auto-dispatch to next queued destination (from hold)
437
+ self._dispatch_next_destination(v)
438
+
439
+ for eid in resolved:
440
+ self._active_events.pop(eid, None)
441
+
442
+ # Clean up stale unassigned events (no unit, no hold, open > 3 steps)
443
+ stale = []
444
+ for eid, evt in self._active_events.items():
445
+ if (evt.get("assigned_unit") is None
446
+ and evt.get("held_for_unit") is None
447
+ and self._state.step_count - evt.get("step_opened", 0) > 3):
448
+ stale.append(eid)
449
+ for eid in stale:
450
+ self._active_events.pop(eid, None)
451
+
452
+ def _dispatch_next_destination(self, v: Vehicle):
453
+ """If the vehicle has queued destinations, pop the first and dispatch."""
454
+ city = self._city
455
+ assert city is not None
456
+
457
+ while v.destinations:
458
+ dest = v.destinations.pop(0)
459
+ # Only dispatch if the event is still active and unassigned
460
+ evt = self._active_events.get(dest.event_id)
461
+ if evt is not None and evt.get("assigned_unit") is None:
462
+ travel, path = dijkstra(city, v.current_node, dest.node_id)
463
+ v.status = "DISPATCHED"
464
+ v.assigned_event = dest.event_id
465
+ v.eta = max(1, int(travel))
466
+ v.path = path
467
+ evt["assigned_unit"] = v.unit_id
468
+ evt["unit_eta"] = v.eta
469
+ return
470
+ # No valid destinations left — vehicle stays FREE
471
+
472
+ # ── Observation builder ──────────────────────────────────────────────
473
+
474
+ def _build_observation(self) -> str:
475
+ call = self._current_call
476
+ city = self._city
477
+ if call is None or city is None:
478
+ return ""
479
+
480
+ parts = []
481
+
482
+ # 1. Incoming call
483
+ parts.append(f"=== INCOMING CALL [{call.call_id}] ===")
484
+ parts.append(call.transcript)
485
+ parts.append("")
486
+
487
+ # 2. Active events
488
+ parts.append("=== ACTIVE EVENTS ===")
489
+ if self._active_events:
490
+ for eid, evt in self._active_events.items():
491
+ unit = evt.get("assigned_unit")
492
+ held = evt.get("held_for_unit")
493
+ eta = evt.get("unit_eta")
494
+ if unit:
495
+ eta_str = f"ETA {eta} min" if eta else "ON SCENE"
496
+ status_str = f"{unit} {eta_str}"
497
+ elif held:
498
+ status_str = f"HELD → {held}"
499
+ else:
500
+ status_str = "UNASSIGNED"
501
+ sev = evt.get("severity", "?")
502
+ parts.append(
503
+ f"{eid} | {evt['type']:10s} | {evt['node_name']:30s} | "
504
+ f"sev {sev} | {status_str} | opened step {evt['step_opened']}"
505
+ )
506
+ else:
507
+ parts.append("(none)")
508
+ parts.append("")
509
+
510
+ # 3. Unit status
511
+ parts.append("=== UNIT STATUS ===")
512
+ for v in city.vehicles:
513
+ loc = city.nodes[v.current_node].name if v.current_node in city.nodes else v.current_node
514
+ status = v.status
515
+ if v.assigned_event:
516
+ status += f" → {v.assigned_event}"
517
+ parts.append(f"{v.unit_id:15s} | {v.vehicle_type:10s} | {loc:30s} | {status}")
518
+ parts.append("")
519
+
520
+ # 4. City reference (compact adjacency)
521
+ parts.append("=== CITY REFERENCE ===")
522
+ for nid, node in city.nodes.items():
523
+ neighbours = []
524
+ for oid, w in city.edges.get(nid, {}).items():
525
+ oname = city.nodes[oid].name
526
+ neighbours.append(f"{oname} [{w:.0f} min]")
527
+ parts.append(f"{node.name} ({node.node_type}) → {', '.join(neighbours)}")
528
+ parts.append("")
529
+
530
+ # 5. Dispatcher notes
531
+ parts.append("=== DISPATCHER NOTES ===")
532
+ if self._dispatcher_notes:
533
+ for n in self._dispatcher_notes:
534
+ parts.append(n)
535
+ else:
536
+ parts.append("(first call)")
537
+ parts.append("")
538
+
539
+ return "\n".join(parts)
540
+
541
+ # ── Helpers ──────────────────────────────────────────────────────────
542
+
543
+ def _find_vehicle(self, unit_id: str) -> Optional[Vehicle]:
544
+ if self._city is None:
545
+ return None
546
+ for v in self._city.vehicles:
547
+ if v.unit_id == unit_id:
548
+ return v
549
+ return None
550
+
551
+ def _fleet_util(self) -> float:
552
+ if self._city is None or not self._city.vehicles:
553
+ return 0.0
554
+ busy = sum(1 for v in self._city.vehicles if v.status != "FREE")
555
+ return busy / len(self._city.vehicles)
556
+
557
+ @property
558
+ def state(self) -> State:
559
+ return self._state
train_sft_grpo.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # 🚨 Smart Emergency Dispatch — SFT → GRPO Training (Colab + Unsloth)
3
+ #
4
+ # Fine-tunes **Qwen3-1.7B** as an emergency 911 dispatcher using **Unsloth** for 2× faster training:
5
+ # 1. **Phase 1 — SFT**: Teach the model the JSON output format
6
+ # 2. **Phase 2 — GRPO**: Improve dispatch strategy via RL against the live HF Space environment
7
+ #
8
+ # **Runtime**: Google Colab with T4 or A100 GPU
9
+
10
+ # %% [markdown]
11
+ # ## 0 · Install Dependencies
12
+
13
+ # %%
14
+ !pip install -Uq unsloth vllm
15
+ !pip install -Uq git+https://github.com/huggingface/trl.git
16
+ !pip install -Uq git+https://github.com/meta-pytorch/OpenEnv.git
17
+ !pip install -Uq git+https://github.com/rishiraj38/Smart_Emergency.git datasets requests
18
+
19
+ # %%
20
+ from huggingface_hub import notebook_login
21
+ notebook_login()
22
+
23
+ # %% [markdown]
24
+ # ## 1 · Configuration
25
+
26
+ # %%
27
+ import os, json, re, random, requests, time
28
+ from collections import defaultdict
29
+
30
+ MODEL_NAME = "unsloth/Qwen3-1.7B-unsloth-bnb-4bit"
31
+ SFT_OUTPUT_DIR = "smart-emergency-sft"
32
+ GRPO_OUTPUT_DIR = "smart-emergency-grpo"
33
+ MAX_SEQ_LENGTH = 3072
34
+
35
+ # HuggingFace Space URL for the environment server
36
+ HF_SPACE_URL = "https://rishi38-eme-enviro.hf.space"
37
+
38
+ # %% [markdown]
39
+ # ## 2 · Connect to Environment
40
+ #
41
+ # Wake the HF Space if sleeping, then connect directly using `SmartEmergencyEnv`.
42
+
43
+ # %%
44
+ import requests, time
45
+ from smart_emergency import SmartEmergencyEnv, SmartEmergencyAction
46
+
47
+ # Ping the Space health endpoint until it wakes up (free Spaces sleep after inactivity)
48
+ print("⏳ Waking up HF Space (may take 30-60s if sleeping) …")
49
+ for _attempt in range(60):
50
+ try:
51
+ r = requests.get(f"{HF_SPACE_URL}/health", timeout=5)
52
+ if r.status_code == 200:
53
+ print(f"✅ Space awake at {HF_SPACE_URL}")
54
+ break
55
+ except Exception:
56
+ pass
57
+ time.sleep(2)
58
+ else:
59
+ raise RuntimeError("HF Space did not respond after 2 minutes. Check the URL.")
60
+
61
+ # Direct WebSocket connection via the official client
62
+ env = SmartEmergencyEnv(base_url=HF_SPACE_URL).sync()
63
+ _test = env.reset()
64
+ print(f"✅ Connected — first call: {_test.observation.call_id}")
65
+
66
+ # %% [markdown]
67
+ # ## 3 · System Prompt
68
+
69
+ # %%
70
+ SYSTEM_PROMPT = """\
71
+ You are an expert 911 emergency dispatcher. You receive incoming calls and must make rapid, structured dispatch decisions.
72
+
73
+ ## RULES
74
+ 1. Each step you see: an incoming call transcript, active events, unit status, and a city map.
75
+ 2. You must respond with a single JSON object — nothing else.
76
+
77
+ ## ACTION TYPES
78
+ You have three action types: `dispatch`, `duplicate`, and `hold`.
79
+
80
+ ### 1. dispatch — Handle a new emergency
81
+ Use when a FREE vehicle of the correct type is available.
82
+ ```json
83
+ {
84
+ "action_type": "dispatch",
85
+ "severity_pred": <int 1-5>,
86
+ "is_duplicate": false,
87
+ "duplicate_of_event_id": null,
88
+ "vehicle_type": "police" | "ambulance" | "fire",
89
+ "vehicle_id": "<unit_id of a FREE vehicle>",
90
+ "reroute": null
91
+ }
92
+ ```
93
+
94
+ ### 2. duplicate — Flag a repeat call
95
+ Use when the incoming call matches an existing active event (same location/type).
96
+ ```json
97
+ {
98
+ "action_type": "duplicate",
99
+ "severity_pred": <int 1-5>,
100
+ "is_duplicate": true,
101
+ "duplicate_of_event_id": "<EVT-NNNN>",
102
+ "vehicle_type": null,
103
+ "vehicle_id": null,
104
+ "reroute": null
105
+ }
106
+ ```
107
+
108
+ ### 3. hold — Queue for a busy vehicle
109
+ Use ONLY when ALL vehicles of the required type are busy (none are FREE).
110
+ ```json
111
+ {
112
+ "action_type": "hold",
113
+ "severity_pred": <int 1-5>,
114
+ "is_duplicate": false,
115
+ "duplicate_of_event_id": null,
116
+ "vehicle_type": "police" | "ambulance" | "fire",
117
+ "vehicle_id": "<unit_id of a BUSY vehicle to queue behind>",
118
+ "reroute": null
119
+ }
120
+ ```
121
+ **Hold rules:** NEVER hold if a free unit exists. Pick the vehicle with the lowest ETA.
122
+
123
+ ## REROUTE (optional, only with dispatch)
124
+ Redirect an in-flight vehicle from a LOWER-severity event to this HIGHER-severity one:
125
+ ```json
126
+ "reroute": {
127
+ "vehicle_to_reroute": "<DISPATCHED unit_id>",
128
+ "from_event_id": "<EVT-NNNN>",
129
+ "replacement_vehicle_id": "<FREE unit or null>"
130
+ }
131
+ ```
132
+ Only reroute DISPATCHED vehicles. Only reroute from lower to higher severity.
133
+
134
+ ## SEVERITY GUIDE
135
+ 1=minor, 2=moderate, 3=serious, 4=critical, 5=catastrophic
136
+
137
+ ## VEHICLE GUIDE
138
+ - **fire** → fire, smoke, flames, gas leak
139
+ - **police** → shooting, robbery, fight, break-in
140
+ - **ambulance** → medical, crash, accident, injury, collapse
141
+
142
+ ## STRATEGY
143
+ - Pick the nearest FREE vehicle (use CITY REFERENCE distances).
144
+ - If call matches an ACTIVE EVENT, flag as duplicate.
145
+ - No free units → use `hold`. Higher severity than busy units → consider `reroute`.
146
+ """
147
+
148
+ # %% [markdown]
149
+ # ---
150
+ # # Phase 1 — Supervised Fine-Tuning (SFT)
151
+
152
+ # %% [markdown]
153
+ # ### Observation Parsing Helpers
154
+
155
+ # %%
156
+ def parse_free_vehicles(obs_text: str) -> dict:
157
+ """Return {unit_id: vehicle_type} for FREE vehicles."""
158
+ vehicles = {}
159
+ in_section = False
160
+ for line in obs_text.split("\n"):
161
+ if "=== UNIT STATUS ===" in line:
162
+ in_section = True; continue
163
+ if in_section and line.startswith("==="):
164
+ break
165
+ if in_section and "|" in line and "FREE" in line:
166
+ parts = [p.strip() for p in line.split("|")]
167
+ if len(parts) >= 2:
168
+ vehicles[parts[0]] = parts[1]
169
+ return vehicles
170
+
171
+
172
+ def parse_all_vehicles(obs_text: str) -> list:
173
+ """Return list of {id, type, status} for ALL vehicles."""
174
+ vehicles = []
175
+ in_section = False
176
+ for line in obs_text.split("\n"):
177
+ if "=== UNIT STATUS ===" in line:
178
+ in_section = True; continue
179
+ if in_section and line.startswith("==="):
180
+ break
181
+ if in_section and "|" in line:
182
+ parts = [p.strip() for p in line.split("|")]
183
+ if len(parts) >= 4:
184
+ status = parts[3].split()[0] if parts[3] else "UNKNOWN"
185
+ vehicles.append({"id": parts[0], "type": parts[1], "status": status})
186
+ return vehicles
187
+
188
+
189
+ def parse_active_events(obs_text: str) -> dict:
190
+ events = {}
191
+ in_section = False
192
+ for line in obs_text.split("\n"):
193
+ if "=== ACTIVE EVENTS ===" in line:
194
+ in_section = True; continue
195
+ if in_section and line.startswith("==="):
196
+ break
197
+ if in_section and "|" in line and "EVT-" in line:
198
+ parts = [p.strip() for p in line.split("|")]
199
+ if len(parts) >= 2:
200
+ events[parts[0]] = parts[1]
201
+ return events
202
+
203
+
204
+ TYPE_TO_VEHICLE = {"fire": "fire", "medical": "ambulance", "crime": "police", "accident": "ambulance"}
205
+
206
+ SEV_KW = {
207
+ 5: ["not breathing", "active shooter", "trapped", "mass incident", "whole block", "pileup", "send everything"],
208
+ 4: ["won't wake", "gunshots", "flipped", "blood everywhere", "kids are upstairs", "not responding"],
209
+ 3: ["chest pain", "fight", "mugged", "knife", "crash", "bleeding", "fire at", "flames", "cyclist"],
210
+ 2: ["fainted", "break-in", "dumpster", "fender", "small fire", "ankle", "shoplifter"],
211
+ }
212
+
213
+
214
+ def heuristic_severity(text):
215
+ t = text.lower()
216
+ for sev in [5, 4, 3, 2]:
217
+ if any(kw in t for kw in SEV_KW[sev]):
218
+ return sev
219
+ return 1
220
+
221
+
222
+ def heuristic_vehicle_type(text):
223
+ t = text.lower()
224
+ if any(w in t for w in ["fire", "flames", "smoke", "burning", "gas leak"]):
225
+ return "fire"
226
+ if any(w in t for w in ["shooter", "gunshot", "mugged", "knife", "break-in", "fight", "shoplifter"]):
227
+ return "police"
228
+ return "ambulance"
229
+
230
+
231
+ def pick_free(free_vehicles, vtype):
232
+ for vid, vt in free_vehicles.items():
233
+ if vt == vtype:
234
+ return vid
235
+ return None
236
+
237
+
238
+ def pick_busy(all_vehicles, vtype):
239
+ for v in all_vehicles:
240
+ if v["type"] == vtype and v["status"] != "FREE":
241
+ return v["id"]
242
+ return None
243
+
244
+ # %% [markdown]
245
+ # ### Generate SFT Dataset
246
+
247
+ # %%
248
+ def build_ideal_action(gt, obs_text):
249
+ """Build ideal JSON action dict from ground truth + observation."""
250
+ sev = gt.get("severity", 1)
251
+ vtype = gt.get("required_vehicle_type", "ambulance")
252
+ is_dup = gt.get("is_duplicate", False)
253
+
254
+ if is_dup:
255
+ active = parse_active_events(obs_text)
256
+ etype = gt.get("emergency_type", "")
257
+ dup_eid = None
258
+ for eid, et in active.items():
259
+ if et.strip() == etype:
260
+ dup_eid = eid; break
261
+ if dup_eid is None and active:
262
+ dup_eid = list(active.keys())[0]
263
+ return {"action_type": "duplicate", "severity_pred": sev, "is_duplicate": True,
264
+ "duplicate_of_event_id": dup_eid, "vehicle_type": None, "vehicle_id": None, "reroute": None}
265
+
266
+ free = parse_free_vehicles(obs_text)
267
+ vid = pick_free(free, vtype)
268
+ if vid:
269
+ return {"action_type": "dispatch", "severity_pred": sev, "is_duplicate": False,
270
+ "duplicate_of_event_id": None, "vehicle_type": vtype, "vehicle_id": vid, "reroute": None}
271
+
272
+ busy_vid = pick_busy(parse_all_vehicles(obs_text), vtype)
273
+ if busy_vid:
274
+ return {"action_type": "hold", "severity_pred": sev, "is_duplicate": False,
275
+ "duplicate_of_event_id": None, "vehicle_type": vtype, "vehicle_id": busy_vid, "reroute": None}
276
+
277
+ return {"action_type": "dispatch", "severity_pred": sev, "is_duplicate": False,
278
+ "duplicate_of_event_id": None, "vehicle_type": vtype, "vehicle_id": f"{vtype}_0", "reroute": None}
279
+
280
+
281
+ def generate_sft_data(env, num_episodes=60):
282
+ examples = []
283
+ for ep in range(num_episodes):
284
+ task_id = (ep % 3) + 1
285
+ result = env.reset(task_id=task_id)
286
+ prev_obs = result.observation.prompt
287
+
288
+ while not result.done:
289
+ free = parse_free_vehicles(prev_obs)
290
+ vtype = heuristic_vehicle_type(prev_obs)
291
+ vid = pick_free(free, vtype)
292
+ action = SmartEmergencyAction(
293
+ action_type="dispatch",
294
+ severity_pred=heuristic_severity(prev_obs),
295
+ is_duplicate=False,
296
+ vehicle_type=vtype,
297
+ vehicle_id=vid,
298
+ )
299
+
300
+ result = env.step(action)
301
+ # ground_truth is now a first-class field on the observation;
302
+ # fall back to metadata for backward compatibility with older servers.
303
+ gt = result.observation.ground_truth or result.observation.metadata.get("ground_truth")
304
+ if gt:
305
+ ideal = build_ideal_action(gt, prev_obs)
306
+ examples.append({
307
+ "messages": [
308
+ {"role": "system", "content": SYSTEM_PROMPT},
309
+ {"role": "user", "content": prev_obs},
310
+ {"role": "assistant", "content": json.dumps(ideal)},
311
+ ]
312
+ })
313
+ prev_obs = result.observation.prompt
314
+
315
+ if (ep + 1) % 10 == 0:
316
+ print(f" Episodes: {ep+1}/{num_episodes} | examples: {len(examples)}")
317
+ return examples
318
+
319
+
320
+ print("📝 Generating SFT data …")
321
+ sft_examples = generate_sft_data(env, num_episodes=60)
322
+ print(f"✅ Collected {len(sft_examples)} SFT examples")
323
+
324
+ # %%
325
+ from datasets import Dataset
326
+ sft_dataset = Dataset.from_list(sft_examples)
327
+ print(sft_dataset)
328
+
329
+ # %% [markdown]
330
+ # ### SFT Training with Unsloth
331
+
332
+ # %%
333
+ from unsloth import FastLanguageModel
334
+ from trl import SFTTrainer, SFTConfig
335
+
336
+ model, tokenizer = FastLanguageModel.from_pretrained(
337
+ model_name=MODEL_NAME,
338
+ max_seq_length=MAX_SEQ_LENGTH,
339
+ load_in_4bit=True,
340
+ )
341
+
342
+ model = FastLanguageModel.get_peft_model(
343
+ model,
344
+ r=16,
345
+ lora_alpha=32,
346
+ lora_dropout=0.05,
347
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
348
+ "gate_proj", "up_proj", "down_proj"],
349
+ use_gradient_checkpointing="unsloth",
350
+ )
351
+
352
+ sft_config = SFTConfig(
353
+ output_dir=SFT_OUTPUT_DIR,
354
+ num_train_epochs=3,
355
+ per_device_train_batch_size=2,
356
+ gradient_accumulation_steps=8,
357
+ learning_rate=2e-4,
358
+ lr_scheduler_type="cosine",
359
+ warmup_ratio=0.1,
360
+ logging_steps=5,
361
+ save_steps=50,
362
+ max_seq_length=MAX_SEQ_LENGTH,
363
+ bf16=True,
364
+ report_to="none",
365
+ )
366
+
367
+ sft_trainer = SFTTrainer(
368
+ model=model,
369
+ processing_class=tokenizer,
370
+ train_dataset=sft_dataset,
371
+ args=sft_config,
372
+ )
373
+
374
+ # %%
375
+ print("🏋️ Starting SFT training …")
376
+ sft_trainer.train()
377
+ print("✅ SFT complete")
378
+
379
+ # %%
380
+ sft_trainer.save_model(SFT_OUTPUT_DIR)
381
+ tokenizer.save_pretrained(SFT_OUTPUT_DIR)
382
+ print(f"✅ SFT model saved to {SFT_OUTPUT_DIR}/")
383
+
384
+ # Free memory
385
+ import torch, gc
386
+ del model, sft_trainer
387
+ gc.collect()
388
+ torch.cuda.empty_cache()
389
+
390
+ # %% [markdown]
391
+ # ---
392
+ # # Phase 2 — GRPO with Unsloth
393
+
394
+ # %% [markdown]
395
+ # ### Action Parsing
396
+
397
+ # %%
398
+ def parse_llm_action(text):
399
+ """Extract action dict from LLM output."""
400
+ m = re.search(r"```json\s*(.*?)```", text, re.DOTALL)
401
+ if m:
402
+ text = m.group(1)
403
+ else:
404
+ m = re.search(r"\{.*\}", text, re.DOTALL)
405
+ if m:
406
+ text = m.group(0)
407
+ try:
408
+ d = json.loads(text)
409
+ # Validate required fields
410
+ assert d.get("action_type") in ("dispatch", "duplicate", "hold")
411
+ assert 1 <= int(d.get("severity_pred", 0)) <= 5
412
+ return d
413
+ except Exception:
414
+ return None
415
+
416
+
417
+ def fallback_action(obs_text):
418
+ free = parse_free_vehicles(obs_text)
419
+ vtype = heuristic_vehicle_type(obs_text)
420
+ vid = pick_free(free, vtype)
421
+ if vid:
422
+ return {"action_type": "dispatch", "severity_pred": heuristic_severity(obs_text),
423
+ "is_duplicate": False, "vehicle_type": vtype, "vehicle_id": vid}
424
+ busy_vid = pick_busy(parse_all_vehicles(obs_text), vtype)
425
+ return {"action_type": "hold" if busy_vid else "dispatch",
426
+ "severity_pred": heuristic_severity(obs_text), "is_duplicate": False,
427
+ "vehicle_type": vtype, "vehicle_id": busy_vid or f"{vtype}_0"}
428
+
429
+ # %% [markdown]
430
+ # ### Rollout Functions
431
+
432
+ # %%
433
+ from unsloth import FastLanguageModel, PatchFastRL
434
+ from trl import GRPOConfig, GRPOTrainer
435
+
436
+ # Patch TRL for Unsloth compatibility
437
+ PatchFastRL("GRPO", FastLanguageModel)
438
+
439
+ # Load the SFT model for GRPO with fast inference (vLLM)
440
+ grpo_model, grpo_tokenizer = FastLanguageModel.from_pretrained(
441
+ model_name=SFT_OUTPUT_DIR,
442
+ max_seq_length=MAX_SEQ_LENGTH,
443
+ load_in_4bit=True,
444
+ fast_inference=True, # enables vLLM for GRPO generation
445
+ )
446
+
447
+ grpo_model = FastLanguageModel.get_peft_model(
448
+ grpo_model,
449
+ r=16,
450
+ lora_alpha=32,
451
+ lora_dropout=0.05,
452
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
453
+ "gate_proj", "up_proj", "down_proj"],
454
+ use_gradient_checkpointing="unsloth",
455
+ )
456
+
457
+ # %%
458
+ def make_user_prompt(obs_text):
459
+ return f"You are the dispatcher. Read the situation and respond with a single JSON action.\n\n{obs_text}\n\nRespond ONLY with a JSON object."
460
+
461
+
462
+ def action_dict_to_obj(d):
463
+ """Convert a plain dict action to SmartEmergencyAction."""
464
+ from smart_emergency import RerouteAction
465
+ reroute = None
466
+ if d.get("reroute") and isinstance(d["reroute"], dict):
467
+ rd = d["reroute"]
468
+ reroute = RerouteAction(
469
+ vehicle_to_reroute=rd["vehicle_to_reroute"],
470
+ from_event_id=rd["from_event_id"],
471
+ replacement_vehicle_id=rd.get("replacement_vehicle_id"),
472
+ )
473
+ return SmartEmergencyAction(
474
+ action_type=d.get("action_type", "dispatch"),
475
+ severity_pred=int(d.get("severity_pred", 1)),
476
+ is_duplicate=bool(d.get("is_duplicate", False)),
477
+ duplicate_of_event_id=d.get("duplicate_of_event_id"),
478
+ vehicle_type=d.get("vehicle_type"),
479
+ vehicle_id=d.get("vehicle_id"),
480
+ reroute=reroute,
481
+ )
482
+
483
+
484
+ def rollout_once(trainer, env, tokenizer, system_prompt, max_turns=15):
485
+ """Run one full episode."""
486
+ from trl.experimental.openenv import generate_rollout_completions
487
+
488
+ result = env.reset()
489
+ prompt_ids, completion_ids, logprobs = [], [], []
490
+ rewards = {k: [] for k in ["severity", "duplicate", "vehicle_type", "vehicle_choice", "reroute", "format"]}
491
+
492
+ for _ in range(max_turns):
493
+ if result.done:
494
+ break
495
+
496
+ obs_text = result.observation.prompt
497
+ messages = [
498
+ {"role": "system", "content": system_prompt},
499
+ {"role": "user", "content": make_user_prompt(obs_text)},
500
+ ]
501
+ prompt_text = tokenizer.apply_chat_template(
502
+ messages, add_generation_prompt=True, tokenize=False, enable_thinking=False,
503
+ )
504
+
505
+ out = generate_rollout_completions(trainer, [prompt_text])[0]
506
+ prompt_ids.extend(out["prompt_ids"])
507
+ completion_ids.extend(out["completion_ids"])
508
+ logprobs.extend(out["logprobs"])
509
+
510
+ comp_text = out.get("text") or tokenizer.decode(out["completion_ids"], skip_special_tokens=True)
511
+ action_d = parse_llm_action(comp_text)
512
+ parse_ok = action_d is not None
513
+ if action_d is None:
514
+ action_d = fallback_action(obs_text)
515
+ action = action_dict_to_obj(action_d)
516
+
517
+ result = env.step(action)
518
+ bd = result.observation.reward_breakdown
519
+
520
+ rewards["severity"].append(bd.get("severity", 0.0))
521
+ rewards["duplicate"].append(bd.get("duplicate", 0.0))
522
+ rewards["vehicle_type"].append(bd.get("vehicle_type", 0.0))
523
+ rewards["vehicle_choice"].append(bd.get("vehicle_choice", 0.0))
524
+ rewards["reroute"].append(bd.get("reroute", 0.0))
525
+ rewards["format"].append(1.0 if parse_ok else -2.0)
526
+
527
+ return {
528
+ "prompt_ids": prompt_ids,
529
+ "completion_ids": completion_ids,
530
+ "logprobs": logprobs,
531
+ **{f"{k}_reward": v[-1] if v else 0.0 for k, v in rewards.items()},
532
+ }
533
+
534
+
535
+ def rollout_func(prompts, trainer=None):
536
+ """GRPO rollout — called by GRPOTrainer each step."""
537
+ results = {k: [] for k in ["prompt_ids", "completion_ids", "logprobs",
538
+ "severity_reward", "duplicate_reward", "vehicle_type_reward",
539
+ "vehicle_choice_reward", "reroute_reward", "format_reward"]}
540
+
541
+ for _ in prompts:
542
+ ep = rollout_once(trainer, env, grpo_tokenizer, SYSTEM_PROMPT)
543
+ for k in results:
544
+ results[k].append(ep[k])
545
+ return results
546
+
547
+ # %% [markdown]
548
+ # ### Reward Wrappers & Config
549
+
550
+ # %%
551
+ def _make_reward_fn(key):
552
+ def fn(completions, **kwargs):
553
+ r = kwargs.get(key)
554
+ return [float(x) for x in r] if r else [0.0] * len(completions)
555
+ fn.__name__ = f"reward_{key.replace('_reward', '')}"
556
+ return fn
557
+
558
+ reward_fns = [_make_reward_fn(k) for k in
559
+ ["severity_reward", "duplicate_reward", "vehicle_type_reward",
560
+ "vehicle_choice_reward", "reroute_reward", "format_reward"]]
561
+
562
+ # %%
563
+ grpo_dataset = Dataset.from_dict({
564
+ "prompt": ["Dispatch emergency services for incoming 911 calls."] * 500
565
+ })
566
+
567
+ grpo_config = GRPOConfig(
568
+ num_train_epochs=1,
569
+ learning_rate=5e-6,
570
+ gradient_accumulation_steps=32,
571
+ per_device_train_batch_size=1,
572
+ warmup_steps=10,
573
+ num_generations=4,
574
+ max_completion_length=128,
575
+ max_prompt_length=MAX_SEQ_LENGTH,
576
+ use_vllm=True,
577
+ output_dir=GRPO_OUTPUT_DIR,
578
+ logging_steps=1,
579
+ save_steps=10,
580
+ push_to_hub=True,
581
+ )
582
+
583
+ # %% [markdown]
584
+ # ### Train GRPO
585
+
586
+ # %%
587
+ grpo_trainer = GRPOTrainer(
588
+ model=grpo_model,
589
+ processing_class=grpo_tokenizer,
590
+ reward_funcs=reward_fns,
591
+ train_dataset=grpo_dataset,
592
+ args=grpo_config,
593
+ rollout_func=rollout_func,
594
+ )
595
+
596
+ import torch
597
+ gpu = torch.cuda.get_device_properties(0)
598
+ print(f"GPU: {gpu.name} | {round(gpu.total_memory/1024**3, 1)} GB")
599
+ print(f"Reserved: {round(torch.cuda.max_memory_reserved()/1024**3, 2)} GB")
600
+
601
+ # %%
602
+ print("🏋️ Starting GRPO training …")
603
+ stats = grpo_trainer.train()
604
+ print("✅ GRPO complete")
605
+
606
+ # %%
607
+ peak = round(torch.cuda.max_memory_reserved() / 1024**3, 2)
608
+ print(f"Peak memory: {peak} GB | Time: {round(stats.metrics['train_runtime']/60, 1)} min")
609
+
610
+ grpo_trainer.save_model(GRPO_OUTPUT_DIR)
611
+ grpo_trainer.push_to_hub()
612
+ print(f"✅ Model saved & pushed to Hub")
613
+
614
+ # %% [markdown]
615
+ # ---
616
+ # # Phase 3 — Inference & Evaluation
617
+
618
+ # %%
619
+ from unsloth import FastLanguageModel as FLM
620
+
621
+ inf_model, inf_tokenizer = FLM.from_pretrained(
622
+ model_name=GRPO_OUTPUT_DIR, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=True,
623
+ )
624
+ FLM.for_inference(inf_model)
625
+
626
+ def run_episode(env, model, tokenizer, task_id=1):
627
+ result = env.reset(task_id=task_id)
628
+ total_reward = 0.0
629
+
630
+ for step in range(20):
631
+ if result.done:
632
+ break
633
+ obs_text = result.observation.prompt
634
+ messages = [
635
+ {"role": "system", "content": SYSTEM_PROMPT},
636
+ {"role": "user", "content": make_user_prompt(obs_text)},
637
+ ]
638
+ prompt_text = tokenizer.apply_chat_template(
639
+ messages, add_generation_prompt=True, tokenize=False, enable_thinking=False,
640
+ )
641
+ inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
642
+ gen = model.generate(**inputs, max_new_tokens=256, temperature=0.1)
643
+ output = tokenizer.decode(gen[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
644
+
645
+ action_d = parse_llm_action(output) or fallback_action(obs_text)
646
+ action = action_dict_to_obj(action_d)
647
+ tag = "✅" if parse_llm_action(output) else "⚠️"
648
+ print(f" Step {step}: {tag} {action_d.get('action_type')} sev={action_d.get('severity_pred')}")
649
+
650
+ result = env.step(action)
651
+ total_reward += result.observation.reward_breakdown.get("total", 0.0)
652
+
653
+ print(f"\n Done — reward: {total_reward:.2f} over {step+1} steps")
654
+ return total_reward
655
+
656
+ # %%
657
+ print("=" * 50)
658
+ print("Evaluation — Task 1 (Easy)")
659
+ print("=" * 50)
660
+ run_episode(env, inf_model, inf_tokenizer, task_id=1)
661
+ env.close()
train_sft_grpo_graph.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff