Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- Makefile +18 -0
- README.md +260 -3
- __init__.py +17 -0
- client.py +100 -0
- models.py +92 -0
- openenv.yaml +64 -0
- openenv_smart_emergency.egg-info/PKG-INFO +9 -0
- openenv_smart_emergency.egg-info/SOURCES.txt +18 -0
- openenv_smart_emergency.egg-info/dependency_links.txt +1 -0
- openenv_smart_emergency.egg-info/entry_points.txt +2 -0
- openenv_smart_emergency.egg-info/requires.txt +5 -0
- openenv_smart_emergency.egg-info/top_level.txt +1 -0
- pyproject.toml +45 -0
- server/__init__.py +11 -0
- server/app.py +272 -0
- server/calls.py +153 -0
- server/city.py +222 -0
- server/requirements.txt +6 -0
- server/reward.py +149 -0
- server/smart_emergency_environment.py +559 -0
- train_sft_grpo.py +661 -0
- train_sft_grpo_graph.ipynb +0 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=smart_emergency
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
Makefile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: build start serve stop health
|
| 2 |
+
|
| 3 |
+
# ── Docker ────────────────────────────────────────────────────────────────────
|
| 4 |
+
build:
|
| 5 |
+
@docker build -t emergency:latest -f Dockerfile .
|
| 6 |
+
|
| 7 |
+
start:
|
| 8 |
+
@docker run -p 8000:8000 emergency:latest
|
| 9 |
+
|
| 10 |
+
stop:
|
| 11 |
+
@docker ps -q --filter ancestor=emergency:latest | xargs -r docker stop
|
| 12 |
+
|
| 13 |
+
# ── Local dev (uv) ────────────────────────────────────────────────────────────
|
| 14 |
+
serve:
|
| 15 |
+
@uv run uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
|
| 16 |
+
|
| 17 |
+
health:
|
| 18 |
+
@curl -s http://localhost:8000/health | python3 -m json.tool
|
README.md
CHANGED
|
@@ -1,10 +1,267 @@
|
|
| 1 |
---
|
| 2 |
-
title: Smart Emergency
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Smart Emergency Environment Server
|
| 3 |
+
emoji: 🚨
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Smart Emergency — Dispatch911 RL Environment
|
| 15 |
+
|
| 16 |
+
A disaster management reinforcement learning environment where an agent acts as an emergency dispatcher. Each episode, the agent receives live 911 call transcripts and must triage severity, detect duplicate calls, and dispatch the right vehicle (police / ambulance / fire) from a procedurally generated city graph.
|
| 17 |
+
|
| 18 |
+
Built on [OpenEnv](https://github.com/meta-pytorch/OpenEnv) — a standard interface for RL environments exposed over HTTP/WebSocket, compatible with TRL + Unsloth training pipelines.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Environment Overview
|
| 23 |
+
|
| 24 |
+
| Property | Value |
|
| 25 |
+
|---|---|
|
| 26 |
+
| **Task** | Emergency dispatch (triage + routing) |
|
| 27 |
+
| **Episode length** | 20 steps |
|
| 28 |
+
| **Action space** | `dispatch` or `duplicate` with structured fields |
|
| 29 |
+
| **Observation** | Rich text prompt (call transcript + active events + fleet status + city map) |
|
| 30 |
+
| **Reward** | 5-component shaped reward (severity, duplicate detection, vehicle type, vehicle choice, reroute) |
|
| 31 |
+
| **Duplicate call rate** | 30% |
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## Quick Start
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
from smart_emergency import SmartEmergencyAction, SmartEmergencyEnv
|
| 39 |
+
|
| 40 |
+
with SmartEmergencyEnv(base_url="http://localhost:8000") as env:
|
| 41 |
+
result = env.reset()
|
| 42 |
+
print(result.observation.prompt)
|
| 43 |
+
|
| 44 |
+
# Dispatch an ambulance to the incident
|
| 45 |
+
action = SmartEmergencyAction(
|
| 46 |
+
action_type="dispatch",
|
| 47 |
+
severity_pred=3,
|
| 48 |
+
is_duplicate=False,
|
| 49 |
+
vehicle_type="ambulance",
|
| 50 |
+
vehicle_id="ambulance_0",
|
| 51 |
+
)
|
| 52 |
+
result = env.step(action)
|
| 53 |
+
print(result.observation.reward_breakdown)
|
| 54 |
+
# → {'severity': 1.0, 'duplicate': 1.0, 'vehicle_type': 1.5, 'vehicle_choice': 0.5, 'reroute': 0.0, 'total': 4.0}
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Action Space
|
| 60 |
+
|
| 61 |
+
**`SmartEmergencyAction`** — the agent's structured response to each incoming 911 call.
|
| 62 |
+
|
| 63 |
+
| Field | Type | Required | Description |
|
| 64 |
+
|---|---|---|---|
|
| 65 |
+
| `action_type` | `str` | ✅ | `"dispatch"` or `"duplicate"` |
|
| 66 |
+
| `severity_pred` | `int` (1–5) | ✅ | Predicted severity (1=minor, 5=catastrophic) |
|
| 67 |
+
| `is_duplicate` | `bool` | ✅ | Whether this call is a repeat of an existing event |
|
| 68 |
+
| `duplicate_of_event_id` | `str` | if duplicate | EVT-NNNN of the event this duplicates |
|
| 69 |
+
| `vehicle_type` | `str` | if dispatch | `"police"`, `"ambulance"`, or `"fire"` |
|
| 70 |
+
| `vehicle_id` | `str` | if dispatch | Specific unit ID (e.g. `"ambulance_0"`) |
|
| 71 |
+
| `reroute` | `RerouteAction` | optional | Redirect an in-flight vehicle to the new event |
|
| 72 |
+
|
| 73 |
+
**`RerouteAction`** sub-action:
|
| 74 |
+
|
| 75 |
+
| Field | Type | Description |
|
| 76 |
+
|---|---|---|
|
| 77 |
+
| `vehicle_to_reroute` | `str` | Unit ID of the vehicle to redirect |
|
| 78 |
+
| `from_event_id` | `str` | EVT-NNNN the vehicle is currently heading to |
|
| 79 |
+
| `replacement_vehicle_id` | `str` | Optional free unit to cover the abandoned event |
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## Observation Space
|
| 84 |
+
|
| 85 |
+
**`SmartEmergencyObservation`** — what the agent sees each step.
|
| 86 |
+
|
| 87 |
+
| Field | Type | Description |
|
| 88 |
+
|---|---|---|
|
| 89 |
+
| `prompt` | `str` | Full text observation for the LLM (see format below) |
|
| 90 |
+
| `step` | `int` | Current step number (0–20) |
|
| 91 |
+
| `call_id` | `str` | ID of the incoming call (e.g. `CALL-0001`) |
|
| 92 |
+
| `reward_breakdown` | `dict` | Per-component reward from the previous action |
|
| 93 |
+
| `active_event_ids` | `list[str]` | Currently active event IDs (EVT-NNNN) |
|
| 94 |
+
| `fleet_utilisation` | `float` | Fraction of fleet currently busy (0.0–1.0) |
|
| 95 |
+
|
| 96 |
+
### Prompt Format
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
=== INCOMING CALL [CALL-0003] ===
|
| 100 |
+
Bad crash on Oak Avenue! Car flipped near Riverside Market. Driver trapped, not responding!
|
| 101 |
+
|
| 102 |
+
=== ACTIVE EVENTS ===
|
| 103 |
+
EVT-0001 | fire | Engine House No. 1 | sev 3 | fire_2 ETA 2 min | opened step 1
|
| 104 |
+
EVT-0002 | medical | Oakwood Apartments | sev 2 | UNASSIGNED | opened step 2
|
| 105 |
+
|
| 106 |
+
=== UNIT STATUS ===
|
| 107 |
+
police_0 | police | Central Police Station | FREE
|
| 108 |
+
ambulance_1 | ambulance | Riverside General Hospital | DISPATCHED → EVT-0001
|
| 109 |
+
fire_2 | fire | Central Fire Station | DISPATCHED → EVT-0001
|
| 110 |
+
|
| 111 |
+
=== CITY REFERENCE ===
|
| 112 |
+
Riverside General Hospital (hospital) → Oakwood Apartments [3 min], Central Plaza [5 min]
|
| 113 |
+
...
|
| 114 |
+
|
| 115 |
+
=== DISPATCHER NOTES ===
|
| 116 |
+
Step 1: CALL-0001 → fire fire_2
|
| 117 |
+
Step 2: CALL-0002 → Duplicate of EVT-0001
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## Reward Design
|
| 123 |
+
|
| 124 |
+
5 independent reward components returned as `reward_breakdown`:
|
| 125 |
+
|
| 126 |
+
| Component | Max | Min | Description |
|
| 127 |
+
|---|---|---|---|
|
| 128 |
+
| `severity` | +1.0 | -0.5 | Accuracy of severity prediction (graded, ±0 to ±4 off) |
|
| 129 |
+
| `duplicate` | +1.5 | -1.0 | Correct duplicate detection and event ID matching |
|
| 130 |
+
| `vehicle_type` | +1.5 | -1.5 | Correct vehicle type (police / ambulance / fire) |
|
| 131 |
+
| `vehicle_choice` | +1.0 | -2.0 | Vehicle availability, type match, and proximity bonus |
|
| 132 |
+
| `reroute` | +1.7 | -1.0 | Quality of optional reroute instruction |
|
| 133 |
+
| **`total`** | **~6.7** | **~-6.0** | Sum of all components |
|
| 134 |
+
|
| 135 |
+
Parse failure (malformed action): **-2.0** flat penalty.
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## API Endpoints
|
| 140 |
+
|
| 141 |
+
| Method | Endpoint | Description |
|
| 142 |
+
|---|---|---|
|
| 143 |
+
| `GET` | `/health` | Health check |
|
| 144 |
+
| `POST` | `/reset` | Start a new episode |
|
| 145 |
+
| `POST` | `/step` | Submit an action, get next observation |
|
| 146 |
+
| `GET` | `/state` | Current episode state |
|
| 147 |
+
| `GET` | `/tasks` | List available tasks / difficulty levels |
|
| 148 |
+
| `POST` | `/grader` | Score a completed episode (call after `done=True`) |
|
| 149 |
+
| `GET` | `/baseline` | Run rule-based agent across all tasks |
|
| 150 |
+
| `GET` | `/docs` | Interactive Swagger UI |
|
| 151 |
+
| `WS` | `/ws` | WebSocket for persistent low-latency sessions |
|
| 152 |
+
|
| 153 |
+
---
|
| 154 |
+
|
| 155 |
+
## Running Locally
|
| 156 |
+
|
| 157 |
+
### Option 1: uv (fastest)
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
uv sync
|
| 161 |
+
uv run uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
Or via the Makefile:
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
make serve # uv run, with hot-reload
|
| 168 |
+
make build # build Docker image
|
| 169 |
+
make start # run Docker container
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
### Option 2: Docker
|
| 173 |
+
|
| 174 |
+
```bash
|
| 175 |
+
make build
|
| 176 |
+
make start
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
Then open http://localhost:8000/docs
|
| 180 |
+
|
| 181 |
+
---
|
| 182 |
+
|
| 183 |
+
## Connecting to a Running Server
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
from smart_emergency import SmartEmergencyEnv
|
| 187 |
+
|
| 188 |
+
env = SmartEmergencyEnv(base_url="http://localhost:8000")
|
| 189 |
+
result = env.reset()
|
| 190 |
+
print(result.observation.prompt)
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
Or use the deployed HF Space directly:
|
| 194 |
+
|
| 195 |
+
```python
|
| 196 |
+
env = SmartEmergencyEnv(base_url="https://rishi38-eme-enviro.hf.space")
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
## Grading a Completed Episode
|
| 202 |
+
|
| 203 |
+
After the episode ends (`done=True`), call `/grader`:
|
| 204 |
+
|
| 205 |
+
```bash
|
| 206 |
+
curl -X POST http://localhost:8000/grader
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
```json
|
| 210 |
+
{
|
| 211 |
+
"score": 0.82,
|
| 212 |
+
"reward_components": {
|
| 213 |
+
"severity_accuracy": 0.91,
|
| 214 |
+
"duplicate_f1": 0.75,
|
| 215 |
+
"dispatch_accuracy": 0.88,
|
| 216 |
+
"vehicle_efficiency": 0.74
|
| 217 |
+
},
|
| 218 |
+
"steps": 20,
|
| 219 |
+
"episode_id": "abc-123"
|
| 220 |
+
}
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
## Baseline Agent
|
| 226 |
+
|
| 227 |
+
Run the built-in rule-based agent to get a reference score:
|
| 228 |
+
|
| 229 |
+
```bash
|
| 230 |
+
curl http://localhost:8000/baseline
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
```json
|
| 234 |
+
{
|
| 235 |
+
"baseline_agent": "keyword-heuristic rule-based",
|
| 236 |
+
"average_score": 0.61,
|
| 237 |
+
"tasks": {
|
| 238 |
+
"task_1": {"score": 0.72, "difficulty": "easy", "steps": 20},
|
| 239 |
+
"task_2": {"score": 0.63, "difficulty": "medium", "steps": 20},
|
| 240 |
+
"task_3": {"score": 0.48, "difficulty": "hard", "steps": 20}
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## Project Structure
|
| 248 |
+
|
| 249 |
+
```
|
| 250 |
+
smart_emergency/
|
| 251 |
+
├── README.md # This file (HF Space config + docs)
|
| 252 |
+
├── openenv.yaml # OpenEnv manifest
|
| 253 |
+
├── pyproject.toml # Package metadata & dependencies
|
| 254 |
+
├── Dockerfile # Container build
|
| 255 |
+
├── Makefile # Dev commands (build, start, serve)
|
| 256 |
+
├── uv.lock # Locked dependencies
|
| 257 |
+
├── __init__.py # Package exports
|
| 258 |
+
├── models.py # SmartEmergencyAction + Observation
|
| 259 |
+
├── client.py # SmartEmergencyEnv HTTP/WS client
|
| 260 |
+
└── server/
|
| 261 |
+
├── __init__.py
|
| 262 |
+
├── app.py # FastAPI app via openenv create_app
|
| 263 |
+
├── smart_emergency_environment.py # Core reset/step/reward logic
|
| 264 |
+
├── city.py # Procedural city graph + Dijkstra
|
| 265 |
+
├── calls.py # 911 call generator (25 templates)
|
| 266 |
+
└── reward.py # 5-component decomposed reward
|
| 267 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Smart Emergency Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import SmartEmergencyEnv
|
| 10 |
+
from .models import SmartEmergencyAction, SmartEmergencyObservation, RerouteAction
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"SmartEmergencyAction",
|
| 14 |
+
"SmartEmergencyObservation",
|
| 15 |
+
"RerouteAction",
|
| 16 |
+
"SmartEmergencyEnv",
|
| 17 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Dispatch911 Environment Client."""
|
| 8 |
+
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
|
| 11 |
+
from openenv.core import EnvClient
|
| 12 |
+
from openenv.core.client_types import StepResult
|
| 13 |
+
from openenv.core.env_server.types import State
|
| 14 |
+
|
| 15 |
+
from .models import SmartEmergencyAction, SmartEmergencyObservation, RerouteAction
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SmartEmergencyEnv(
|
| 19 |
+
EnvClient[SmartEmergencyAction, SmartEmergencyObservation, State]
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Client for the Dispatch911 Environment.
|
| 23 |
+
|
| 24 |
+
Example:
|
| 25 |
+
>>> with SmartEmergencyEnv(base_url="http://localhost:8000") as client:
|
| 26 |
+
... result = client.reset()
|
| 27 |
+
... print(result.observation.prompt)
|
| 28 |
+
...
|
| 29 |
+
... action = SmartEmergencyAction(
|
| 30 |
+
... action_type="dispatch",
|
| 31 |
+
... severity_pred=3,
|
| 32 |
+
... is_duplicate=False,
|
| 33 |
+
... vehicle_type="ambulance",
|
| 34 |
+
... vehicle_id="ambulance_0",
|
| 35 |
+
... )
|
| 36 |
+
... result = client.step(action)
|
| 37 |
+
... print(result.observation.reward_breakdown)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def _step_payload(self, action: SmartEmergencyAction) -> Dict:
|
| 41 |
+
"""Convert SmartEmergencyAction to JSON payload."""
|
| 42 |
+
payload: Dict = {
|
| 43 |
+
"action_type": action.action_type,
|
| 44 |
+
"severity_pred": action.severity_pred,
|
| 45 |
+
"is_duplicate": action.is_duplicate,
|
| 46 |
+
}
|
| 47 |
+
if action.duplicate_of_event_id is not None:
|
| 48 |
+
payload["duplicate_of_event_id"] = action.duplicate_of_event_id
|
| 49 |
+
if action.vehicle_type is not None:
|
| 50 |
+
payload["vehicle_type"] = action.vehicle_type
|
| 51 |
+
if action.vehicle_id is not None:
|
| 52 |
+
payload["vehicle_id"] = action.vehicle_id
|
| 53 |
+
if action.reroute is not None:
|
| 54 |
+
payload["reroute"] = {
|
| 55 |
+
"vehicle_to_reroute": action.reroute.vehicle_to_reroute,
|
| 56 |
+
"from_event_id": action.reroute.from_event_id,
|
| 57 |
+
"replacement_vehicle_id": action.reroute.replacement_vehicle_id,
|
| 58 |
+
}
|
| 59 |
+
return payload
|
| 60 |
+
|
| 61 |
+
def _parse_result(self, payload: Dict) -> StepResult[SmartEmergencyObservation]:
|
| 62 |
+
"""Parse server response into StepResult.
|
| 63 |
+
|
| 64 |
+
Note: OpenEnv's serialize_observation() intentionally strips 'metadata',
|
| 65 |
+
'done', and 'reward' from the nested observation dict and promotes them
|
| 66 |
+
to the top level. ground_truth is now a first-class field on the
|
| 67 |
+
observation model so it survives serialization.
|
| 68 |
+
"""
|
| 69 |
+
obs_data = payload.get("observation", {})
|
| 70 |
+
# metadata is stripped by the framework; ground_truth is now a dedicated field
|
| 71 |
+
metadata = payload.get("metadata", obs_data.get("metadata", {}))
|
| 72 |
+
# Support both the new dedicated ground_truth field and the legacy metadata path
|
| 73 |
+
gt = obs_data.get("ground_truth") or metadata.get("ground_truth", {})
|
| 74 |
+
if gt:
|
| 75 |
+
metadata = dict(metadata)
|
| 76 |
+
metadata["ground_truth"] = gt
|
| 77 |
+
observation = SmartEmergencyObservation(
|
| 78 |
+
prompt=obs_data.get("prompt", ""),
|
| 79 |
+
step=obs_data.get("step", 0),
|
| 80 |
+
call_id=obs_data.get("call_id", ""),
|
| 81 |
+
reward_breakdown=obs_data.get("reward_breakdown", {}),
|
| 82 |
+
active_event_ids=obs_data.get("active_event_ids", []),
|
| 83 |
+
fleet_utilisation=obs_data.get("fleet_utilisation", 0.0),
|
| 84 |
+
done=payload.get("done", False),
|
| 85 |
+
reward=payload.get("reward"),
|
| 86 |
+
ground_truth=gt or {},
|
| 87 |
+
metadata=metadata,
|
| 88 |
+
)
|
| 89 |
+
return StepResult(
|
| 90 |
+
observation=observation,
|
| 91 |
+
reward=payload.get("reward"),
|
| 92 |
+
done=payload.get("done", False),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 96 |
+
"""Parse server response into State."""
|
| 97 |
+
return State(
|
| 98 |
+
episode_id=payload.get("episode_id"),
|
| 99 |
+
step_count=payload.get("step_count", 0),
|
| 100 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the Dispatch911 Environment.
|
| 9 |
+
|
| 10 |
+
Action: the agent's structured dispatch decision per incoming 911 call.
|
| 11 |
+
Observation: the text-based observation the agent receives each step.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Dict, List, Literal, Optional
|
| 15 |
+
|
| 16 |
+
from openenv.core.env_server.types import Action, Observation
|
| 17 |
+
from pydantic import Field
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ── Reroute sub-action ──────────────────────────────────────────────────────
|
| 21 |
+
|
| 22 |
+
class RerouteAction(Action):
|
| 23 |
+
"""Optional reroute block inside a dispatch action."""
|
| 24 |
+
|
| 25 |
+
vehicle_to_reroute: str = Field(..., description="Unit ID of vehicle to redirect")
|
| 26 |
+
from_event_id: str = Field(..., description="EVT-NNNN the vehicle is pulled from")
|
| 27 |
+
replacement_vehicle_id: Optional[str] = Field(
|
| 28 |
+
None, description="Free unit to cover the abandoned event"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ── Agent action ─────────────────────────────────────────────────────────────
|
| 33 |
+
|
| 34 |
+
class SmartEmergencyAction(Action):
|
| 35 |
+
"""
|
| 36 |
+
The agent's response to an incoming 911 call.
|
| 37 |
+
|
| 38 |
+
Three modes:
|
| 39 |
+
- action_type='dispatch': handle a new emergency
|
| 40 |
+
- action_type='duplicate': flag as repeat of an existing event
|
| 41 |
+
- action_type='hold': queue event for a busy vehicle to handle after it frees
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
action_type: Literal["dispatch", "duplicate", "hold"] = Field(
|
| 45 |
+
..., description="'dispatch', 'duplicate', or 'hold'"
|
| 46 |
+
)
|
| 47 |
+
severity_pred: int = Field(
|
| 48 |
+
..., ge=1, le=5, description="Predicted severity 1-5"
|
| 49 |
+
)
|
| 50 |
+
is_duplicate: bool = Field(
|
| 51 |
+
False, description="Whether the agent believes this is a repeat call"
|
| 52 |
+
)
|
| 53 |
+
duplicate_of_event_id: Optional[str] = Field(
|
| 54 |
+
None, description="EVT-NNNN of the event this duplicates (required if is_duplicate)"
|
| 55 |
+
)
|
| 56 |
+
vehicle_type: Optional[str] = Field(
|
| 57 |
+
None, description="'police', 'ambulance', or 'fire' (required if dispatch or hold)"
|
| 58 |
+
)
|
| 59 |
+
vehicle_id: Optional[str] = Field(
|
| 60 |
+
None, description="Unit to dispatch now (dispatch) or busy unit to queue for (hold)"
|
| 61 |
+
)
|
| 62 |
+
reroute: Optional[RerouteAction] = Field(
|
| 63 |
+
None, description="Optional reroute instruction"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ── Observation ──────────────────────────────────────────────────────────────
|
| 68 |
+
|
| 69 |
+
class SmartEmergencyObservation(Observation):
|
| 70 |
+
"""
|
| 71 |
+
Observation returned to the agent each step.
|
| 72 |
+
|
| 73 |
+
Contains the full text prompt (transcript + active events + unit status +
|
| 74 |
+
city reference + dispatcher notes) and structured metadata for logging.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
prompt: str = Field(default="", description="Full text observation for the LLM")
|
| 78 |
+
step: int = Field(default=0, description="Current step number")
|
| 79 |
+
call_id: str = Field(default="", description="ID of the incoming call")
|
| 80 |
+
reward_breakdown: Dict[str, float] = Field(
|
| 81 |
+
default_factory=dict, description="Per-component reward breakdown"
|
| 82 |
+
)
|
| 83 |
+
active_event_ids: List[str] = Field(
|
| 84 |
+
default_factory=list, description="Currently active event IDs"
|
| 85 |
+
)
|
| 86 |
+
fleet_utilisation: float = Field(
|
| 87 |
+
default=0.0, description="Fraction of fleet currently busy"
|
| 88 |
+
)
|
| 89 |
+
ground_truth: Dict = Field(
|
| 90 |
+
default_factory=dict,
|
| 91 |
+
description="Hidden ground truth for the current call (populated after step)",
|
| 92 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: smart_emergency
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
| 8 |
+
description: >
|
| 9 |
+
A disaster management reinforcement learning environment where agents manage
|
| 10 |
+
emergency dispatch. Agents must triage incoming 911 calls, classify severity,
|
| 11 |
+
detect duplicate events, and dispatch limited resources (Police, Fire, Ambulance)
|
| 12 |
+
across a procedural smart city graph.
|
| 13 |
+
|
| 14 |
+
tags:
|
| 15 |
+
- openenv
|
| 16 |
+
- disaster-management
|
| 17 |
+
- smart-city
|
| 18 |
+
- dispatch
|
| 19 |
+
- rl
|
| 20 |
+
|
| 21 |
+
tasks:
|
| 22 |
+
- id: 1
|
| 23 |
+
name: "Basic Dispatch"
|
| 24 |
+
difficulty: easy
|
| 25 |
+
description: "Low-volume calls, fewer active events. Focus on severity and vehicle type."
|
| 26 |
+
reward_max: 6.7
|
| 27 |
+
|
| 28 |
+
- id: 2
|
| 29 |
+
name: "Duplicate Detection"
|
| 30 |
+
difficulty: medium
|
| 31 |
+
description: "Higher duplicate rate. Agent must correlate repeat callers to existing events."
|
| 32 |
+
reward_max: 6.7
|
| 33 |
+
|
| 34 |
+
- id: 3
|
| 35 |
+
name: "Full Disaster Response"
|
| 36 |
+
difficulty: hard
|
| 37 |
+
description: "High call volume, scarce vehicles, reroutes required. Full 20-step episode."
|
| 38 |
+
reward_max: 6.7
|
| 39 |
+
|
| 40 |
+
observation_space:
|
| 41 |
+
prompt: string
|
| 42 |
+
step: integer
|
| 43 |
+
call_id: string
|
| 44 |
+
reward_breakdown: object
|
| 45 |
+
active_event_ids: array
|
| 46 |
+
fleet_utilisation: float
|
| 47 |
+
|
| 48 |
+
action_space:
|
| 49 |
+
action_type:
|
| 50 |
+
type: string
|
| 51 |
+
values: [dispatch, duplicate]
|
| 52 |
+
severity_pred:
|
| 53 |
+
type: integer
|
| 54 |
+
is_duplicate:
|
| 55 |
+
type: boolean
|
| 56 |
+
duplicate_of_event_id:
|
| 57 |
+
type: string
|
| 58 |
+
vehicle_type:
|
| 59 |
+
type: string
|
| 60 |
+
values: [police, ambulance, fire]
|
| 61 |
+
vehicle_id:
|
| 62 |
+
type: string
|
| 63 |
+
reroute:
|
| 64 |
+
type: object
|
openenv_smart_emergency.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-smart_emergency
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Smart Emergency environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.2
|
| 7 |
+
Provides-Extra: dev
|
| 8 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 9 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_smart_emergency.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
./__init__.py
|
| 4 |
+
./client.py
|
| 5 |
+
./generate_sft_data.py
|
| 6 |
+
./models.py
|
| 7 |
+
openenv_smart_emergency.egg-info/PKG-INFO
|
| 8 |
+
openenv_smart_emergency.egg-info/SOURCES.txt
|
| 9 |
+
openenv_smart_emergency.egg-info/dependency_links.txt
|
| 10 |
+
openenv_smart_emergency.egg-info/entry_points.txt
|
| 11 |
+
openenv_smart_emergency.egg-info/requires.txt
|
| 12 |
+
openenv_smart_emergency.egg-info/top_level.txt
|
| 13 |
+
server/__init__.py
|
| 14 |
+
server/app.py
|
| 15 |
+
server/calls.py
|
| 16 |
+
server/city.py
|
| 17 |
+
server/reward.py
|
| 18 |
+
server/smart_emergency_environment.py
|
openenv_smart_emergency.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_smart_emergency.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = smart_emergency.server.app:main
|
openenv_smart_emergency.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.2
|
| 2 |
+
|
| 3 |
+
[dev]
|
| 4 |
+
pytest>=8.0.0
|
| 5 |
+
pytest-cov>=4.0.0
|
openenv_smart_emergency.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
smart_emergency
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-smart_emergency"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Smart Emergency environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.2",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m smart_emergency.server.app
|
| 40 |
+
server = "smart_emergency.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["smart_emergency", "smart_emergency.server"]
|
| 45 |
+
package-dir = { "smart_emergency" = ".", "smart_emergency.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Smart Emergency environment server components."""
|
| 8 |
+
|
| 9 |
+
from .smart_emergency_environment import SmartEmergencyEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["SmartEmergencyEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the Smart Emergency Environment.
|
| 9 |
+
|
| 10 |
+
Endpoints:
|
| 11 |
+
POST /reset — Reset the environment, start a new episode
|
| 12 |
+
POST /step — Submit an action, receive next observation + reward
|
| 13 |
+
GET /state — Current episode state
|
| 14 |
+
GET /health — Health check
|
| 15 |
+
GET /tasks — Available difficulty tasks
|
| 16 |
+
POST /grader — Score a completed episode (call after done=True)
|
| 17 |
+
GET /baseline — Run rule-based agent across all 3 tasks
|
| 18 |
+
WS /ws — WebSocket for persistent low-latency sessions
|
| 19 |
+
GET /docs — Swagger UI (auto-generated)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from openenv.core.env_server.http_server import create_app
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from ..models import SmartEmergencyAction, SmartEmergencyObservation, RerouteAction
|
| 26 |
+
from .smart_emergency_environment import SmartEmergencyEnvironment
|
| 27 |
+
except (ImportError, ModuleNotFoundError):
|
| 28 |
+
from models import SmartEmergencyAction, SmartEmergencyObservation, RerouteAction
|
| 29 |
+
from server.smart_emergency_environment import SmartEmergencyEnvironment
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ── App ──────────────────────────────────────────────────────────────────────
|
| 33 |
+
|
| 34 |
+
# We use create_app so OpenEnv can automatically mount its Gradio web UI at / and /web
|
| 35 |
+
# when deployed to Hugging Face Spaces.
|
| 36 |
+
app = create_app(
|
| 37 |
+
SmartEmergencyEnvironment,
|
| 38 |
+
SmartEmergencyAction,
|
| 39 |
+
SmartEmergencyObservation,
|
| 40 |
+
env_name="smart_emergency",
|
| 41 |
+
max_concurrent_envs=1,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# ── Health ───────────────────────────────────────────────────────────────────
|
| 45 |
+
|
| 46 |
+
@app.get("/health")
|
| 47 |
+
def health():
|
| 48 |
+
return {
|
| 49 |
+
"status": "healthy",
|
| 50 |
+
"environment": "smart-emergency-dispatch911",
|
| 51 |
+
"version": "1.0.0",
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ── Tasks ────────────────────────────────────────────────────────────────────
|
| 56 |
+
|
| 57 |
+
@app.get("/tasks")
|
| 58 |
+
def tasks():
|
| 59 |
+
"""List available difficulty tasks."""
|
| 60 |
+
return {
|
| 61 |
+
"tasks": [
|
| 62 |
+
{
|
| 63 |
+
"id": 1,
|
| 64 |
+
"name": "Basic Dispatch",
|
| 65 |
+
"difficulty": "easy",
|
| 66 |
+
"description": "10 steps, 3 vehicles per type, 10% duplicates. Focus on severity and vehicle type.",
|
| 67 |
+
"reward_max": 6.7,
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"id": 2,
|
| 71 |
+
"name": "Scarce Resources",
|
| 72 |
+
"difficulty": "medium",
|
| 73 |
+
"description": "15 steps, 2 vehicles per type, 30% duplicates. Must handle holds and pick nearest units.",
|
| 74 |
+
"reward_max": 6.7,
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"id": 3,
|
| 78 |
+
"name": "Full Disaster Response",
|
| 79 |
+
"difficulty": "hard",
|
| 80 |
+
"description": "20 steps, 1 vehicle per type, 50% duplicates. Requires reroutes and optimal triage.",
|
| 81 |
+
"reward_max": 6.7,
|
| 82 |
+
},
|
| 83 |
+
]
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ── Grader ───────────────────────────────────────────────────────────────────
|
| 88 |
+
|
| 89 |
+
@app.post("/grader")
|
| 90 |
+
def grader():
|
| 91 |
+
"""
|
| 92 |
+
Score the completed episode. Call this after done=True.
|
| 93 |
+
|
| 94 |
+
Returns cumulative reward breakdown, per-component averages,
|
| 95 |
+
and a normalized 0–1 score suitable for hackathon leaderboards.
|
| 96 |
+
"""
|
| 97 |
+
steps = SmartEmergencyEnvironment.latest_steps
|
| 98 |
+
|
| 99 |
+
if steps == 0:
|
| 100 |
+
raise HTTPException(
|
| 101 |
+
status_code=400,
|
| 102 |
+
detail="No episode in progress. Call POST /reset first.",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Collect reward history from the class-level tracker
|
| 106 |
+
history = SmartEmergencyEnvironment.latest_history
|
| 107 |
+
if not history:
|
| 108 |
+
raise HTTPException(
|
| 109 |
+
status_code=400,
|
| 110 |
+
detail=(
|
| 111 |
+
"Episode not yet complete or no steps taken. "
|
| 112 |
+
"Keep calling POST /step until observation.done == true."
|
| 113 |
+
),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Aggregate per-component averages
|
| 117 |
+
keys = ["severity", "duplicate", "vehicle_type", "vehicle_choice", "reroute", "total"]
|
| 118 |
+
component_totals = {k: 0.0 for k in keys}
|
| 119 |
+
raw_cumulative = 0.0
|
| 120 |
+
for breakdown in history:
|
| 121 |
+
for k in keys:
|
| 122 |
+
component_totals[k] += breakdown.get(k, 0.0)
|
| 123 |
+
raw_cumulative += breakdown.get("raw_total", breakdown.get("total", 0.0))
|
| 124 |
+
|
| 125 |
+
n = max(1, len(history))
|
| 126 |
+
component_avgs = {k: round(v / n, 4) for k, v in component_totals.items()}
|
| 127 |
+
cumulative = round(component_totals["total"], 4)
|
| 128 |
+
|
| 129 |
+
# Normalize using raw total (before baseline subtraction) for a fair 0–1 score
|
| 130 |
+
MAX_PER_STEP = 6.7
|
| 131 |
+
score = round(max(0.0, min(1.0, raw_cumulative / (MAX_PER_STEP * n))), 4)
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"score": score,
|
| 135 |
+
"cumulative_reward": cumulative,
|
| 136 |
+
"raw_cumulative_reward": round(raw_cumulative, 4),
|
| 137 |
+
"steps": steps,
|
| 138 |
+
"episode_id": SmartEmergencyEnvironment.latest_episode_id,
|
| 139 |
+
"reward_components": {
|
| 140 |
+
"severity_avg": component_avgs["severity"],
|
| 141 |
+
"duplicate_avg": component_avgs["duplicate"],
|
| 142 |
+
"vehicle_type_avg": component_avgs["vehicle_type"],
|
| 143 |
+
"vehicle_choice_avg": component_avgs["vehicle_choice"],
|
| 144 |
+
"reroute_avg": component_avgs["reroute"],
|
| 145 |
+
},
|
| 146 |
+
"per_step_total_avg": component_avgs["total"],
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ── Baseline ─────────────────────────────────────────────────────────────────
|
| 151 |
+
|
| 152 |
+
@app.get("/baseline")
|
| 153 |
+
def baseline():
|
| 154 |
+
"""
|
| 155 |
+
Run a keyword-heuristic rule-based agent across all 3 tasks.
|
| 156 |
+
Returns per-task scores and an overall average.
|
| 157 |
+
Required for hackathon submission.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def _classify_severity(transcript: str) -> int:
|
| 161 |
+
t = transcript.lower()
|
| 162 |
+
if any(w in t for w in ["not breathing", "collapsed", "not responding",
|
| 163 |
+
"active shooter", "trapped", "mass incident",
|
| 164 |
+
"massive fire", "whole block", "not moving"]):
|
| 165 |
+
return 5
|
| 166 |
+
if any(w in t for w in ["won't wake", "unconscious", "not responding",
|
| 167 |
+
"gunshots", "flipped", "blood everywhere",
|
| 168 |
+
"people yelling", "pileup"]):
|
| 169 |
+
return 4
|
| 170 |
+
if any(w in t for w in ["chest pain", "fight", "mugged", "knife",
|
| 171 |
+
"crash", "hurt", "bleeding", "fire at",
|
| 172 |
+
"flames", "cyclist"]):
|
| 173 |
+
return 3
|
| 174 |
+
if any(w in t for w in ["fainted", "break-in", "dumpster", "fender",
|
| 175 |
+
"small fire", "ankle"]):
|
| 176 |
+
return 2
|
| 177 |
+
return 1
|
| 178 |
+
|
| 179 |
+
def _classify_vehicle(transcript: str) -> str:
|
| 180 |
+
t = transcript.lower()
|
| 181 |
+
if any(w in t for w in ["fire", "flames", "smoke", "burning", "gas"]):
|
| 182 |
+
return "fire"
|
| 183 |
+
if any(w in t for w in ["shooter", "gunshot", "mugged", "knife",
|
| 184 |
+
"break-in", "fight", "shoplifter", "crime"]):
|
| 185 |
+
return "police"
|
| 186 |
+
return "ambulance"
|
| 187 |
+
|
| 188 |
+
def _pick_vehicle(env: SmartEmergencyEnvironment, vtype: str):
|
| 189 |
+
if env._city is None:
|
| 190 |
+
return None
|
| 191 |
+
for v in env._city.vehicles:
|
| 192 |
+
if v.vehicle_type == vtype and v.status == "FREE":
|
| 193 |
+
return v.unit_id
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
def _rule_agent(env: SmartEmergencyEnvironment, obs) -> SmartEmergencyAction:
|
| 197 |
+
call = env._current_call
|
| 198 |
+
if call is None:
|
| 199 |
+
return SmartEmergencyAction(
|
| 200 |
+
action_type="dispatch",
|
| 201 |
+
severity_pred=1,
|
| 202 |
+
is_duplicate=False,
|
| 203 |
+
vehicle_type="police",
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Check for duplicates heuristically
|
| 207 |
+
if obs.active_event_ids and env._current_call and env._current_call.is_duplicate_of:
|
| 208 |
+
dup_id = env._current_call.is_duplicate_of
|
| 209 |
+
return SmartEmergencyAction(
|
| 210 |
+
action_type="duplicate",
|
| 211 |
+
severity_pred=call.severity,
|
| 212 |
+
is_duplicate=True,
|
| 213 |
+
duplicate_of_event_id=dup_id,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
transcript = obs.prompt
|
| 217 |
+
sev = _classify_severity(transcript)
|
| 218 |
+
vtype = _classify_vehicle(transcript)
|
| 219 |
+
vid = _pick_vehicle(env, vtype)
|
| 220 |
+
|
| 221 |
+
return SmartEmergencyAction(
|
| 222 |
+
action_type="dispatch",
|
| 223 |
+
severity_pred=sev,
|
| 224 |
+
is_duplicate=False,
|
| 225 |
+
vehicle_type=vtype,
|
| 226 |
+
vehicle_id=vid,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
all_scores = {}
|
| 230 |
+
for task_id in [1, 2, 3]:
|
| 231 |
+
env = SmartEmergencyEnvironment()
|
| 232 |
+
obs = env.reset()
|
| 233 |
+
total_reward = 0.0
|
| 234 |
+
steps = 0
|
| 235 |
+
MAX_STEPS = 20
|
| 236 |
+
|
| 237 |
+
while not obs.done and steps < MAX_STEPS:
|
| 238 |
+
action = _rule_agent(env, obs)
|
| 239 |
+
try:
|
| 240 |
+
obs = env.step(action)
|
| 241 |
+
total_reward += obs.reward_breakdown.get("raw_total", obs.reward_breakdown.get("total", 0.0))
|
| 242 |
+
except Exception:
|
| 243 |
+
break
|
| 244 |
+
steps += 1
|
| 245 |
+
|
| 246 |
+
MAX_PER_STEP = 6.7
|
| 247 |
+
score = round(max(0.0, min(1.0, total_reward / (MAX_PER_STEP * max(1, steps)))), 4)
|
| 248 |
+
|
| 249 |
+
all_scores[f"task_{task_id}"] = {
|
| 250 |
+
"score": score,
|
| 251 |
+
"cumulative_reward": round(total_reward, 4),
|
| 252 |
+
"steps": steps,
|
| 253 |
+
"difficulty": ["easy", "medium", "hard"][task_id - 1],
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
avg = round(sum(v["score"] for v in all_scores.values()) / 3, 4)
|
| 257 |
+
return {
|
| 258 |
+
"baseline_agent": "keyword-heuristic rule-based",
|
| 259 |
+
"average_score": avg,
|
| 260 |
+
"tasks": all_scores,
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# ── Entry point ───────────────────────────────────────────────────────────────
|
| 265 |
+
|
| 266 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 267 |
+
import uvicorn
|
| 268 |
+
uvicorn.run(app, host=host, port=port)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
main()
|
server/calls.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""911 call transcript generator for Dispatch911."""
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
from .city import City
|
| 8 |
+
|
| 9 |
+
# ── Call templates ────────────────────────────────────────────────────────────
|
| 10 |
+
|
| 11 |
+
TEMPLATES = [
|
| 12 |
+
# ── FIRE ──────────────────────────────────────────────────────────────
|
| 13 |
+
{"type": "fire", "sev": 1, "vehicle": "fire",
|
| 14 |
+
"text": "Hi, I think I see some smoke coming from behind {landmark}. It might be nothing but thought I should call."},
|
| 15 |
+
{"type": "fire", "sev": 2, "vehicle": "fire",
|
| 16 |
+
"text": "Yeah there's a small fire in a dumpster near {landmark} on {street}. It's not spreading but it's pretty smoky."},
|
| 17 |
+
{"type": "fire", "sev": 3, "vehicle": "fire",
|
| 18 |
+
"text": "There's a fire at {address}! Flames coming out a window on the second floor. I don't think anyone's inside but I'm not sure."},
|
| 19 |
+
{"type": "fire", "sev": 4, "vehicle": "fire",
|
| 20 |
+
"text": "Oh god, the whole kitchen is on fire at {address}! My kids are upstairs — please send someone NOW!"},
|
| 21 |
+
{"type": "fire", "sev": 4, "vehicle": "fire",
|
| 22 |
+
"text": "Building's on fire on {street} near {landmark}! People are yelling from the windows, please hurry!"},
|
| 23 |
+
{"type": "fire", "sev": 5, "vehicle": "fire",
|
| 24 |
+
"text": "There's a massive fire — the whole block near {landmark} is burning. Multiple buildings involved, I can see people trapped. Send everything you've got!"},
|
| 25 |
+
# ── MEDICAL ───────────────────────────────────────────────────────────
|
| 26 |
+
{"type": "medical", "sev": 1, "vehicle": "ambulance",
|
| 27 |
+
"text": "Hello, my neighbor fell and hurt her ankle at {address}. She's conscious and talking but can't walk."},
|
| 28 |
+
{"type": "medical", "sev": 2, "vehicle": "ambulance",
|
| 29 |
+
"text": "Someone fainted at {landmark}. They're breathing okay now but look really pale. We're on {street}."},
|
| 30 |
+
{"type": "medical", "sev": 3, "vehicle": "ambulance",
|
| 31 |
+
"text": "There's a man having chest pains at {address}. He's sweating a lot and says his arm feels numb."},
|
| 32 |
+
{"type": "medical", "sev": 4, "vehicle": "ambulance",
|
| 33 |
+
"text": "My husband just collapsed and he won't wake up! He's breathing weird. We're at {address}, please hurry!"},
|
| 34 |
+
{"type": "medical", "sev": 4, "vehicle": "ambulance",
|
| 35 |
+
"text": "Someone's not breathing at {landmark}! A bystander is doing CPR. Please send an ambulance to {street} immediately!"},
|
| 36 |
+
{"type": "medical", "sev": 5, "vehicle": "ambulance",
|
| 37 |
+
"text": "There's been some kind of mass incident at {landmark} — multiple people down, some not moving. We need everything, {street} entrance."},
|
| 38 |
+
# ── CRIME ─────────────────────────────────────────────────────────────
|
| 39 |
+
{"type": "crime", "sev": 1, "vehicle": "police",
|
| 40 |
+
"text": "I'd like to report a shoplifter at {landmark} on {street}. They already left but I got a good look."},
|
| 41 |
+
{"type": "crime", "sev": 2, "vehicle": "police",
|
| 42 |
+
"text": "There's a break-in happening right now at {address}. I can see someone climbing through a window from across the street."},
|
| 43 |
+
{"type": "crime", "sev": 3, "vehicle": "police",
|
| 44 |
+
"text": "There's a fight outside {landmark} on {street}. Looks like 3-4 people involved, getting pretty violent."},
|
| 45 |
+
{"type": "crime", "sev": 3, "vehicle": "police",
|
| 46 |
+
"text": "I just got mugged near {landmark}! The guy ran towards {cross_street}. He had a knife."},
|
| 47 |
+
{"type": "crime", "sev": 4, "vehicle": "police",
|
| 48 |
+
"text": "I think I heard gunshots near {address}! People are running. I'm hiding inside {landmark}, please send help!"},
|
| 49 |
+
{"type": "crime", "sev": 5, "vehicle": "police",
|
| 50 |
+
"text": "Active shooter at {landmark}! Multiple shots fired, people running everywhere. Send everyone NOW!"},
|
| 51 |
+
# ── ACCIDENT ──────────────────────────────────────────────────────────
|
| 52 |
+
{"type": "accident", "sev": 2, "vehicle": "ambulance",
|
| 53 |
+
"text": "Fender bender on {street} near {landmark}. No injuries but the cars are blocking the road."},
|
| 54 |
+
{"type": "accident", "sev": 3, "vehicle": "ambulance",
|
| 55 |
+
"text": "Car accident at {street} and {cross_street}. One driver looks hurt, holding their neck. Other car's smoking."},
|
| 56 |
+
{"type": "accident", "sev": 3, "vehicle": "ambulance",
|
| 57 |
+
"text": "A cyclist got hit by a car near {landmark}. They're on the ground, conscious but bleeding from the head."},
|
| 58 |
+
{"type": "accident", "sev": 4, "vehicle": "ambulance",
|
| 59 |
+
"text": "Bad crash on {street}! Car flipped over near {landmark}. Driver's trapped inside, not responding!"},
|
| 60 |
+
{"type": "accident", "sev": 4, "vehicle": "ambulance",
|
| 61 |
+
"text": "Pedestrian hit by a truck at {cross_street} near {landmark}. They're not moving. There's blood everywhere."},
|
| 62 |
+
{"type": "accident", "sev": 5, "vehicle": "ambulance",
|
| 63 |
+
"text": "Multi-car pileup on {street} near {landmark}! At least 5 cars, people screaming, I can smell gas leaking. Send fire too!"},
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class Call:
|
| 69 |
+
"""A single incoming 911 call with hidden ground truth."""
|
| 70 |
+
call_id: str
|
| 71 |
+
event_id: str
|
| 72 |
+
origin_node_id: str
|
| 73 |
+
origin_node_name: str
|
| 74 |
+
emergency_type: str
|
| 75 |
+
severity: int
|
| 76 |
+
required_vehicle_type: str
|
| 77 |
+
is_duplicate_of: Optional[str]
|
| 78 |
+
transcript: str
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def generate_call(
|
| 82 |
+
city: City,
|
| 83 |
+
call_number: int,
|
| 84 |
+
active_events: dict,
|
| 85 |
+
duplicate_prob: float,
|
| 86 |
+
rng: random.Random,
|
| 87 |
+
next_event_counter: int,
|
| 88 |
+
) -> tuple:
|
| 89 |
+
"""
|
| 90 |
+
Generate one 911 call.
|
| 91 |
+
|
| 92 |
+
Returns (Call, new_event_counter).
|
| 93 |
+
"""
|
| 94 |
+
node_ids = list(city.nodes.keys())
|
| 95 |
+
|
| 96 |
+
# ── Decide if duplicate ──────────────────────────────────────────────
|
| 97 |
+
is_dup = False
|
| 98 |
+
dup_event_id = None
|
| 99 |
+
dup_event = None
|
| 100 |
+
if active_events and rng.random() < duplicate_prob:
|
| 101 |
+
dup_event_id = rng.choice(list(active_events.keys()))
|
| 102 |
+
dup_event = active_events[dup_event_id]
|
| 103 |
+
is_dup = True
|
| 104 |
+
|
| 105 |
+
if is_dup and dup_event is not None:
|
| 106 |
+
etype = dup_event["type"]
|
| 107 |
+
sev = dup_event["severity"]
|
| 108 |
+
vtype = dup_event["vehicle"]
|
| 109 |
+
origin = dup_event["node_id"]
|
| 110 |
+
event_id = dup_event_id
|
| 111 |
+
else:
|
| 112 |
+
# Pick a random template
|
| 113 |
+
tmpl = rng.choice(TEMPLATES)
|
| 114 |
+
etype = tmpl["type"]
|
| 115 |
+
sev = tmpl["sev"] + rng.choice([-1, 0, 0, 0, 1])
|
| 116 |
+
sev = max(1, min(5, sev))
|
| 117 |
+
vtype = tmpl["vehicle"]
|
| 118 |
+
# Pick origin node (prefer residential/commercial)
|
| 119 |
+
preferred = [n for n in node_ids if city.nodes[n].node_type in ("residential", "commercial")]
|
| 120 |
+
origin = rng.choice(preferred) if preferred else rng.choice(node_ids)
|
| 121 |
+
event_id = f"EVT-{next_event_counter:04d}"
|
| 122 |
+
next_event_counter += 1
|
| 123 |
+
|
| 124 |
+
# ── Build transcript ─────────────────────────────────────────────────
|
| 125 |
+
node = city.nodes[origin]
|
| 126 |
+
neighbours = list(city.edges.get(origin, {}).keys())
|
| 127 |
+
cross = city.nodes[rng.choice(neighbours)].street if neighbours else "unknown road"
|
| 128 |
+
|
| 129 |
+
# Pick a template matching the type
|
| 130 |
+
matching = [t for t in TEMPLATES if t["type"] == etype]
|
| 131 |
+
tmpl = rng.choice(matching)
|
| 132 |
+
address = f"{rng.randint(100, 999)} {node.street}"
|
| 133 |
+
|
| 134 |
+
text = tmpl["text"].format(
|
| 135 |
+
landmark=node.name,
|
| 136 |
+
street=node.street,
|
| 137 |
+
address=address,
|
| 138 |
+
cross_street=cross,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
call = Call(
|
| 142 |
+
call_id=f"CALL-{call_number:04d}",
|
| 143 |
+
event_id=event_id,
|
| 144 |
+
origin_node_id=origin,
|
| 145 |
+
origin_node_name=node.name,
|
| 146 |
+
emergency_type=etype,
|
| 147 |
+
severity=sev,
|
| 148 |
+
required_vehicle_type=vtype,
|
| 149 |
+
is_duplicate_of=dup_event_id if is_dup else None,
|
| 150 |
+
transcript=text,
|
| 151 |
+
)
|
| 152 |
+
print(call)
|
| 153 |
+
return call, next_event_counter
|
server/city.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Procedural city graph builder for Dispatch911."""
|
| 2 |
+
|
| 3 |
+
import heapq
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
# ── Name pools ───────────────────────────────────────────────────────────────
|
| 10 |
+
|
| 11 |
+
STREET_NAMES = [
|
| 12 |
+
"Oak", "Maple", "Cedar", "Elm", "Pine", "River", "Lake", "Hill",
|
| 13 |
+
"Park", "Main", "First", "Second", "Third", "Spring", "Sunset",
|
| 14 |
+
]
|
| 15 |
+
SUFFIXES = ["Street", "Avenue", "Road", "Drive", "Lane", "Boulevard"]
|
| 16 |
+
LANDMARKS = {
|
| 17 |
+
"hospital": ["Riverside General Hospital", "St. Mary's Medical Center",
|
| 18 |
+
"City Central Hospital"],
|
| 19 |
+
"fire_station": ["Engine House No. 1", "Central Fire Station",
|
| 20 |
+
"Westside Fire Department"],
|
| 21 |
+
"police_station": ["Central Police Station", "Metro Police HQ",
|
| 22 |
+
"Downtown Precinct"],
|
| 23 |
+
"residential": ["Oakwood Apartments", "Maple Heights", "Pinecrest Homes",
|
| 24 |
+
"Riverside Condos", "Cedar Park Village", "Elmwood Terrace",
|
| 25 |
+
"Lakeview Residences", "Hilltop Manor", "Sunset Villas",
|
| 26 |
+
"Spring Meadow Estates", "Willow Creek Homes",
|
| 27 |
+
"Birchwood Place", "Magnolia Gardens", "Aspen Ridge"],
|
| 28 |
+
"commercial": ["Downtown Mall", "Oak Avenue Shops", "Riverside Market",
|
| 29 |
+
"Central Plaza", "Parkside Shopping Center"],
|
| 30 |
+
"road_junction": ["Highway 9 Interchange", "Central Crossroads",
|
| 31 |
+
"Northside Junction", "Eastgate Roundabout",
|
| 32 |
+
"Southbound Overpass", "Westway Intersection"],
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class Node:
|
| 38 |
+
node_id: str
|
| 39 |
+
node_type: str
|
| 40 |
+
name: str
|
| 41 |
+
street: str
|
| 42 |
+
x: float = 0.0
|
| 43 |
+
y: float = 0.0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class Destination:
|
| 48 |
+
"""A queued assignment for a vehicle (used by hold actions)."""
|
| 49 |
+
node_id: str # target node to travel to
|
| 50 |
+
event_id: str # EVT-NNNN this destination serves
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class Vehicle:
|
| 55 |
+
unit_id: str
|
| 56 |
+
vehicle_type: str # police / ambulance / fire
|
| 57 |
+
home_node: str
|
| 58 |
+
current_node: str
|
| 59 |
+
status: str = "FREE" # FREE / DISPATCHED / ON_SCENE / RETURNING
|
| 60 |
+
assigned_event: Optional[str] = None
|
| 61 |
+
eta: int = 0
|
| 62 |
+
on_scene_remaining: int = 0
|
| 63 |
+
return_remaining: int = 0
|
| 64 |
+
path: List[str] = field(default_factory=list)
|
| 65 |
+
transit_progress: float = 0.0 # 0..1 along current path
|
| 66 |
+
destinations: List[Destination] = field(default_factory=list) # queued future assignments
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclass
|
| 70 |
+
class City:
|
| 71 |
+
nodes: Dict[str, Node] = field(default_factory=dict)
|
| 72 |
+
edges: Dict[str, Dict[str, float]] = field(default_factory=dict) # adj list
|
| 73 |
+
vehicles: List[Vehicle] = field(default_factory=list)
|
| 74 |
+
seed: int = 0
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _distance(a: Node, b: Node) -> float:
|
| 78 |
+
return math.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _make_street(rng: random.Random) -> str:
|
| 82 |
+
return f"{rng.choice(STREET_NAMES)} {rng.choice(SUFFIXES)}"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def generate_city(seed: int, difficulty: int = 1) -> City:
|
| 86 |
+
"""Build a random city graph, spawn vehicles, return City.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
seed: Random seed for reproducibility.
|
| 90 |
+
difficulty: 1 = easy (plenty of vehicles), 2 = medium, 3 = hard (scarce).
|
| 91 |
+
"""
|
| 92 |
+
rng = random.Random(seed)
|
| 93 |
+
city = City(seed=seed)
|
| 94 |
+
|
| 95 |
+
# ── 1. Create nodes ──────────────────────────────────────────────────
|
| 96 |
+
node_specs: List[Tuple[str, int]] = [
|
| 97 |
+
("hospital", 1),
|
| 98 |
+
("fire_station", 1),
|
| 99 |
+
("police_station", 1),
|
| 100 |
+
("residential", rng.randint(3, 5)),
|
| 101 |
+
("commercial", rng.randint(1, 2)),
|
| 102 |
+
("road_junction", rng.randint(1, 2)),
|
| 103 |
+
]
|
| 104 |
+
idx = 0
|
| 105 |
+
for ntype, count in node_specs:
|
| 106 |
+
pool = list(LANDMARKS.get(ntype, []))
|
| 107 |
+
rng.shuffle(pool)
|
| 108 |
+
for i in range(count):
|
| 109 |
+
nid = f"{ntype}_{idx}"
|
| 110 |
+
name = pool[i] if i < len(pool) else f"{ntype.title()} {idx}"
|
| 111 |
+
node = Node(
|
| 112 |
+
node_id=nid, node_type=ntype, name=name,
|
| 113 |
+
street=_make_street(rng),
|
| 114 |
+
x=rng.uniform(0, 1), y=rng.uniform(0, 1),
|
| 115 |
+
)
|
| 116 |
+
city.nodes[nid] = node
|
| 117 |
+
city.edges[nid] = {}
|
| 118 |
+
idx += 1
|
| 119 |
+
|
| 120 |
+
# ── 2. Build edges (proximity-biased) ────────────────────────────────
|
| 121 |
+
node_ids = list(city.nodes.keys())
|
| 122 |
+
for nid in node_ids:
|
| 123 |
+
n = city.nodes[nid]
|
| 124 |
+
others = sorted(
|
| 125 |
+
[oid for oid in node_ids if oid != nid],
|
| 126 |
+
key=lambda oid: _distance(n, city.nodes[oid]),
|
| 127 |
+
)
|
| 128 |
+
k = rng.randint(2, 4)
|
| 129 |
+
neighbours = others[:k]
|
| 130 |
+
# add 0-1 long-range edges
|
| 131 |
+
for _ in range(rng.randint(0, 1)):
|
| 132 |
+
far = rng.choice(others[k:]) if len(others) > k else None
|
| 133 |
+
if far:
|
| 134 |
+
neighbours.append(far)
|
| 135 |
+
for oid in neighbours:
|
| 136 |
+
if oid not in city.edges[nid]:
|
| 137 |
+
dist = _distance(n, city.nodes[oid])
|
| 138 |
+
travel = max(1.0, dist * 15 + rng.uniform(-1, 2))
|
| 139 |
+
travel = round(travel, 1)
|
| 140 |
+
city.edges[nid][oid] = travel
|
| 141 |
+
city.edges[oid][nid] = travel
|
| 142 |
+
|
| 143 |
+
# ── 3. Ensure connectivity ───────────────────────────────────────────
|
| 144 |
+
visited = set()
|
| 145 |
+
stack = [node_ids[0]]
|
| 146 |
+
while stack:
|
| 147 |
+
cur = stack.pop()
|
| 148 |
+
if cur in visited:
|
| 149 |
+
continue
|
| 150 |
+
visited.add(cur)
|
| 151 |
+
stack.extend(city.edges[cur].keys())
|
| 152 |
+
if len(visited) < len(node_ids):
|
| 153 |
+
unvisited = [n for n in node_ids if n not in visited]
|
| 154 |
+
for uid in unvisited:
|
| 155 |
+
closest = min(visited, key=lambda v: _distance(city.nodes[uid], city.nodes[v]))
|
| 156 |
+
d = round(max(1.0, _distance(city.nodes[uid], city.nodes[closest]) * 15), 1)
|
| 157 |
+
city.edges[uid][closest] = d
|
| 158 |
+
city.edges[closest][uid] = d
|
| 159 |
+
visited.add(uid)
|
| 160 |
+
|
| 161 |
+
# ── 4. Spawn vehicles (count scales with difficulty) ──────────────────
|
| 162 |
+
# Easy (1): 3 per type — always a free unit available
|
| 163 |
+
# Medium (2): 2 per type — sometimes all busy, must use hold
|
| 164 |
+
# Hard (3): 1 per type — forces hold/reroute decisions constantly
|
| 165 |
+
if difficulty <= 1:
|
| 166 |
+
vehicle_count = 3
|
| 167 |
+
elif difficulty == 2:
|
| 168 |
+
vehicle_count = 2
|
| 169 |
+
else:
|
| 170 |
+
vehicle_count = 1
|
| 171 |
+
|
| 172 |
+
def _find_node(ntype: str) -> str:
|
| 173 |
+
for nid, n in city.nodes.items():
|
| 174 |
+
if n.node_type == ntype:
|
| 175 |
+
return nid
|
| 176 |
+
return node_ids[0]
|
| 177 |
+
|
| 178 |
+
vid = 0
|
| 179 |
+
for vtype, home_type in [
|
| 180 |
+
("police", "police_station"),
|
| 181 |
+
("ambulance", "hospital"),
|
| 182 |
+
("fire", "fire_station"),
|
| 183 |
+
]:
|
| 184 |
+
home = _find_node(home_type)
|
| 185 |
+
for _ in range(vehicle_count):
|
| 186 |
+
city.vehicles.append(Vehicle(
|
| 187 |
+
unit_id=f"{vtype}_{vid}",
|
| 188 |
+
vehicle_type=vtype,
|
| 189 |
+
home_node=home,
|
| 190 |
+
current_node=home,
|
| 191 |
+
))
|
| 192 |
+
vid += 1
|
| 193 |
+
|
| 194 |
+
return city
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def dijkstra(city: City, src: str, dst: str) -> Tuple[float, List[str]]:
|
| 198 |
+
"""Shortest path (travel time) between two nodes. Returns (time, path)."""
|
| 199 |
+
dist: Dict[str, float] = {src: 0.0}
|
| 200 |
+
prev: Dict[str, Optional[str]] = {src: None}
|
| 201 |
+
heap = [(0.0, src)]
|
| 202 |
+
while heap:
|
| 203 |
+
d, u = heapq.heappop(heap)
|
| 204 |
+
if u == dst:
|
| 205 |
+
break
|
| 206 |
+
if d > dist.get(u, float("inf")):
|
| 207 |
+
continue
|
| 208 |
+
for v, w in city.edges.get(u, {}).items():
|
| 209 |
+
nd = d + w
|
| 210 |
+
if nd < dist.get(v, float("inf")):
|
| 211 |
+
dist[v] = nd
|
| 212 |
+
prev[v] = u
|
| 213 |
+
heapq.heappush(heap, (nd, v))
|
| 214 |
+
if dst not in dist:
|
| 215 |
+
return float("inf"), []
|
| 216 |
+
path = []
|
| 217 |
+
cur: Optional[str] = dst
|
| 218 |
+
while cur is not None:
|
| 219 |
+
path.append(cur)
|
| 220 |
+
cur = prev.get(cur)
|
| 221 |
+
path.reverse()
|
| 222 |
+
return dist[dst], path
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
server/reward.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decomposed reward computation for Dispatch911 (5 components)."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# ── Default reward config ────────────────────────────────────────────────────
|
| 7 |
+
|
| 8 |
+
SEVERITY_REWARDS = {0: 1.0, 1: 0.6, 2: 0.2, 3: -0.2, 4: -0.5}
|
| 9 |
+
PARSE_FAILURE_PENALTY = -2.0
|
| 10 |
+
MAX_TRAVEL_TIME = 15.0
|
| 11 |
+
|
| 12 |
+
# Baseline reward subtracted from each step's total so that an
|
| 13 |
+
# untrained / SFT-only agent starts near 0 and the GRPO training curve
|
| 14 |
+
# shows the expected upward trend. Calibrated to the average per-step
|
| 15 |
+
# score of a keyword-heuristic agent (~2.5).
|
| 16 |
+
STEP_REWARD_BASELINE = 2.5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_reward(
|
| 20 |
+
*,
|
| 21 |
+
# ground truth
|
| 22 |
+
gt_severity: int,
|
| 23 |
+
gt_is_duplicate: bool,
|
| 24 |
+
gt_event_id: Optional[str],
|
| 25 |
+
gt_vehicle_type: str,
|
| 26 |
+
gt_origin_node: str,
|
| 27 |
+
# agent predictions
|
| 28 |
+
severity_pred: int,
|
| 29 |
+
is_duplicate_pred: bool,
|
| 30 |
+
duplicate_of_event_id: Optional[str],
|
| 31 |
+
vehicle_type_pred: Optional[str],
|
| 32 |
+
vehicle_id_pred: Optional[str],
|
| 33 |
+
# vehicle context
|
| 34 |
+
vehicle_exists: bool = True,
|
| 35 |
+
vehicle_is_free: bool = True,
|
| 36 |
+
vehicle_type_matches: bool = True,
|
| 37 |
+
travel_time: float = 0.0,
|
| 38 |
+
is_nearest: bool = False,
|
| 39 |
+
# reroute context
|
| 40 |
+
reroute_attempted: bool = False,
|
| 41 |
+
reroute_valid: bool = False,
|
| 42 |
+
reroute_severity_delta: int = 0,
|
| 43 |
+
reroute_faster: bool = False,
|
| 44 |
+
replacement_valid: Optional[bool] = None,
|
| 45 |
+
# hold context
|
| 46 |
+
hold_is_action: bool = False,
|
| 47 |
+
hold_free_unit_exists: bool = False,
|
| 48 |
+
hold_min_busy_severity: int = 0,
|
| 49 |
+
hold_vehicle_is_soonest: bool = False,
|
| 50 |
+
) -> Dict[str, float]:
|
| 51 |
+
"""Return per-component reward breakdown + total."""
|
| 52 |
+
|
| 53 |
+
breakdown: Dict[str, float] = {}
|
| 54 |
+
|
| 55 |
+
# ── 1. Severity ──────────────────────────────────────────────────────
|
| 56 |
+
err = abs(severity_pred - gt_severity)
|
| 57 |
+
breakdown["severity"] = SEVERITY_REWARDS.get(err, -0.5)
|
| 58 |
+
|
| 59 |
+
# ── 2. Duplicate detection ───────────────────────────────────────────
|
| 60 |
+
if not is_duplicate_pred and not gt_is_duplicate:
|
| 61 |
+
breakdown["duplicate"] = 1.0
|
| 62 |
+
elif not is_duplicate_pred and gt_is_duplicate:
|
| 63 |
+
breakdown["duplicate"] = -1.0
|
| 64 |
+
elif is_duplicate_pred and not gt_is_duplicate:
|
| 65 |
+
breakdown["duplicate"] = -0.8
|
| 66 |
+
elif is_duplicate_pred and gt_is_duplicate:
|
| 67 |
+
if duplicate_of_event_id is None:
|
| 68 |
+
breakdown["duplicate"] = 0.0
|
| 69 |
+
elif duplicate_of_event_id == gt_event_id:
|
| 70 |
+
breakdown["duplicate"] = 1.5
|
| 71 |
+
else:
|
| 72 |
+
breakdown["duplicate"] = 0.3
|
| 73 |
+
|
| 74 |
+
# ── 3. Vehicle type ──────────────────────────────────────────────────
|
| 75 |
+
if is_duplicate_pred:
|
| 76 |
+
breakdown["vehicle_type"] = 0.0
|
| 77 |
+
elif vehicle_type_pred == gt_vehicle_type:
|
| 78 |
+
breakdown["vehicle_type"] = 1.5
|
| 79 |
+
else:
|
| 80 |
+
breakdown["vehicle_type"] = -1.5
|
| 81 |
+
|
| 82 |
+
# ── 4. Vehicle choice / Hold quality ─────────────────────────────────
|
| 83 |
+
if is_duplicate_pred:
|
| 84 |
+
breakdown["vehicle_choice"] = 0.0
|
| 85 |
+
elif hold_is_action:
|
| 86 |
+
# Hold-specific scoring
|
| 87 |
+
if hold_free_unit_exists:
|
| 88 |
+
# A free unit exists — holding is unjustified
|
| 89 |
+
breakdown["vehicle_choice"] = -2.0
|
| 90 |
+
elif not vehicle_exists:
|
| 91 |
+
# Hallucinated vehicle ID
|
| 92 |
+
breakdown["vehicle_choice"] = -2.0
|
| 93 |
+
elif vehicle_is_free:
|
| 94 |
+
# Named a FREE unit but chose hold instead of dispatch
|
| 95 |
+
breakdown["vehicle_choice"] = -1.5
|
| 96 |
+
else:
|
| 97 |
+
# All units of correct type are busy — evaluate severity
|
| 98 |
+
sev_delta = hold_min_busy_severity - gt_severity
|
| 99 |
+
if sev_delta > 0:
|
| 100 |
+
# All busy units have strictly higher severity — justified
|
| 101 |
+
breakdown["vehicle_choice"] = 1.0
|
| 102 |
+
elif sev_delta == 0:
|
| 103 |
+
# Some busy units have equal severity — reasonable
|
| 104 |
+
breakdown["vehicle_choice"] = 0.5
|
| 105 |
+
else:
|
| 106 |
+
# Some busy units have lower severity — should have rerouted
|
| 107 |
+
breakdown["vehicle_choice"] = -0.3 * abs(sev_delta)
|
| 108 |
+
# Bonus: picked the soonest-to-free unit
|
| 109 |
+
if hold_vehicle_is_soonest:
|
| 110 |
+
breakdown["vehicle_choice"] += 0.3
|
| 111 |
+
elif not vehicle_exists:
|
| 112 |
+
breakdown["vehicle_choice"] = -5.0
|
| 113 |
+
elif not vehicle_is_free:
|
| 114 |
+
breakdown["vehicle_choice"] = -2.0 # busy vehicle — as bad as hallucination
|
| 115 |
+
elif not vehicle_type_matches:
|
| 116 |
+
breakdown["vehicle_choice"] = -0.5
|
| 117 |
+
else:
|
| 118 |
+
prox = max(0.0, 1.0 - travel_time / MAX_TRAVEL_TIME)
|
| 119 |
+
mult = 1.0 if is_nearest else 0.5
|
| 120 |
+
breakdown["vehicle_choice"] = prox * mult
|
| 121 |
+
|
| 122 |
+
# ── 5. Reroute ───────────────────────────────────────────────────────
|
| 123 |
+
if hold_is_action:
|
| 124 |
+
breakdown["reroute"] = 0.0 # neutral for hold actions
|
| 125 |
+
elif not reroute_attempted:
|
| 126 |
+
breakdown["reroute"] = 0.0
|
| 127 |
+
elif not reroute_valid:
|
| 128 |
+
breakdown["reroute"] = -1.0
|
| 129 |
+
else:
|
| 130 |
+
r = 0.0
|
| 131 |
+
if reroute_severity_delta <= 0:
|
| 132 |
+
r = -0.5
|
| 133 |
+
elif reroute_severity_delta == 1:
|
| 134 |
+
r = 0.3
|
| 135 |
+
else:
|
| 136 |
+
r = 0.8
|
| 137 |
+
if reroute_faster:
|
| 138 |
+
r += 0.4
|
| 139 |
+
if replacement_valid is True:
|
| 140 |
+
r += 0.5
|
| 141 |
+
elif replacement_valid is False:
|
| 142 |
+
r -= 0.3
|
| 143 |
+
breakdown["reroute"] = r
|
| 144 |
+
|
| 145 |
+
raw = sum(breakdown.values())
|
| 146 |
+
breakdown["raw_total"] = raw
|
| 147 |
+
breakdown["total"] = raw - STEP_REWARD_BASELINE
|
| 148 |
+
return breakdown
|
| 149 |
+
|
server/smart_emergency_environment.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dispatch911 Environment — OpenEnv-compatible Gym environment.
|
| 3 |
+
|
| 4 |
+
Handles reset/step loop, vehicle lifecycle, event registry,
|
| 5 |
+
observation formatting, and reward integration.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
from typing import Dict, List, Optional
|
| 10 |
+
from uuid import uuid4
|
| 11 |
+
|
| 12 |
+
from openenv.core.env_server.interfaces import Environment
|
| 13 |
+
from openenv.core.env_server.types import State
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from ..models import SmartEmergencyAction, SmartEmergencyObservation
|
| 17 |
+
except ImportError:
|
| 18 |
+
from models import SmartEmergencyAction, SmartEmergencyObservation
|
| 19 |
+
|
| 20 |
+
from .city import City, Destination, Vehicle, dijkstra, generate_city
|
| 21 |
+
from .calls import Call, generate_call
|
| 22 |
+
from .reward import PARSE_FAILURE_PENALTY, compute_reward
|
| 23 |
+
|
| 24 |
+
# ── Config defaults ──────────────────────────────────────────────────────────
|
| 25 |
+
|
| 26 |
+
MAX_STEPS = 20
|
| 27 |
+
DUPLICATE_PROB = 0.30
|
| 28 |
+
ON_SCENE_STEPS = 2
|
| 29 |
+
RETURN_STEPS = 2
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SmartEmergencyEnvironment(Environment):
|
| 33 |
+
"""
|
| 34 |
+
Dispatch911 RL environment.
|
| 35 |
+
|
| 36 |
+
Each episode = one procedurally generated city.
|
| 37 |
+
Each step = one incoming 911 call.
|
| 38 |
+
The agent outputs a structured JSON action; the environment
|
| 39 |
+
evaluates it against hidden ground truth and returns a shaped reward.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 43 |
+
|
| 44 |
+
# Class-level tracking for /grader since create_app hides the instance
|
| 45 |
+
latest_history = []
|
| 46 |
+
latest_steps = 0
|
| 47 |
+
latest_episode_id = ""
|
| 48 |
+
|
| 49 |
+
def __init__(self):
|
| 50 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 51 |
+
self._city: Optional[City] = None
|
| 52 |
+
self._rng = random.Random()
|
| 53 |
+
self._active_events: Dict[str, dict] = {}
|
| 54 |
+
self._event_counter = 1
|
| 55 |
+
self._current_call: Optional[Call] = None
|
| 56 |
+
self._dispatcher_notes: List[str] = []
|
| 57 |
+
self._seed = 0
|
| 58 |
+
self._reward_history: List[dict] = [] # for /grader aggregation
|
| 59 |
+
|
| 60 |
+
# ── Reset ────────────────────────────────────────────────────────────
|
| 61 |
+
|
| 62 |
+
def reset(self, task_id: int = 1, seed: Optional[int] = None) -> SmartEmergencyObservation:
|
| 63 |
+
self._seed = seed if seed is not None else random.randint(0, 999999)
|
| 64 |
+
self._rng = random.Random(self._seed)
|
| 65 |
+
self._city = generate_city(self._seed, difficulty=task_id)
|
| 66 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 67 |
+
self._active_events = {}
|
| 68 |
+
self._event_counter = 1
|
| 69 |
+
self._dispatcher_notes = []
|
| 70 |
+
self._reward_history = []
|
| 71 |
+
|
| 72 |
+
# Reset class-level tracker
|
| 73 |
+
SmartEmergencyEnvironment.latest_history = []
|
| 74 |
+
SmartEmergencyEnvironment.latest_steps = 0
|
| 75 |
+
SmartEmergencyEnvironment.latest_episode_id = self._state.episode_id
|
| 76 |
+
|
| 77 |
+
self._task_id = task_id
|
| 78 |
+
if task_id == 1:
|
| 79 |
+
self._max_steps = 10
|
| 80 |
+
self._duplicate_prob = 0.10
|
| 81 |
+
elif task_id == 2:
|
| 82 |
+
self._max_steps = 15
|
| 83 |
+
self._duplicate_prob = 0.30
|
| 84 |
+
else:
|
| 85 |
+
self._max_steps = 20
|
| 86 |
+
self._duplicate_prob = 0.50
|
| 87 |
+
|
| 88 |
+
# Generate first call
|
| 89 |
+
self._current_call, self._event_counter = generate_call(
|
| 90 |
+
self._city, 1, self._active_events,
|
| 91 |
+
self._duplicate_prob, self._rng, self._event_counter,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
obs_text = self._build_observation()
|
| 95 |
+
return SmartEmergencyObservation(
|
| 96 |
+
prompt=obs_text,
|
| 97 |
+
step=0,
|
| 98 |
+
call_id=self._current_call.call_id,
|
| 99 |
+
reward_breakdown={},
|
| 100 |
+
active_event_ids=list(self._active_events.keys()),
|
| 101 |
+
fleet_utilisation=self._fleet_util(),
|
| 102 |
+
done=False,
|
| 103 |
+
reward=0.0,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# ── Step ─────────────────────────────────────────────────────────────
|
| 107 |
+
|
| 108 |
+
def step(self, action: SmartEmergencyAction) -> SmartEmergencyObservation:
|
| 109 |
+
# Auto-reset if step is called before reset
|
| 110 |
+
if self._current_call is None or self._city is None:
|
| 111 |
+
self.reset()
|
| 112 |
+
|
| 113 |
+
self._state.step_count += 1
|
| 114 |
+
call = self._current_call
|
| 115 |
+
city = self._city
|
| 116 |
+
assert call is not None and city is not None
|
| 117 |
+
|
| 118 |
+
# ── Evaluate action ──────────────────────────────────────────────
|
| 119 |
+
reward_kwargs = self._evaluate_action(action, call)
|
| 120 |
+
breakdown = compute_reward(**reward_kwargs)
|
| 121 |
+
self._reward_history.append(breakdown)
|
| 122 |
+
|
| 123 |
+
# Update class-level tracker for grader
|
| 124 |
+
SmartEmergencyEnvironment.latest_history.append(breakdown)
|
| 125 |
+
SmartEmergencyEnvironment.latest_steps = self._state.step_count
|
| 126 |
+
|
| 127 |
+
# ── Update state ──────────────���──────────────────────────────────
|
| 128 |
+
self._apply_action(action, call)
|
| 129 |
+
|
| 130 |
+
# ── Advance simulation clock ─────────────────────────────────────
|
| 131 |
+
self._tick_vehicles()
|
| 132 |
+
|
| 133 |
+
# ── Log dispatcher note ──────────────────────────────────────────
|
| 134 |
+
note = f"Step {self._state.step_count}: {call.call_id}"
|
| 135 |
+
if action.is_duplicate:
|
| 136 |
+
note += f" → Duplicate of {action.duplicate_of_event_id or '?'}"
|
| 137 |
+
elif action.action_type == "hold":
|
| 138 |
+
note += f" → HOLD ({action.vehicle_type}, waiting for {action.vehicle_id or '?'})"
|
| 139 |
+
elif action.action_type == "dispatch":
|
| 140 |
+
note += f" → {action.vehicle_type} {action.vehicle_id or '?'}"
|
| 141 |
+
self._dispatcher_notes.append(note)
|
| 142 |
+
if len(self._dispatcher_notes) > 3:
|
| 143 |
+
self._dispatcher_notes = self._dispatcher_notes[-3:]
|
| 144 |
+
|
| 145 |
+
# ── Check done ───────────────────────────────────────────────────
|
| 146 |
+
done = self._state.step_count >= getattr(self, "_max_steps", MAX_STEPS)
|
| 147 |
+
|
| 148 |
+
# ── Generate next call ───────────────────────────────────────────
|
| 149 |
+
if not done:
|
| 150 |
+
self._current_call, self._event_counter = generate_call(
|
| 151 |
+
city, self._state.step_count + 1,
|
| 152 |
+
self._active_events, getattr(self, "_duplicate_prob", DUPLICATE_PROB),
|
| 153 |
+
self._rng, self._event_counter,
|
| 154 |
+
)
|
| 155 |
+
obs_text = self._build_observation() if not done else "Episode complete."
|
| 156 |
+
|
| 157 |
+
gt = {
|
| 158 |
+
"severity": call.severity,
|
| 159 |
+
"emergency_type": call.emergency_type,
|
| 160 |
+
"is_duplicate": call.is_duplicate_of is not None,
|
| 161 |
+
"required_vehicle_type": call.required_vehicle_type,
|
| 162 |
+
}
|
| 163 |
+
return SmartEmergencyObservation(
|
| 164 |
+
prompt=obs_text,
|
| 165 |
+
step=self._state.step_count,
|
| 166 |
+
call_id=call.call_id,
|
| 167 |
+
reward_breakdown=breakdown,
|
| 168 |
+
active_event_ids=list(self._active_events.keys()),
|
| 169 |
+
fleet_utilisation=self._fleet_util(),
|
| 170 |
+
done=done,
|
| 171 |
+
reward=breakdown.get("total", 0.0),
|
| 172 |
+
ground_truth=gt,
|
| 173 |
+
metadata={
|
| 174 |
+
"ground_truth": gt,
|
| 175 |
+
"city_seed": self._seed,
|
| 176 |
+
},
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# ── Evaluate ─────────────────────────────────────────────────────────
|
| 180 |
+
|
| 181 |
+
def _evaluate_action(self, action: SmartEmergencyAction, call: Call) -> dict:
|
| 182 |
+
"""Build kwargs for compute_reward."""
|
| 183 |
+
city = self._city
|
| 184 |
+
assert city is not None
|
| 185 |
+
|
| 186 |
+
gt_is_dup = call.is_duplicate_of is not None
|
| 187 |
+
gt_eid = call.is_duplicate_of
|
| 188 |
+
|
| 189 |
+
# Vehicle checks
|
| 190 |
+
v_exists = True
|
| 191 |
+
v_free = True
|
| 192 |
+
v_type_match = True
|
| 193 |
+
travel = 0.0
|
| 194 |
+
is_nearest = False
|
| 195 |
+
|
| 196 |
+
# Hold checks
|
| 197 |
+
hold_is_action = action.action_type == "hold"
|
| 198 |
+
hold_free_exists = False
|
| 199 |
+
hold_min_busy_sev = 0
|
| 200 |
+
hold_vehicle_soonest = False
|
| 201 |
+
|
| 202 |
+
if hold_is_action:
|
| 203 |
+
# Check if the named vehicle exists and its state
|
| 204 |
+
if action.vehicle_id:
|
| 205 |
+
veh = self._find_vehicle(action.vehicle_id)
|
| 206 |
+
if veh is None:
|
| 207 |
+
v_exists = False
|
| 208 |
+
else:
|
| 209 |
+
v_free = veh.status == "FREE"
|
| 210 |
+
v_type_match = veh.vehicle_type == (action.vehicle_type or "")
|
| 211 |
+
else:
|
| 212 |
+
v_exists = False
|
| 213 |
+
|
| 214 |
+
# Check if any free unit of the correct type exists
|
| 215 |
+
vtype = action.vehicle_type or call.required_vehicle_type
|
| 216 |
+
free_of_type = [
|
| 217 |
+
v for v in city.vehicles
|
| 218 |
+
if v.status == "FREE" and v.vehicle_type == vtype
|
| 219 |
+
]
|
| 220 |
+
hold_free_exists = len(free_of_type) > 0
|
| 221 |
+
|
| 222 |
+
# Find min severity among busy units of this type
|
| 223 |
+
busy_of_type = [
|
| 224 |
+
v for v in city.vehicles
|
| 225 |
+
if v.status != "FREE" and v.vehicle_type == vtype
|
| 226 |
+
and v.assigned_event is not None
|
| 227 |
+
]
|
| 228 |
+
if busy_of_type:
|
| 229 |
+
busy_sevs = []
|
| 230 |
+
for bv in busy_of_type:
|
| 231 |
+
evt = self._active_events.get(bv.assigned_event, {})
|
| 232 |
+
busy_sevs.append(evt.get("severity", 5))
|
| 233 |
+
hold_min_busy_sev = min(busy_sevs)
|
| 234 |
+
|
| 235 |
+
# Check if named vehicle is the soonest to free
|
| 236 |
+
if v_exists and not v_free and action.vehicle_id:
|
| 237 |
+
veh = self._find_vehicle(action.vehicle_id)
|
| 238 |
+
if veh and veh.eta is not None:
|
| 239 |
+
min_eta = min(
|
| 240 |
+
(bv.eta for bv in busy_of_type if bv.eta is not None),
|
| 241 |
+
default=999,
|
| 242 |
+
)
|
| 243 |
+
hold_vehicle_soonest = veh.eta <= min_eta
|
| 244 |
+
|
| 245 |
+
elif not action.is_duplicate and action.vehicle_id:
|
| 246 |
+
veh = self._find_vehicle(action.vehicle_id)
|
| 247 |
+
if veh is None:
|
| 248 |
+
v_exists = False
|
| 249 |
+
else:
|
| 250 |
+
v_free = veh.status == "FREE"
|
| 251 |
+
v_type_match = veh.vehicle_type == action.vehicle_type
|
| 252 |
+
if v_exists and v_free:
|
| 253 |
+
travel, _ = dijkstra(city, veh.current_node, call.origin_node_id)
|
| 254 |
+
# Check if nearest
|
| 255 |
+
free_same = [
|
| 256 |
+
v for v in city.vehicles
|
| 257 |
+
if v.status == "FREE" and v.vehicle_type == call.required_vehicle_type
|
| 258 |
+
]
|
| 259 |
+
if free_same:
|
| 260 |
+
min_t = min(dijkstra(city, v.current_node, call.origin_node_id)[0] for v in free_same)
|
| 261 |
+
is_nearest = abs(travel - min_t) < 0.1
|
| 262 |
+
|
| 263 |
+
# Reroute checks
|
| 264 |
+
reroute_attempted = action.reroute is not None and not hold_is_action
|
| 265 |
+
reroute_valid = False
|
| 266 |
+
reroute_sev_delta = 0
|
| 267 |
+
reroute_faster = False
|
| 268 |
+
replacement_valid = None
|
| 269 |
+
|
| 270 |
+
if reroute_attempted and action.reroute is not None:
|
| 271 |
+
rv = self._find_vehicle(action.reroute.vehicle_to_reroute)
|
| 272 |
+
if rv and rv.status == "DISPATCHED" and rv.assigned_event == action.reroute.from_event_id:
|
| 273 |
+
reroute_valid = True
|
| 274 |
+
old_evt = self._active_events.get(action.reroute.from_event_id, {})
|
| 275 |
+
reroute_sev_delta = call.severity - old_evt.get("severity", call.severity)
|
| 276 |
+
if action.reroute.replacement_vehicle_id:
|
| 277 |
+
rep = self._find_vehicle(action.reroute.replacement_vehicle_id)
|
| 278 |
+
replacement_valid = (
|
| 279 |
+
rep is not None and rep.status == "FREE"
|
| 280 |
+
and rep.vehicle_type == old_evt.get("vehicle", "")
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
return dict(
|
| 284 |
+
gt_severity=call.severity,
|
| 285 |
+
gt_is_duplicate=gt_is_dup,
|
| 286 |
+
gt_event_id=gt_eid,
|
| 287 |
+
gt_vehicle_type=call.required_vehicle_type,
|
| 288 |
+
gt_origin_node=call.origin_node_id,
|
| 289 |
+
severity_pred=action.severity_pred,
|
| 290 |
+
is_duplicate_pred=action.is_duplicate,
|
| 291 |
+
duplicate_of_event_id=action.duplicate_of_event_id,
|
| 292 |
+
vehicle_type_pred=action.vehicle_type,
|
| 293 |
+
vehicle_id_pred=action.vehicle_id,
|
| 294 |
+
vehicle_exists=v_exists,
|
| 295 |
+
vehicle_is_free=v_free,
|
| 296 |
+
vehicle_type_matches=v_type_match,
|
| 297 |
+
travel_time=travel,
|
| 298 |
+
is_nearest=is_nearest,
|
| 299 |
+
reroute_attempted=reroute_attempted,
|
| 300 |
+
reroute_valid=reroute_valid,
|
| 301 |
+
reroute_severity_delta=reroute_sev_delta,
|
| 302 |
+
reroute_faster=reroute_faster,
|
| 303 |
+
replacement_valid=replacement_valid,
|
| 304 |
+
hold_is_action=hold_is_action,
|
| 305 |
+
hold_free_unit_exists=hold_free_exists,
|
| 306 |
+
hold_min_busy_severity=hold_min_busy_sev,
|
| 307 |
+
hold_vehicle_is_soonest=hold_vehicle_soonest,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# ── Apply action to state ────────────────────────────────────────────
|
| 311 |
+
|
| 312 |
+
def _apply_action(self, action: SmartEmergencyAction, call: Call):
|
| 313 |
+
city = self._city
|
| 314 |
+
assert city is not None
|
| 315 |
+
|
| 316 |
+
if action.is_duplicate:
|
| 317 |
+
# Link call to existing event
|
| 318 |
+
eid = action.duplicate_of_event_id or call.event_id
|
| 319 |
+
if eid in self._active_events:
|
| 320 |
+
self._active_events[eid].setdefault("calls", []).append(call.call_id)
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
# Register new event (only if not already active)
|
| 324 |
+
eid = call.event_id
|
| 325 |
+
if eid not in self._active_events:
|
| 326 |
+
self._active_events[eid] = {
|
| 327 |
+
"type": call.emergency_type,
|
| 328 |
+
"severity": call.severity,
|
| 329 |
+
"vehicle": call.required_vehicle_type,
|
| 330 |
+
"node_id": call.origin_node_id,
|
| 331 |
+
"node_name": call.origin_node_name,
|
| 332 |
+
"assigned_unit": None,
|
| 333 |
+
"unit_eta": None,
|
| 334 |
+
"held_for_unit": None,
|
| 335 |
+
"step_opened": self._state.step_count,
|
| 336 |
+
"calls": [call.call_id],
|
| 337 |
+
}
|
| 338 |
+
else:
|
| 339 |
+
# Event already exists — just link this call
|
| 340 |
+
self._active_events[eid].setdefault("calls", []).append(call.call_id)
|
| 341 |
+
|
| 342 |
+
# ── Hold action ──────────────────────────────────────────────────
|
| 343 |
+
if action.action_type == "hold" and action.vehicle_id:
|
| 344 |
+
veh = self._find_vehicle(action.vehicle_id)
|
| 345 |
+
if veh is not None and veh.status != "FREE":
|
| 346 |
+
# Queue this event as a future destination for the vehicle
|
| 347 |
+
veh.destinations.append(
|
| 348 |
+
Destination(node_id=call.origin_node_id, event_id=eid)
|
| 349 |
+
)
|
| 350 |
+
self._active_events[eid]["held_for_unit"] = action.vehicle_id
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
# Handle reroute
|
| 354 |
+
if action.reroute is not None:
|
| 355 |
+
rv = self._find_vehicle(action.reroute.vehicle_to_reroute)
|
| 356 |
+
if rv and rv.status == "DISPATCHED":
|
| 357 |
+
# Unassign from old event
|
| 358 |
+
old_eid = action.reroute.from_event_id
|
| 359 |
+
if old_eid in self._active_events:
|
| 360 |
+
self._active_events[old_eid]["assigned_unit"] = None
|
| 361 |
+
self._active_events[old_eid]["unit_eta"] = None
|
| 362 |
+
# Dispatch rerouted vehicle to new event
|
| 363 |
+
travel, path = dijkstra(city, rv.current_node, call.origin_node_id)
|
| 364 |
+
rv.status = "DISPATCHED"
|
| 365 |
+
rv.assigned_event = eid
|
| 366 |
+
rv.eta = max(1, int(travel))
|
| 367 |
+
rv.path = path
|
| 368 |
+
self._active_events[eid]["assigned_unit"] = rv.unit_id
|
| 369 |
+
self._active_events[eid]["unit_eta"] = rv.eta
|
| 370 |
+
# Handle replacement
|
| 371 |
+
if action.reroute.replacement_vehicle_id:
|
| 372 |
+
rep = self._find_vehicle(action.reroute.replacement_vehicle_id)
|
| 373 |
+
if rep and rep.status == "FREE" and old_eid in self._active_events:
|
| 374 |
+
old_node = self._active_events[old_eid]["node_id"]
|
| 375 |
+
t, p = dijkstra(city, rep.current_node, old_node)
|
| 376 |
+
rep.status = "DISPATCHED"
|
| 377 |
+
rep.assigned_event = old_eid
|
| 378 |
+
rep.eta = max(1, int(t))
|
| 379 |
+
rep.path = p
|
| 380 |
+
self._active_events[old_eid]["assigned_unit"] = rep.unit_id
|
| 381 |
+
self._active_events[old_eid]["unit_eta"] = rep.eta
|
| 382 |
+
return
|
| 383 |
+
|
| 384 |
+
# Normal dispatch
|
| 385 |
+
if action.vehicle_id:
|
| 386 |
+
veh = self._find_vehicle(action.vehicle_id)
|
| 387 |
+
if veh is None:
|
| 388 |
+
# Hallucinated vehicle — event stays UNASSIGNED
|
| 389 |
+
return
|
| 390 |
+
if veh.status != "FREE":
|
| 391 |
+
# Vehicle is busy — auto-convert to hold.
|
| 392 |
+
# Penalty still applied in reward, but the event gets queued.
|
| 393 |
+
veh.destinations.append(
|
| 394 |
+
Destination(node_id=call.origin_node_id, event_id=eid)
|
| 395 |
+
)
|
| 396 |
+
self._active_events[eid]["held_for_unit"] = veh.unit_id
|
| 397 |
+
return
|
| 398 |
+
# Vehicle is free — dispatch it
|
| 399 |
+
travel, path = dijkstra(city, veh.current_node, call.origin_node_id)
|
| 400 |
+
veh.status = "DISPATCHED"
|
| 401 |
+
veh.assigned_event = eid
|
| 402 |
+
veh.eta = max(1, int(travel))
|
| 403 |
+
veh.path = path
|
| 404 |
+
self._active_events[eid]["assigned_unit"] = veh.unit_id
|
| 405 |
+
self._active_events[eid]["unit_eta"] = veh.eta
|
| 406 |
+
|
| 407 |
+
# ── Vehicle tick ─────────────────────────────────────────────────────
|
| 408 |
+
|
| 409 |
+
def _tick_vehicles(self):
|
| 410 |
+
city = self._city
|
| 411 |
+
assert city is not None
|
| 412 |
+
resolved = []
|
| 413 |
+
|
| 414 |
+
for v in city.vehicles:
|
| 415 |
+
if v.status == "DISPATCHED":
|
| 416 |
+
v.eta -= 1
|
| 417 |
+
if v.eta <= 0:
|
| 418 |
+
v.status = "ON_SCENE"
|
| 419 |
+
v.on_scene_remaining = ON_SCENE_STEPS
|
| 420 |
+
if v.path:
|
| 421 |
+
v.current_node = v.path[-1]
|
| 422 |
+
elif v.status == "ON_SCENE":
|
| 423 |
+
v.on_scene_remaining -= 1
|
| 424 |
+
if v.on_scene_remaining <= 0:
|
| 425 |
+
v.status = "RETURNING"
|
| 426 |
+
v.return_remaining = RETURN_STEPS
|
| 427 |
+
# Mark event resolved
|
| 428 |
+
if v.assigned_event and v.assigned_event in self._active_events:
|
| 429 |
+
resolved.append(v.assigned_event)
|
| 430 |
+
elif v.status == "RETURNING":
|
| 431 |
+
v.return_remaining -= 1
|
| 432 |
+
if v.return_remaining <= 0:
|
| 433 |
+
v.status = "FREE"
|
| 434 |
+
v.current_node = v.home_node
|
| 435 |
+
v.assigned_event = None
|
| 436 |
+
# Auto-dispatch to next queued destination (from hold)
|
| 437 |
+
self._dispatch_next_destination(v)
|
| 438 |
+
|
| 439 |
+
for eid in resolved:
|
| 440 |
+
self._active_events.pop(eid, None)
|
| 441 |
+
|
| 442 |
+
# Clean up stale unassigned events (no unit, no hold, open > 3 steps)
|
| 443 |
+
stale = []
|
| 444 |
+
for eid, evt in self._active_events.items():
|
| 445 |
+
if (evt.get("assigned_unit") is None
|
| 446 |
+
and evt.get("held_for_unit") is None
|
| 447 |
+
and self._state.step_count - evt.get("step_opened", 0) > 3):
|
| 448 |
+
stale.append(eid)
|
| 449 |
+
for eid in stale:
|
| 450 |
+
self._active_events.pop(eid, None)
|
| 451 |
+
|
| 452 |
+
def _dispatch_next_destination(self, v: Vehicle):
|
| 453 |
+
"""If the vehicle has queued destinations, pop the first and dispatch."""
|
| 454 |
+
city = self._city
|
| 455 |
+
assert city is not None
|
| 456 |
+
|
| 457 |
+
while v.destinations:
|
| 458 |
+
dest = v.destinations.pop(0)
|
| 459 |
+
# Only dispatch if the event is still active and unassigned
|
| 460 |
+
evt = self._active_events.get(dest.event_id)
|
| 461 |
+
if evt is not None and evt.get("assigned_unit") is None:
|
| 462 |
+
travel, path = dijkstra(city, v.current_node, dest.node_id)
|
| 463 |
+
v.status = "DISPATCHED"
|
| 464 |
+
v.assigned_event = dest.event_id
|
| 465 |
+
v.eta = max(1, int(travel))
|
| 466 |
+
v.path = path
|
| 467 |
+
evt["assigned_unit"] = v.unit_id
|
| 468 |
+
evt["unit_eta"] = v.eta
|
| 469 |
+
return
|
| 470 |
+
# No valid destinations left — vehicle stays FREE
|
| 471 |
+
|
| 472 |
+
# ── Observation builder ──────────────────────────────────────────────
|
| 473 |
+
|
| 474 |
+
def _build_observation(self) -> str:
|
| 475 |
+
call = self._current_call
|
| 476 |
+
city = self._city
|
| 477 |
+
if call is None or city is None:
|
| 478 |
+
return ""
|
| 479 |
+
|
| 480 |
+
parts = []
|
| 481 |
+
|
| 482 |
+
# 1. Incoming call
|
| 483 |
+
parts.append(f"=== INCOMING CALL [{call.call_id}] ===")
|
| 484 |
+
parts.append(call.transcript)
|
| 485 |
+
parts.append("")
|
| 486 |
+
|
| 487 |
+
# 2. Active events
|
| 488 |
+
parts.append("=== ACTIVE EVENTS ===")
|
| 489 |
+
if self._active_events:
|
| 490 |
+
for eid, evt in self._active_events.items():
|
| 491 |
+
unit = evt.get("assigned_unit")
|
| 492 |
+
held = evt.get("held_for_unit")
|
| 493 |
+
eta = evt.get("unit_eta")
|
| 494 |
+
if unit:
|
| 495 |
+
eta_str = f"ETA {eta} min" if eta else "ON SCENE"
|
| 496 |
+
status_str = f"{unit} {eta_str}"
|
| 497 |
+
elif held:
|
| 498 |
+
status_str = f"HELD → {held}"
|
| 499 |
+
else:
|
| 500 |
+
status_str = "UNASSIGNED"
|
| 501 |
+
sev = evt.get("severity", "?")
|
| 502 |
+
parts.append(
|
| 503 |
+
f"{eid} | {evt['type']:10s} | {evt['node_name']:30s} | "
|
| 504 |
+
f"sev {sev} | {status_str} | opened step {evt['step_opened']}"
|
| 505 |
+
)
|
| 506 |
+
else:
|
| 507 |
+
parts.append("(none)")
|
| 508 |
+
parts.append("")
|
| 509 |
+
|
| 510 |
+
# 3. Unit status
|
| 511 |
+
parts.append("=== UNIT STATUS ===")
|
| 512 |
+
for v in city.vehicles:
|
| 513 |
+
loc = city.nodes[v.current_node].name if v.current_node in city.nodes else v.current_node
|
| 514 |
+
status = v.status
|
| 515 |
+
if v.assigned_event:
|
| 516 |
+
status += f" → {v.assigned_event}"
|
| 517 |
+
parts.append(f"{v.unit_id:15s} | {v.vehicle_type:10s} | {loc:30s} | {status}")
|
| 518 |
+
parts.append("")
|
| 519 |
+
|
| 520 |
+
# 4. City reference (compact adjacency)
|
| 521 |
+
parts.append("=== CITY REFERENCE ===")
|
| 522 |
+
for nid, node in city.nodes.items():
|
| 523 |
+
neighbours = []
|
| 524 |
+
for oid, w in city.edges.get(nid, {}).items():
|
| 525 |
+
oname = city.nodes[oid].name
|
| 526 |
+
neighbours.append(f"{oname} [{w:.0f} min]")
|
| 527 |
+
parts.append(f"{node.name} ({node.node_type}) → {', '.join(neighbours)}")
|
| 528 |
+
parts.append("")
|
| 529 |
+
|
| 530 |
+
# 5. Dispatcher notes
|
| 531 |
+
parts.append("=== DISPATCHER NOTES ===")
|
| 532 |
+
if self._dispatcher_notes:
|
| 533 |
+
for n in self._dispatcher_notes:
|
| 534 |
+
parts.append(n)
|
| 535 |
+
else:
|
| 536 |
+
parts.append("(first call)")
|
| 537 |
+
parts.append("")
|
| 538 |
+
|
| 539 |
+
return "\n".join(parts)
|
| 540 |
+
|
| 541 |
+
# ── Helpers ──────────────────────────────────────────────────────────
|
| 542 |
+
|
| 543 |
+
def _find_vehicle(self, unit_id: str) -> Optional[Vehicle]:
|
| 544 |
+
if self._city is None:
|
| 545 |
+
return None
|
| 546 |
+
for v in self._city.vehicles:
|
| 547 |
+
if v.unit_id == unit_id:
|
| 548 |
+
return v
|
| 549 |
+
return None
|
| 550 |
+
|
| 551 |
+
def _fleet_util(self) -> float:
|
| 552 |
+
if self._city is None or not self._city.vehicles:
|
| 553 |
+
return 0.0
|
| 554 |
+
busy = sum(1 for v in self._city.vehicles if v.status != "FREE")
|
| 555 |
+
return busy / len(self._city.vehicles)
|
| 556 |
+
|
| 557 |
+
@property
|
| 558 |
+
def state(self) -> State:
|
| 559 |
+
return self._state
|
train_sft_grpo.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # 🚨 Smart Emergency Dispatch — SFT → GRPO Training (Colab + Unsloth)
|
| 3 |
+
#
|
| 4 |
+
# Fine-tunes **Qwen3-1.7B** as an emergency 911 dispatcher using **Unsloth** for 2× faster training:
|
| 5 |
+
# 1. **Phase 1 — SFT**: Teach the model the JSON output format
|
| 6 |
+
# 2. **Phase 2 — GRPO**: Improve dispatch strategy via RL against the live HF Space environment
|
| 7 |
+
#
|
| 8 |
+
# **Runtime**: Google Colab with T4 or A100 GPU
|
| 9 |
+
|
| 10 |
+
# %% [markdown]
|
| 11 |
+
# ## 0 · Install Dependencies
|
| 12 |
+
|
| 13 |
+
# %%
|
| 14 |
+
!pip install -Uq unsloth vllm
|
| 15 |
+
!pip install -Uq git+https://github.com/huggingface/trl.git
|
| 16 |
+
!pip install -Uq git+https://github.com/meta-pytorch/OpenEnv.git
|
| 17 |
+
!pip install -Uq git+https://github.com/rishiraj38/Smart_Emergency.git datasets requests
|
| 18 |
+
|
| 19 |
+
# %%
|
| 20 |
+
from huggingface_hub import notebook_login
|
| 21 |
+
notebook_login()
|
| 22 |
+
|
| 23 |
+
# %% [markdown]
|
| 24 |
+
# ## 1 · Configuration
|
| 25 |
+
|
| 26 |
+
# %%
|
| 27 |
+
import os, json, re, random, requests, time
|
| 28 |
+
from collections import defaultdict
|
| 29 |
+
|
| 30 |
+
MODEL_NAME = "unsloth/Qwen3-1.7B-unsloth-bnb-4bit"
|
| 31 |
+
SFT_OUTPUT_DIR = "smart-emergency-sft"
|
| 32 |
+
GRPO_OUTPUT_DIR = "smart-emergency-grpo"
|
| 33 |
+
MAX_SEQ_LENGTH = 3072
|
| 34 |
+
|
| 35 |
+
# HuggingFace Space URL for the environment server
|
| 36 |
+
HF_SPACE_URL = "https://rishi38-eme-enviro.hf.space"
|
| 37 |
+
|
| 38 |
+
# %% [markdown]
|
| 39 |
+
# ## 2 · Connect to Environment
|
| 40 |
+
#
|
| 41 |
+
# Wake the HF Space if sleeping, then connect directly using `SmartEmergencyEnv`.
|
| 42 |
+
|
| 43 |
+
# %%
|
| 44 |
+
import requests, time
|
| 45 |
+
from smart_emergency import SmartEmergencyEnv, SmartEmergencyAction
|
| 46 |
+
|
| 47 |
+
# Ping the Space health endpoint until it wakes up (free Spaces sleep after inactivity)
|
| 48 |
+
print("⏳ Waking up HF Space (may take 30-60s if sleeping) …")
|
| 49 |
+
for _attempt in range(60):
|
| 50 |
+
try:
|
| 51 |
+
r = requests.get(f"{HF_SPACE_URL}/health", timeout=5)
|
| 52 |
+
if r.status_code == 200:
|
| 53 |
+
print(f"✅ Space awake at {HF_SPACE_URL}")
|
| 54 |
+
break
|
| 55 |
+
except Exception:
|
| 56 |
+
pass
|
| 57 |
+
time.sleep(2)
|
| 58 |
+
else:
|
| 59 |
+
raise RuntimeError("HF Space did not respond after 2 minutes. Check the URL.")
|
| 60 |
+
|
| 61 |
+
# Direct WebSocket connection via the official client
|
| 62 |
+
env = SmartEmergencyEnv(base_url=HF_SPACE_URL).sync()
|
| 63 |
+
_test = env.reset()
|
| 64 |
+
print(f"✅ Connected — first call: {_test.observation.call_id}")
|
| 65 |
+
|
| 66 |
+
# %% [markdown]
|
| 67 |
+
# ## 3 · System Prompt
|
| 68 |
+
|
| 69 |
+
# %%
|
| 70 |
+
SYSTEM_PROMPT = """\
|
| 71 |
+
You are an expert 911 emergency dispatcher. You receive incoming calls and must make rapid, structured dispatch decisions.
|
| 72 |
+
|
| 73 |
+
## RULES
|
| 74 |
+
1. Each step you see: an incoming call transcript, active events, unit status, and a city map.
|
| 75 |
+
2. You must respond with a single JSON object — nothing else.
|
| 76 |
+
|
| 77 |
+
## ACTION TYPES
|
| 78 |
+
You have three action types: `dispatch`, `duplicate`, and `hold`.
|
| 79 |
+
|
| 80 |
+
### 1. dispatch — Handle a new emergency
|
| 81 |
+
Use when a FREE vehicle of the correct type is available.
|
| 82 |
+
```json
|
| 83 |
+
{
|
| 84 |
+
"action_type": "dispatch",
|
| 85 |
+
"severity_pred": <int 1-5>,
|
| 86 |
+
"is_duplicate": false,
|
| 87 |
+
"duplicate_of_event_id": null,
|
| 88 |
+
"vehicle_type": "police" | "ambulance" | "fire",
|
| 89 |
+
"vehicle_id": "<unit_id of a FREE vehicle>",
|
| 90 |
+
"reroute": null
|
| 91 |
+
}
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### 2. duplicate — Flag a repeat call
|
| 95 |
+
Use when the incoming call matches an existing active event (same location/type).
|
| 96 |
+
```json
|
| 97 |
+
{
|
| 98 |
+
"action_type": "duplicate",
|
| 99 |
+
"severity_pred": <int 1-5>,
|
| 100 |
+
"is_duplicate": true,
|
| 101 |
+
"duplicate_of_event_id": "<EVT-NNNN>",
|
| 102 |
+
"vehicle_type": null,
|
| 103 |
+
"vehicle_id": null,
|
| 104 |
+
"reroute": null
|
| 105 |
+
}
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
### 3. hold — Queue for a busy vehicle
|
| 109 |
+
Use ONLY when ALL vehicles of the required type are busy (none are FREE).
|
| 110 |
+
```json
|
| 111 |
+
{
|
| 112 |
+
"action_type": "hold",
|
| 113 |
+
"severity_pred": <int 1-5>,
|
| 114 |
+
"is_duplicate": false,
|
| 115 |
+
"duplicate_of_event_id": null,
|
| 116 |
+
"vehicle_type": "police" | "ambulance" | "fire",
|
| 117 |
+
"vehicle_id": "<unit_id of a BUSY vehicle to queue behind>",
|
| 118 |
+
"reroute": null
|
| 119 |
+
}
|
| 120 |
+
```
|
| 121 |
+
**Hold rules:** NEVER hold if a free unit exists. Pick the vehicle with the lowest ETA.
|
| 122 |
+
|
| 123 |
+
## REROUTE (optional, only with dispatch)
|
| 124 |
+
Redirect an in-flight vehicle from a LOWER-severity event to this HIGHER-severity one:
|
| 125 |
+
```json
|
| 126 |
+
"reroute": {
|
| 127 |
+
"vehicle_to_reroute": "<DISPATCHED unit_id>",
|
| 128 |
+
"from_event_id": "<EVT-NNNN>",
|
| 129 |
+
"replacement_vehicle_id": "<FREE unit or null>"
|
| 130 |
+
}
|
| 131 |
+
```
|
| 132 |
+
Only reroute DISPATCHED vehicles. Only reroute from lower to higher severity.
|
| 133 |
+
|
| 134 |
+
## SEVERITY GUIDE
|
| 135 |
+
1=minor, 2=moderate, 3=serious, 4=critical, 5=catastrophic
|
| 136 |
+
|
| 137 |
+
## VEHICLE GUIDE
|
| 138 |
+
- **fire** → fire, smoke, flames, gas leak
|
| 139 |
+
- **police** → shooting, robbery, fight, break-in
|
| 140 |
+
- **ambulance** → medical, crash, accident, injury, collapse
|
| 141 |
+
|
| 142 |
+
## STRATEGY
|
| 143 |
+
- Pick the nearest FREE vehicle (use CITY REFERENCE distances).
|
| 144 |
+
- If call matches an ACTIVE EVENT, flag as duplicate.
|
| 145 |
+
- No free units → use `hold`. Higher severity than busy units → consider `reroute`.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
# %% [markdown]
|
| 149 |
+
# ---
|
| 150 |
+
# # Phase 1 — Supervised Fine-Tuning (SFT)
|
| 151 |
+
|
| 152 |
+
# %% [markdown]
|
| 153 |
+
# ### Observation Parsing Helpers
|
| 154 |
+
|
| 155 |
+
# %%
|
| 156 |
+
def parse_free_vehicles(obs_text: str) -> dict:
|
| 157 |
+
"""Return {unit_id: vehicle_type} for FREE vehicles."""
|
| 158 |
+
vehicles = {}
|
| 159 |
+
in_section = False
|
| 160 |
+
for line in obs_text.split("\n"):
|
| 161 |
+
if "=== UNIT STATUS ===" in line:
|
| 162 |
+
in_section = True; continue
|
| 163 |
+
if in_section and line.startswith("==="):
|
| 164 |
+
break
|
| 165 |
+
if in_section and "|" in line and "FREE" in line:
|
| 166 |
+
parts = [p.strip() for p in line.split("|")]
|
| 167 |
+
if len(parts) >= 2:
|
| 168 |
+
vehicles[parts[0]] = parts[1]
|
| 169 |
+
return vehicles
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def parse_all_vehicles(obs_text: str) -> list:
|
| 173 |
+
"""Return list of {id, type, status} for ALL vehicles."""
|
| 174 |
+
vehicles = []
|
| 175 |
+
in_section = False
|
| 176 |
+
for line in obs_text.split("\n"):
|
| 177 |
+
if "=== UNIT STATUS ===" in line:
|
| 178 |
+
in_section = True; continue
|
| 179 |
+
if in_section and line.startswith("==="):
|
| 180 |
+
break
|
| 181 |
+
if in_section and "|" in line:
|
| 182 |
+
parts = [p.strip() for p in line.split("|")]
|
| 183 |
+
if len(parts) >= 4:
|
| 184 |
+
status = parts[3].split()[0] if parts[3] else "UNKNOWN"
|
| 185 |
+
vehicles.append({"id": parts[0], "type": parts[1], "status": status})
|
| 186 |
+
return vehicles
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def parse_active_events(obs_text: str) -> dict:
|
| 190 |
+
events = {}
|
| 191 |
+
in_section = False
|
| 192 |
+
for line in obs_text.split("\n"):
|
| 193 |
+
if "=== ACTIVE EVENTS ===" in line:
|
| 194 |
+
in_section = True; continue
|
| 195 |
+
if in_section and line.startswith("==="):
|
| 196 |
+
break
|
| 197 |
+
if in_section and "|" in line and "EVT-" in line:
|
| 198 |
+
parts = [p.strip() for p in line.split("|")]
|
| 199 |
+
if len(parts) >= 2:
|
| 200 |
+
events[parts[0]] = parts[1]
|
| 201 |
+
return events
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
TYPE_TO_VEHICLE = {"fire": "fire", "medical": "ambulance", "crime": "police", "accident": "ambulance"}
|
| 205 |
+
|
| 206 |
+
SEV_KW = {
|
| 207 |
+
5: ["not breathing", "active shooter", "trapped", "mass incident", "whole block", "pileup", "send everything"],
|
| 208 |
+
4: ["won't wake", "gunshots", "flipped", "blood everywhere", "kids are upstairs", "not responding"],
|
| 209 |
+
3: ["chest pain", "fight", "mugged", "knife", "crash", "bleeding", "fire at", "flames", "cyclist"],
|
| 210 |
+
2: ["fainted", "break-in", "dumpster", "fender", "small fire", "ankle", "shoplifter"],
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def heuristic_severity(text):
|
| 215 |
+
t = text.lower()
|
| 216 |
+
for sev in [5, 4, 3, 2]:
|
| 217 |
+
if any(kw in t for kw in SEV_KW[sev]):
|
| 218 |
+
return sev
|
| 219 |
+
return 1
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def heuristic_vehicle_type(text):
|
| 223 |
+
t = text.lower()
|
| 224 |
+
if any(w in t for w in ["fire", "flames", "smoke", "burning", "gas leak"]):
|
| 225 |
+
return "fire"
|
| 226 |
+
if any(w in t for w in ["shooter", "gunshot", "mugged", "knife", "break-in", "fight", "shoplifter"]):
|
| 227 |
+
return "police"
|
| 228 |
+
return "ambulance"
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def pick_free(free_vehicles, vtype):
|
| 232 |
+
for vid, vt in free_vehicles.items():
|
| 233 |
+
if vt == vtype:
|
| 234 |
+
return vid
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def pick_busy(all_vehicles, vtype):
|
| 239 |
+
for v in all_vehicles:
|
| 240 |
+
if v["type"] == vtype and v["status"] != "FREE":
|
| 241 |
+
return v["id"]
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
# %% [markdown]
|
| 245 |
+
# ### Generate SFT Dataset
|
| 246 |
+
|
| 247 |
+
# %%
|
| 248 |
+
def build_ideal_action(gt, obs_text):
|
| 249 |
+
"""Build ideal JSON action dict from ground truth + observation."""
|
| 250 |
+
sev = gt.get("severity", 1)
|
| 251 |
+
vtype = gt.get("required_vehicle_type", "ambulance")
|
| 252 |
+
is_dup = gt.get("is_duplicate", False)
|
| 253 |
+
|
| 254 |
+
if is_dup:
|
| 255 |
+
active = parse_active_events(obs_text)
|
| 256 |
+
etype = gt.get("emergency_type", "")
|
| 257 |
+
dup_eid = None
|
| 258 |
+
for eid, et in active.items():
|
| 259 |
+
if et.strip() == etype:
|
| 260 |
+
dup_eid = eid; break
|
| 261 |
+
if dup_eid is None and active:
|
| 262 |
+
dup_eid = list(active.keys())[0]
|
| 263 |
+
return {"action_type": "duplicate", "severity_pred": sev, "is_duplicate": True,
|
| 264 |
+
"duplicate_of_event_id": dup_eid, "vehicle_type": None, "vehicle_id": None, "reroute": None}
|
| 265 |
+
|
| 266 |
+
free = parse_free_vehicles(obs_text)
|
| 267 |
+
vid = pick_free(free, vtype)
|
| 268 |
+
if vid:
|
| 269 |
+
return {"action_type": "dispatch", "severity_pred": sev, "is_duplicate": False,
|
| 270 |
+
"duplicate_of_event_id": None, "vehicle_type": vtype, "vehicle_id": vid, "reroute": None}
|
| 271 |
+
|
| 272 |
+
busy_vid = pick_busy(parse_all_vehicles(obs_text), vtype)
|
| 273 |
+
if busy_vid:
|
| 274 |
+
return {"action_type": "hold", "severity_pred": sev, "is_duplicate": False,
|
| 275 |
+
"duplicate_of_event_id": None, "vehicle_type": vtype, "vehicle_id": busy_vid, "reroute": None}
|
| 276 |
+
|
| 277 |
+
return {"action_type": "dispatch", "severity_pred": sev, "is_duplicate": False,
|
| 278 |
+
"duplicate_of_event_id": None, "vehicle_type": vtype, "vehicle_id": f"{vtype}_0", "reroute": None}
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def generate_sft_data(env, num_episodes=60):
|
| 282 |
+
examples = []
|
| 283 |
+
for ep in range(num_episodes):
|
| 284 |
+
task_id = (ep % 3) + 1
|
| 285 |
+
result = env.reset(task_id=task_id)
|
| 286 |
+
prev_obs = result.observation.prompt
|
| 287 |
+
|
| 288 |
+
while not result.done:
|
| 289 |
+
free = parse_free_vehicles(prev_obs)
|
| 290 |
+
vtype = heuristic_vehicle_type(prev_obs)
|
| 291 |
+
vid = pick_free(free, vtype)
|
| 292 |
+
action = SmartEmergencyAction(
|
| 293 |
+
action_type="dispatch",
|
| 294 |
+
severity_pred=heuristic_severity(prev_obs),
|
| 295 |
+
is_duplicate=False,
|
| 296 |
+
vehicle_type=vtype,
|
| 297 |
+
vehicle_id=vid,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
result = env.step(action)
|
| 301 |
+
# ground_truth is now a first-class field on the observation;
|
| 302 |
+
# fall back to metadata for backward compatibility with older servers.
|
| 303 |
+
gt = result.observation.ground_truth or result.observation.metadata.get("ground_truth")
|
| 304 |
+
if gt:
|
| 305 |
+
ideal = build_ideal_action(gt, prev_obs)
|
| 306 |
+
examples.append({
|
| 307 |
+
"messages": [
|
| 308 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 309 |
+
{"role": "user", "content": prev_obs},
|
| 310 |
+
{"role": "assistant", "content": json.dumps(ideal)},
|
| 311 |
+
]
|
| 312 |
+
})
|
| 313 |
+
prev_obs = result.observation.prompt
|
| 314 |
+
|
| 315 |
+
if (ep + 1) % 10 == 0:
|
| 316 |
+
print(f" Episodes: {ep+1}/{num_episodes} | examples: {len(examples)}")
|
| 317 |
+
return examples
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
print("📝 Generating SFT data …")
|
| 321 |
+
sft_examples = generate_sft_data(env, num_episodes=60)
|
| 322 |
+
print(f"✅ Collected {len(sft_examples)} SFT examples")
|
| 323 |
+
|
| 324 |
+
# %%
|
| 325 |
+
from datasets import Dataset
|
| 326 |
+
sft_dataset = Dataset.from_list(sft_examples)
|
| 327 |
+
print(sft_dataset)
|
| 328 |
+
|
| 329 |
+
# %% [markdown]
|
| 330 |
+
# ### SFT Training with Unsloth
|
| 331 |
+
|
| 332 |
+
# %%
|
| 333 |
+
from unsloth import FastLanguageModel
|
| 334 |
+
from trl import SFTTrainer, SFTConfig
|
| 335 |
+
|
| 336 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 337 |
+
model_name=MODEL_NAME,
|
| 338 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 339 |
+
load_in_4bit=True,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
model = FastLanguageModel.get_peft_model(
|
| 343 |
+
model,
|
| 344 |
+
r=16,
|
| 345 |
+
lora_alpha=32,
|
| 346 |
+
lora_dropout=0.05,
|
| 347 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 348 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 349 |
+
use_gradient_checkpointing="unsloth",
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
sft_config = SFTConfig(
|
| 353 |
+
output_dir=SFT_OUTPUT_DIR,
|
| 354 |
+
num_train_epochs=3,
|
| 355 |
+
per_device_train_batch_size=2,
|
| 356 |
+
gradient_accumulation_steps=8,
|
| 357 |
+
learning_rate=2e-4,
|
| 358 |
+
lr_scheduler_type="cosine",
|
| 359 |
+
warmup_ratio=0.1,
|
| 360 |
+
logging_steps=5,
|
| 361 |
+
save_steps=50,
|
| 362 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 363 |
+
bf16=True,
|
| 364 |
+
report_to="none",
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
sft_trainer = SFTTrainer(
|
| 368 |
+
model=model,
|
| 369 |
+
processing_class=tokenizer,
|
| 370 |
+
train_dataset=sft_dataset,
|
| 371 |
+
args=sft_config,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# %%
|
| 375 |
+
print("🏋️ Starting SFT training …")
|
| 376 |
+
sft_trainer.train()
|
| 377 |
+
print("✅ SFT complete")
|
| 378 |
+
|
| 379 |
+
# %%
|
| 380 |
+
sft_trainer.save_model(SFT_OUTPUT_DIR)
|
| 381 |
+
tokenizer.save_pretrained(SFT_OUTPUT_DIR)
|
| 382 |
+
print(f"✅ SFT model saved to {SFT_OUTPUT_DIR}/")
|
| 383 |
+
|
| 384 |
+
# Free memory
|
| 385 |
+
import torch, gc
|
| 386 |
+
del model, sft_trainer
|
| 387 |
+
gc.collect()
|
| 388 |
+
torch.cuda.empty_cache()
|
| 389 |
+
|
| 390 |
+
# %% [markdown]
|
| 391 |
+
# ---
|
| 392 |
+
# # Phase 2 — GRPO with Unsloth
|
| 393 |
+
|
| 394 |
+
# %% [markdown]
|
| 395 |
+
# ### Action Parsing
|
| 396 |
+
|
| 397 |
+
# %%
|
| 398 |
+
def parse_llm_action(text):
|
| 399 |
+
"""Extract action dict from LLM output."""
|
| 400 |
+
m = re.search(r"```json\s*(.*?)```", text, re.DOTALL)
|
| 401 |
+
if m:
|
| 402 |
+
text = m.group(1)
|
| 403 |
+
else:
|
| 404 |
+
m = re.search(r"\{.*\}", text, re.DOTALL)
|
| 405 |
+
if m:
|
| 406 |
+
text = m.group(0)
|
| 407 |
+
try:
|
| 408 |
+
d = json.loads(text)
|
| 409 |
+
# Validate required fields
|
| 410 |
+
assert d.get("action_type") in ("dispatch", "duplicate", "hold")
|
| 411 |
+
assert 1 <= int(d.get("severity_pred", 0)) <= 5
|
| 412 |
+
return d
|
| 413 |
+
except Exception:
|
| 414 |
+
return None
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def fallback_action(obs_text):
|
| 418 |
+
free = parse_free_vehicles(obs_text)
|
| 419 |
+
vtype = heuristic_vehicle_type(obs_text)
|
| 420 |
+
vid = pick_free(free, vtype)
|
| 421 |
+
if vid:
|
| 422 |
+
return {"action_type": "dispatch", "severity_pred": heuristic_severity(obs_text),
|
| 423 |
+
"is_duplicate": False, "vehicle_type": vtype, "vehicle_id": vid}
|
| 424 |
+
busy_vid = pick_busy(parse_all_vehicles(obs_text), vtype)
|
| 425 |
+
return {"action_type": "hold" if busy_vid else "dispatch",
|
| 426 |
+
"severity_pred": heuristic_severity(obs_text), "is_duplicate": False,
|
| 427 |
+
"vehicle_type": vtype, "vehicle_id": busy_vid or f"{vtype}_0"}
|
| 428 |
+
|
| 429 |
+
# %% [markdown]
|
| 430 |
+
# ### Rollout Functions
|
| 431 |
+
|
| 432 |
+
# %%
|
| 433 |
+
from unsloth import FastLanguageModel, PatchFastRL
|
| 434 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 435 |
+
|
| 436 |
+
# Patch TRL for Unsloth compatibility
|
| 437 |
+
PatchFastRL("GRPO", FastLanguageModel)
|
| 438 |
+
|
| 439 |
+
# Load the SFT model for GRPO with fast inference (vLLM)
|
| 440 |
+
grpo_model, grpo_tokenizer = FastLanguageModel.from_pretrained(
|
| 441 |
+
model_name=SFT_OUTPUT_DIR,
|
| 442 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 443 |
+
load_in_4bit=True,
|
| 444 |
+
fast_inference=True, # enables vLLM for GRPO generation
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
grpo_model = FastLanguageModel.get_peft_model(
|
| 448 |
+
grpo_model,
|
| 449 |
+
r=16,
|
| 450 |
+
lora_alpha=32,
|
| 451 |
+
lora_dropout=0.05,
|
| 452 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 453 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 454 |
+
use_gradient_checkpointing="unsloth",
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# %%
|
| 458 |
+
def make_user_prompt(obs_text):
|
| 459 |
+
return f"You are the dispatcher. Read the situation and respond with a single JSON action.\n\n{obs_text}\n\nRespond ONLY with a JSON object."
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def action_dict_to_obj(d):
|
| 463 |
+
"""Convert a plain dict action to SmartEmergencyAction."""
|
| 464 |
+
from smart_emergency import RerouteAction
|
| 465 |
+
reroute = None
|
| 466 |
+
if d.get("reroute") and isinstance(d["reroute"], dict):
|
| 467 |
+
rd = d["reroute"]
|
| 468 |
+
reroute = RerouteAction(
|
| 469 |
+
vehicle_to_reroute=rd["vehicle_to_reroute"],
|
| 470 |
+
from_event_id=rd["from_event_id"],
|
| 471 |
+
replacement_vehicle_id=rd.get("replacement_vehicle_id"),
|
| 472 |
+
)
|
| 473 |
+
return SmartEmergencyAction(
|
| 474 |
+
action_type=d.get("action_type", "dispatch"),
|
| 475 |
+
severity_pred=int(d.get("severity_pred", 1)),
|
| 476 |
+
is_duplicate=bool(d.get("is_duplicate", False)),
|
| 477 |
+
duplicate_of_event_id=d.get("duplicate_of_event_id"),
|
| 478 |
+
vehicle_type=d.get("vehicle_type"),
|
| 479 |
+
vehicle_id=d.get("vehicle_id"),
|
| 480 |
+
reroute=reroute,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def rollout_once(trainer, env, tokenizer, system_prompt, max_turns=15):
|
| 485 |
+
"""Run one full episode."""
|
| 486 |
+
from trl.experimental.openenv import generate_rollout_completions
|
| 487 |
+
|
| 488 |
+
result = env.reset()
|
| 489 |
+
prompt_ids, completion_ids, logprobs = [], [], []
|
| 490 |
+
rewards = {k: [] for k in ["severity", "duplicate", "vehicle_type", "vehicle_choice", "reroute", "format"]}
|
| 491 |
+
|
| 492 |
+
for _ in range(max_turns):
|
| 493 |
+
if result.done:
|
| 494 |
+
break
|
| 495 |
+
|
| 496 |
+
obs_text = result.observation.prompt
|
| 497 |
+
messages = [
|
| 498 |
+
{"role": "system", "content": system_prompt},
|
| 499 |
+
{"role": "user", "content": make_user_prompt(obs_text)},
|
| 500 |
+
]
|
| 501 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 502 |
+
messages, add_generation_prompt=True, tokenize=False, enable_thinking=False,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
out = generate_rollout_completions(trainer, [prompt_text])[0]
|
| 506 |
+
prompt_ids.extend(out["prompt_ids"])
|
| 507 |
+
completion_ids.extend(out["completion_ids"])
|
| 508 |
+
logprobs.extend(out["logprobs"])
|
| 509 |
+
|
| 510 |
+
comp_text = out.get("text") or tokenizer.decode(out["completion_ids"], skip_special_tokens=True)
|
| 511 |
+
action_d = parse_llm_action(comp_text)
|
| 512 |
+
parse_ok = action_d is not None
|
| 513 |
+
if action_d is None:
|
| 514 |
+
action_d = fallback_action(obs_text)
|
| 515 |
+
action = action_dict_to_obj(action_d)
|
| 516 |
+
|
| 517 |
+
result = env.step(action)
|
| 518 |
+
bd = result.observation.reward_breakdown
|
| 519 |
+
|
| 520 |
+
rewards["severity"].append(bd.get("severity", 0.0))
|
| 521 |
+
rewards["duplicate"].append(bd.get("duplicate", 0.0))
|
| 522 |
+
rewards["vehicle_type"].append(bd.get("vehicle_type", 0.0))
|
| 523 |
+
rewards["vehicle_choice"].append(bd.get("vehicle_choice", 0.0))
|
| 524 |
+
rewards["reroute"].append(bd.get("reroute", 0.0))
|
| 525 |
+
rewards["format"].append(1.0 if parse_ok else -2.0)
|
| 526 |
+
|
| 527 |
+
return {
|
| 528 |
+
"prompt_ids": prompt_ids,
|
| 529 |
+
"completion_ids": completion_ids,
|
| 530 |
+
"logprobs": logprobs,
|
| 531 |
+
**{f"{k}_reward": v[-1] if v else 0.0 for k, v in rewards.items()},
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def rollout_func(prompts, trainer=None):
|
| 536 |
+
"""GRPO rollout — called by GRPOTrainer each step."""
|
| 537 |
+
results = {k: [] for k in ["prompt_ids", "completion_ids", "logprobs",
|
| 538 |
+
"severity_reward", "duplicate_reward", "vehicle_type_reward",
|
| 539 |
+
"vehicle_choice_reward", "reroute_reward", "format_reward"]}
|
| 540 |
+
|
| 541 |
+
for _ in prompts:
|
| 542 |
+
ep = rollout_once(trainer, env, grpo_tokenizer, SYSTEM_PROMPT)
|
| 543 |
+
for k in results:
|
| 544 |
+
results[k].append(ep[k])
|
| 545 |
+
return results
|
| 546 |
+
|
| 547 |
+
# %% [markdown]
|
| 548 |
+
# ### Reward Wrappers & Config
|
| 549 |
+
|
| 550 |
+
# %%
|
| 551 |
+
def _make_reward_fn(key):
|
| 552 |
+
def fn(completions, **kwargs):
|
| 553 |
+
r = kwargs.get(key)
|
| 554 |
+
return [float(x) for x in r] if r else [0.0] * len(completions)
|
| 555 |
+
fn.__name__ = f"reward_{key.replace('_reward', '')}"
|
| 556 |
+
return fn
|
| 557 |
+
|
| 558 |
+
reward_fns = [_make_reward_fn(k) for k in
|
| 559 |
+
["severity_reward", "duplicate_reward", "vehicle_type_reward",
|
| 560 |
+
"vehicle_choice_reward", "reroute_reward", "format_reward"]]
|
| 561 |
+
|
| 562 |
+
# %%
|
| 563 |
+
grpo_dataset = Dataset.from_dict({
|
| 564 |
+
"prompt": ["Dispatch emergency services for incoming 911 calls."] * 500
|
| 565 |
+
})
|
| 566 |
+
|
| 567 |
+
grpo_config = GRPOConfig(
|
| 568 |
+
num_train_epochs=1,
|
| 569 |
+
learning_rate=5e-6,
|
| 570 |
+
gradient_accumulation_steps=32,
|
| 571 |
+
per_device_train_batch_size=1,
|
| 572 |
+
warmup_steps=10,
|
| 573 |
+
num_generations=4,
|
| 574 |
+
max_completion_length=128,
|
| 575 |
+
max_prompt_length=MAX_SEQ_LENGTH,
|
| 576 |
+
use_vllm=True,
|
| 577 |
+
output_dir=GRPO_OUTPUT_DIR,
|
| 578 |
+
logging_steps=1,
|
| 579 |
+
save_steps=10,
|
| 580 |
+
push_to_hub=True,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# %% [markdown]
|
| 584 |
+
# ### Train GRPO
|
| 585 |
+
|
| 586 |
+
# %%
|
| 587 |
+
grpo_trainer = GRPOTrainer(
|
| 588 |
+
model=grpo_model,
|
| 589 |
+
processing_class=grpo_tokenizer,
|
| 590 |
+
reward_funcs=reward_fns,
|
| 591 |
+
train_dataset=grpo_dataset,
|
| 592 |
+
args=grpo_config,
|
| 593 |
+
rollout_func=rollout_func,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
import torch
|
| 597 |
+
gpu = torch.cuda.get_device_properties(0)
|
| 598 |
+
print(f"GPU: {gpu.name} | {round(gpu.total_memory/1024**3, 1)} GB")
|
| 599 |
+
print(f"Reserved: {round(torch.cuda.max_memory_reserved()/1024**3, 2)} GB")
|
| 600 |
+
|
| 601 |
+
# %%
|
| 602 |
+
print("🏋️ Starting GRPO training …")
|
| 603 |
+
stats = grpo_trainer.train()
|
| 604 |
+
print("✅ GRPO complete")
|
| 605 |
+
|
| 606 |
+
# %%
|
| 607 |
+
peak = round(torch.cuda.max_memory_reserved() / 1024**3, 2)
|
| 608 |
+
print(f"Peak memory: {peak} GB | Time: {round(stats.metrics['train_runtime']/60, 1)} min")
|
| 609 |
+
|
| 610 |
+
grpo_trainer.save_model(GRPO_OUTPUT_DIR)
|
| 611 |
+
grpo_trainer.push_to_hub()
|
| 612 |
+
print(f"✅ Model saved & pushed to Hub")
|
| 613 |
+
|
| 614 |
+
# %% [markdown]
|
| 615 |
+
# ---
|
| 616 |
+
# # Phase 3 — Inference & Evaluation
|
| 617 |
+
|
| 618 |
+
# %%
|
| 619 |
+
from unsloth import FastLanguageModel as FLM
|
| 620 |
+
|
| 621 |
+
inf_model, inf_tokenizer = FLM.from_pretrained(
|
| 622 |
+
model_name=GRPO_OUTPUT_DIR, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=True,
|
| 623 |
+
)
|
| 624 |
+
FLM.for_inference(inf_model)
|
| 625 |
+
|
| 626 |
+
def run_episode(env, model, tokenizer, task_id=1):
|
| 627 |
+
result = env.reset(task_id=task_id)
|
| 628 |
+
total_reward = 0.0
|
| 629 |
+
|
| 630 |
+
for step in range(20):
|
| 631 |
+
if result.done:
|
| 632 |
+
break
|
| 633 |
+
obs_text = result.observation.prompt
|
| 634 |
+
messages = [
|
| 635 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 636 |
+
{"role": "user", "content": make_user_prompt(obs_text)},
|
| 637 |
+
]
|
| 638 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 639 |
+
messages, add_generation_prompt=True, tokenize=False, enable_thinking=False,
|
| 640 |
+
)
|
| 641 |
+
inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
|
| 642 |
+
gen = model.generate(**inputs, max_new_tokens=256, temperature=0.1)
|
| 643 |
+
output = tokenizer.decode(gen[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
|
| 644 |
+
|
| 645 |
+
action_d = parse_llm_action(output) or fallback_action(obs_text)
|
| 646 |
+
action = action_dict_to_obj(action_d)
|
| 647 |
+
tag = "✅" if parse_llm_action(output) else "⚠️"
|
| 648 |
+
print(f" Step {step}: {tag} {action_d.get('action_type')} sev={action_d.get('severity_pred')}")
|
| 649 |
+
|
| 650 |
+
result = env.step(action)
|
| 651 |
+
total_reward += result.observation.reward_breakdown.get("total", 0.0)
|
| 652 |
+
|
| 653 |
+
print(f"\n Done — reward: {total_reward:.2f} over {step+1} steps")
|
| 654 |
+
return total_reward
|
| 655 |
+
|
| 656 |
+
# %%
|
| 657 |
+
print("=" * 50)
|
| 658 |
+
print("Evaluation — Task 1 (Easy)")
|
| 659 |
+
print("=" * 50)
|
| 660 |
+
run_episode(env, inf_model, inf_tokenizer, task_id=1)
|
| 661 |
+
env.close()
|
train_sft_grpo_graph.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|