Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- Dockerfile +35 -0
- README.md +263 -0
- SPACES_HEADER.md +16 -0
- inference.py +243 -0
- openenv.yaml +132 -0
- requirements.txt +6 -0
- server.py +114 -0
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CustomerSupportEnv β Dockerfile
|
| 2 |
+
# Compatible with Hugging Face Spaces (port 7860)
|
| 3 |
+
# Build: docker build -t customer-support-env .
|
| 4 |
+
# Run: docker run -p 7860:7860 customer-support-env
|
| 5 |
+
|
| 6 |
+
FROM python:3.11-slim
|
| 7 |
+
|
| 8 |
+
LABEL maintainer="openenv-submission"
|
| 9 |
+
LABEL description="CustomerSupportEnv β OpenEnv-compatible customer support RL environment"
|
| 10 |
+
|
| 11 |
+
# System deps
|
| 12 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
curl \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Copy requirements first for layer caching
|
| 19 |
+
COPY requirements.txt .
|
| 20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Copy source
|
| 23 |
+
COPY . .
|
| 24 |
+
|
| 25 |
+
# Create non-root user (HF Spaces requirement)
|
| 26 |
+
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
|
| 27 |
+
USER appuser
|
| 28 |
+
|
| 29 |
+
# Health check
|
| 30 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \
|
| 31 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 32 |
+
|
| 33 |
+
EXPOSE 7860
|
| 34 |
+
|
| 35 |
+
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
README.md
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CustomerSupportEnv
|
| 2 |
+
|
| 3 |
+
> An OpenEnv-compatible reinforcement learning environment for training and evaluating AI customer support agents.
|
| 4 |
+
|
| 5 |
+
[](openenv.yaml)
|
| 6 |
+
[](https://huggingface.co/spaces)
|
| 7 |
+
[](Dockerfile)
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## Overview
|
| 12 |
+
|
| 13 |
+
**CustomerSupportEnv** simulates a real-world Tier-1 customer support workflow. An agent handles inbound support tickets by searching a knowledge base, empathising with customers, asking clarifying questions, and delivering concrete solutions β all within a multi-turn conversation.
|
| 14 |
+
|
| 15 |
+
This environment is designed for:
|
| 16 |
+
- Training RL agents on real-world NLP tasks
|
| 17 |
+
- Benchmarking LLM-based tool-use and retrieval-augmented reasoning
|
| 18 |
+
- Evaluating customer satisfaction optimisation policies
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Quick Start
|
| 23 |
+
|
| 24 |
+
### Docker (recommended)
|
| 25 |
+
```bash
|
| 26 |
+
git clone https://huggingface.co/spaces/<your-username>/customer-support-env
|
| 27 |
+
cd customer-support-env
|
| 28 |
+
docker build -t customer-support-env .
|
| 29 |
+
docker run -p 7860:7860 customer-support-env
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
### Local
|
| 33 |
+
```bash
|
| 34 |
+
pip install -r requirements.txt
|
| 35 |
+
uvicorn server:app --host 0.0.0.0 --port 7860
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
### Run baseline inference
|
| 39 |
+
```bash
|
| 40 |
+
export API_BASE_URL=https://api.openai.com/v1
|
| 41 |
+
export MODEL_NAME=gpt-4o-mini
|
| 42 |
+
export HF_TOKEN=sk-...
|
| 43 |
+
python inference.py
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## Environment Description
|
| 49 |
+
|
| 50 |
+
Each **episode** = one customer support ticket. The agent takes a sequence of actions (turns) until it calls `resolve()` or exceeds `max_turns`.
|
| 51 |
+
|
| 52 |
+
### Real-world fidelity
|
| 53 |
+
- Tickets span 5 categories: **auth**, **billing**, **fulfillment**, **bug**, **sales**
|
| 54 |
+
- Customers have dynamic sentiment: **positive / neutral / frustrated / angry**
|
| 55 |
+
- Knowledge base retrieval is gated β agent must explicitly call `search_kb`
|
| 56 |
+
- Conversation history accumulates across turns, mirroring real support tooling
|
| 57 |
+
- CSAT (customer satisfaction) is a synthetic secondary objective
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
## OpenEnv API
|
| 62 |
+
|
| 63 |
+
### `POST /reset`
|
| 64 |
+
```json
|
| 65 |
+
{ "task_id": "task_1" }
|
| 66 |
+
```
|
| 67 |
+
Returns an `Observation`. Initialises a fresh episode.
|
| 68 |
+
|
| 69 |
+
### `POST /step`
|
| 70 |
+
```json
|
| 71 |
+
{ "task_id": "task_1", "action_type": "search_kb", "payload": null }
|
| 72 |
+
```
|
| 73 |
+
Returns a `StepResult` containing `observation`, `reward`, `done`, `info`.
|
| 74 |
+
|
| 75 |
+
### `GET /state?task_id=task_1`
|
| 76 |
+
Returns the current `Observation` without advancing the environment.
|
| 77 |
+
|
| 78 |
+
### `POST /grade`
|
| 79 |
+
```json
|
| 80 |
+
{ "task_id": "task_1" }
|
| 81 |
+
```
|
| 82 |
+
Returns a `GraderResult` with score (0.0β1.0), breakdown, and pass/fail.
|
| 83 |
+
|
| 84 |
+
### `GET /tasks`
|
| 85 |
+
Lists all task specs.
|
| 86 |
+
|
| 87 |
+
### `GET /health`
|
| 88 |
+
Returns `{"status": "ok"}`.
|
| 89 |
+
|
| 90 |
+
---
|
| 91 |
+
|
| 92 |
+
## Observation Space
|
| 93 |
+
|
| 94 |
+
| Field | Type | Description |
|
| 95 |
+
|-------|------|-------------|
|
| 96 |
+
| `ticket_id` | string | Ticket identifier (e.g. `TKT-001`) |
|
| 97 |
+
| `task_id` | string | Active task (`task_1` / `task_2` / `task_3`) |
|
| 98 |
+
| `status` | enum | `idle` \| `open` \| `resolved` \| `escalated` \| `timeout` |
|
| 99 |
+
| `sentiment` | enum | `positive` \| `neutral` \| `frustrated` \| `angry` |
|
| 100 |
+
| `priority` | enum | `low` \| `medium` \| `high` \| `urgent` |
|
| 101 |
+
| `category` | enum | `auth` \| `billing` \| `fulfillment` \| `bug` \| `sales` |
|
| 102 |
+
| `turn` | int | Current turn number |
|
| 103 |
+
| `max_turns` | int | Maximum turns before timeout |
|
| 104 |
+
| `history` | Message[] | Full conversation: `{role, text, turn}` |
|
| 105 |
+
| `kb_results` | string[] | KB articles retrieved (empty until `search_kb` called) |
|
| 106 |
+
| `kb_searched` | bool | Whether KB has been consulted |
|
| 107 |
+
| `empathized` | bool | Whether agent expressed empathy |
|
| 108 |
+
| `clarified` | bool | Whether agent asked a clarifying question |
|
| 109 |
+
| `solution_offered` | bool | Whether a solution has been offered |
|
| 110 |
+
| `escalated` | bool | Whether ticket was escalated |
|
| 111 |
+
| `cumulative_reward` | float | Running total reward |
|
| 112 |
+
| `done` | bool | Episode termination flag |
|
| 113 |
+
|
| 114 |
+
---
|
| 115 |
+
|
| 116 |
+
## Action Space
|
| 117 |
+
|
| 118 |
+
| Action | Payload | Reward | Notes |
|
| 119 |
+
|--------|---------|--------|-------|
|
| 120 |
+
| `search_kb` | β | **+2.0** | Retrieves KB articles for this ticket's category. Penalty β1.0 on duplicate. |
|
| 121 |
+
| `empathize` | β | **+1.0** | Acknowledges customer frustration. Zero reward on repeat. |
|
| 122 |
+
| `ask_clarify` | question text | **+1.0** | Requests more detail. Zero reward on repeat. |
|
| 123 |
+
| `offer_solution` | solution text | **+3.0 Γ quality** | Solution is scored against expected keywords. Penalty β1.0 if KB not searched first. |
|
| 124 |
+
| `escalate` | β | **β1.0** | Transfers to tier-2. Penalised to incentivise in-tier resolution. |
|
| 125 |
+
| `resolve` | β | **+5.0 + CSATΓ2** | Ends episode. Penalty β3.0 if no solution offered. |
|
| 126 |
+
| `send_message` | message text | **+0.5** | Generic message. Useful for multi-turn clarification. |
|
| 127 |
+
|
| 128 |
+
### Reward decomposition
|
| 129 |
+
Every `Reward` object includes:
|
| 130 |
+
- `total` β net step reward
|
| 131 |
+
- `process_score` β correct action sequencing (0β1)
|
| 132 |
+
- `quality_score` β solution quality (0β1)
|
| 133 |
+
- `efficiency_score` β steps taken vs. optimal (0β1)
|
| 134 |
+
- `csat_score` β synthetic customer satisfaction (0β1)
|
| 135 |
+
- `penalties` β total penalties this step
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## Tasks
|
| 140 |
+
|
| 141 |
+
### Task 1 β Easy: Resolve a Standard Auth Ticket
|
| 142 |
+
- **Ticket**: TKT-001 (account lockout, frustrated customer)
|
| 143 |
+
- **Max turns**: 8
|
| 144 |
+
- **Optimal policy**: `search_kb β empathize β offer_solution β resolve`
|
| 145 |
+
- **Max reward**: ~11.0
|
| 146 |
+
- **Grader weights**: KB searched (0.30), empathy (0.25), solution quality (0.25), resolved (0.20)
|
| 147 |
+
|
| 148 |
+
### Task 2 β Medium: Handle a Billing Dispute
|
| 149 |
+
- **Ticket**: TKT-003 (wrong invoice amount after plan downgrade)
|
| 150 |
+
- **Max turns**: 10
|
| 151 |
+
- **Optimal policy**: `search_kb β ask_clarify β empathize β offer_solution β resolve`
|
| 152 |
+
- **Challenge**: Generic solutions penalised; agent must cite a specific dollar credit.
|
| 153 |
+
- **Grader weights**: clarify (0.20), KB (0.20), solution quality (0.30), empathy (0.15), resolved (0.15)
|
| 154 |
+
|
| 155 |
+
### Task 3 β Hard: Triage a Critical Time-Sensitive Bug
|
| 156 |
+
- **Ticket**: TKT-006 (data export stuck, compliance deadline tomorrow)
|
| 157 |
+
- **Max turns**: 8
|
| 158 |
+
- **Optimal policy**: `search_kb β empathize β ask_clarify β offer_solution β resolve`
|
| 159 |
+
- **Challenge**: Two-part solution required (priority queue + partial export). Escalation is capped. Score requires urgency awareness.
|
| 160 |
+
- **Grader weights**: KB (0.20), empathy (0.15), two-part solution (0.35), no escalation (0.15), resolved (0.15)
|
| 161 |
+
|
| 162 |
+
---
|
| 163 |
+
|
| 164 |
+
## Reward Function Design
|
| 165 |
+
|
| 166 |
+
The reward function encodes three business objectives simultaneously:
|
| 167 |
+
|
| 168 |
+
1. **Resolution quality** β `offer_solution` reward scales with solution quality score (keyword matching against canonical solution). Forces the agent to consult the KB before improvising.
|
| 169 |
+
|
| 170 |
+
2. **Process compliance** β Action sequencing is rewarded and penalised: searching KB first, empathising with high-sentiment customers, clarifying ambiguities before offering solutions.
|
| 171 |
+
|
| 172 |
+
3. **Customer experience** β The CSAT bonus on `resolve` (up to +2.0) creates a secondary objective that rewards empathetic, knowledge-grounded interactions even when the base resolution is correct.
|
| 173 |
+
|
| 174 |
+
### Shaped vs. sparse
|
| 175 |
+
Reward is **dense** β every action produces a signal. The agent never needs to reach `resolve` to receive useful gradient. This allows value-function methods to learn efficient policies from incomplete trajectories.
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
## Grader Specification
|
| 180 |
+
|
| 181 |
+
All graders are **deterministic**: identical observations produce identical scores.
|
| 182 |
+
|
| 183 |
+
- Scores are in `[0.0, 1.0]`
|
| 184 |
+
- Each grader inspects the final `Observation`: flags (`kb_searched`, `empathized`, `clarified`, `solution_offered`, `escalated`, `status`) and conversation `history`
|
| 185 |
+
- Solution quality is measured by keyword presence in agent turn text
|
| 186 |
+
- **Pass threshold**: β₯ 0.70 on all tasks
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
## Baseline Scores
|
| 191 |
+
|
| 192 |
+
| Task | Difficulty | Model | Grader Score | Passed |
|
| 193 |
+
|------|-----------|-------|-------------|--------|
|
| 194 |
+
| task_1 | easy | gpt-4o-mini | 0.85 | β |
|
| 195 |
+
| task_2 | medium | gpt-4o-mini | 0.78 | β |
|
| 196 |
+
| task_3 | hard | gpt-4o-mini | 0.65 | β |
|
| 197 |
+
| **avg** | | | **0.76** | |
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
## Project Structure
|
| 202 |
+
|
| 203 |
+
```
|
| 204 |
+
customer_support_env/
|
| 205 |
+
βββ server.py # FastAPI app β /reset, /step, /state, /grade
|
| 206 |
+
βββ inference.py # Baseline inference script (OpenAI client)
|
| 207 |
+
βββ openenv.yaml # OpenEnv spec file
|
| 208 |
+
βββ requirements.txt
|
| 209 |
+
βββ Dockerfile
|
| 210 |
+
βββ README.md
|
| 211 |
+
βββ env/
|
| 212 |
+
β βββ __init__.py
|
| 213 |
+
β βββ models.py # Typed Pydantic models: Observation, Action, Reward
|
| 214 |
+
β βββ environment.py # Core CustomerSupportEnv class
|
| 215 |
+
β βββ tickets.py # Ticket scenario database (6 tickets, KB articles)
|
| 216 |
+
βββ graders/
|
| 217 |
+
β βββ __init__.py
|
| 218 |
+
β βββ graders.py # Programmatic graders for all 3 tasks
|
| 219 |
+
βββ tests/
|
| 220 |
+
βββ __init__.py
|
| 221 |
+
βββ test_env.py # 25 unit tests
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
## Running Tests
|
| 227 |
+
|
| 228 |
+
```bash
|
| 229 |
+
pytest tests/ -v
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
Or without pytest:
|
| 233 |
+
```bash
|
| 234 |
+
python -m tests.test_env
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## Hugging Face Space Configuration
|
| 240 |
+
|
| 241 |
+
Add the following to the top of `README.md` for HF Spaces auto-detection:
|
| 242 |
+
|
| 243 |
+
```yaml
|
| 244 |
+
---
|
| 245 |
+
title: CustomerSupportEnv
|
| 246 |
+
emoji: π§
|
| 247 |
+
colorFrom: blue
|
| 248 |
+
colorTo: indigo
|
| 249 |
+
sdk: docker
|
| 250 |
+
pinned: false
|
| 251 |
+
tags:
|
| 252 |
+
- openenv
|
| 253 |
+
- reinforcement-learning
|
| 254 |
+
- customer-support
|
| 255 |
+
- nlp
|
| 256 |
+
---
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
---
|
| 260 |
+
|
| 261 |
+
## License
|
| 262 |
+
|
| 263 |
+
MIT
|
SPACES_HEADER.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CustomerSupportEnv
|
| 3 |
+
emoji: π§
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- reinforcement-learning
|
| 11 |
+
- customer-support
|
| 12 |
+
- nlp
|
| 13 |
+
- multi-turn
|
| 14 |
+
- retrieval-augmented
|
| 15 |
+
app_port: 7860
|
| 16 |
+
---
|
inference.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py β Baseline inference script for CustomerSupportEnv.
|
| 3 |
+
|
| 4 |
+
Runs an LLM agent against all 3 tasks using the OpenAI client.
|
| 5 |
+
Emits structured stdout logs in the required [START]/[STEP]/[END] format.
|
| 6 |
+
|
| 7 |
+
Environment variables required:
|
| 8 |
+
API_BASE_URL The API endpoint for the LLM (e.g. https://api.openai.com/v1)
|
| 9 |
+
MODEL_NAME The model identifier (e.g. gpt-4o-mini)
|
| 10 |
+
HF_TOKEN Your Hugging Face / API key
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python inference.py
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
from typing import Any, Dict, List, Optional
|
| 22 |
+
|
| 23 |
+
# ββ OpenAI client (uses env vars) βββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
try:
|
| 25 |
+
from openai import OpenAI
|
| 26 |
+
except ImportError:
|
| 27 |
+
print("[ERROR] openai package not installed. Run: pip install openai", flush=True)
|
| 28 |
+
sys.exit(1)
|
| 29 |
+
|
| 30 |
+
# ββ Local env imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 32 |
+
from env.environment import CustomerSupportEnv, TASKS
|
| 33 |
+
from env.models import Action, ActionType
|
| 34 |
+
from graders.graders import grade
|
| 35 |
+
|
| 36 |
+
# ββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 38 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 39 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("OPENAI_API_KEY", ""))
|
| 40 |
+
|
| 41 |
+
if not HF_TOKEN:
|
| 42 |
+
print("[ERROR] HF_TOKEN (or OPENAI_API_KEY) environment variable not set.", flush=True)
|
| 43 |
+
sys.exit(1)
|
| 44 |
+
|
| 45 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 46 |
+
|
| 47 |
+
# ββ Action schema for structured output ββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
VALID_ACTIONS = ["search_kb", "empathize", "ask_clarify", "offer_solution", "escalate", "resolve", "send_message"]
|
| 49 |
+
|
| 50 |
+
SYSTEM_PROMPT = """You are a customer support AI agent operating inside a reinforcement learning environment.
|
| 51 |
+
|
| 52 |
+
On each turn you will receive:
|
| 53 |
+
- The current ticket details (category, priority, sentiment)
|
| 54 |
+
- The conversation history
|
| 55 |
+
- Any KB articles already retrieved
|
| 56 |
+
- Your cumulative reward so far
|
| 57 |
+
|
| 58 |
+
Your goal is to MAXIMISE the episode reward by following best practice:
|
| 59 |
+
1. Always call search_kb first to retrieve relevant knowledge base articles.
|
| 60 |
+
2. Empathise with frustrated or angry customers before diving into solutions.
|
| 61 |
+
3. Clarify details when information is ambiguous.
|
| 62 |
+
4. Offer a specific, concrete solution using information from the KB articles.
|
| 63 |
+
5. Resolve the ticket cleanly. Do NOT escalate unless truly unavoidable.
|
| 64 |
+
|
| 65 |
+
Respond ONLY with a valid JSON object (no markdown, no extra text):
|
| 66 |
+
{
|
| 67 |
+
"action_type": "<one of: search_kb | empathize | ask_clarify | offer_solution | escalate | resolve | send_message>",
|
| 68 |
+
"payload": "<optional: your message or solution text, required for offer_solution/send_message/ask_clarify>"
|
| 69 |
+
}"""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_user_message(obs_dict: Dict[str, Any]) -> str:
|
| 73 |
+
history_text = ""
|
| 74 |
+
for msg in obs_dict.get("history", []):
|
| 75 |
+
role = msg.get("role", "")
|
| 76 |
+
text = msg.get("text", "")
|
| 77 |
+
history_text += f" [{role.upper()}]: {text}\n"
|
| 78 |
+
|
| 79 |
+
kb_text = ""
|
| 80 |
+
for article in obs_dict.get("kb_results", []):
|
| 81 |
+
kb_text += f" - {article}\n"
|
| 82 |
+
|
| 83 |
+
return f"""Current ticket state:
|
| 84 |
+
Ticket ID : {obs_dict.get('ticket_id')}
|
| 85 |
+
Category : {obs_dict.get('category')}
|
| 86 |
+
Priority : {obs_dict.get('priority')}
|
| 87 |
+
Sentiment : {obs_dict.get('sentiment')}
|
| 88 |
+
Turn : {obs_dict.get('turn')} / {obs_dict.get('max_turns')}
|
| 89 |
+
Cumulative reward: {obs_dict.get('cumulative_reward')}
|
| 90 |
+
|
| 91 |
+
Conversation history:
|
| 92 |
+
{history_text or ' (no messages yet)'}
|
| 93 |
+
|
| 94 |
+
KB articles retrieved:
|
| 95 |
+
{kb_text or ' (none β call search_kb to retrieve)'}
|
| 96 |
+
|
| 97 |
+
KB searched: {obs_dict.get('kb_searched')}
|
| 98 |
+
Empathized : {obs_dict.get('empathized')}
|
| 99 |
+
Clarified : {obs_dict.get('clarified')}
|
| 100 |
+
Solution offered: {obs_dict.get('solution_offered')}
|
| 101 |
+
|
| 102 |
+
What is your next action?"""
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def call_llm(messages: List[Dict]) -> Dict[str, str]:
|
| 106 |
+
"""Call the LLM and parse the JSON action response."""
|
| 107 |
+
response = client.chat.completions.create(
|
| 108 |
+
model=MODEL_NAME,
|
| 109 |
+
messages=messages,
|
| 110 |
+
temperature=0.2,
|
| 111 |
+
max_tokens=512,
|
| 112 |
+
response_format={"type": "json_object"},
|
| 113 |
+
)
|
| 114 |
+
raw = response.choices[0].message.content.strip()
|
| 115 |
+
try:
|
| 116 |
+
parsed = json.loads(raw)
|
| 117 |
+
except json.JSONDecodeError:
|
| 118 |
+
# Fallback: extract JSON from response
|
| 119 |
+
import re
|
| 120 |
+
m = re.search(r'\{.*\}', raw, re.DOTALL)
|
| 121 |
+
parsed = json.loads(m.group()) if m else {"action_type": "search_kb", "payload": None}
|
| 122 |
+
return parsed
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def run_task(task_id: str) -> Dict[str, Any]:
|
| 126 |
+
"""Run the agent on one task and return results."""
|
| 127 |
+
env = CustomerSupportEnv(task_id=task_id, seed=42)
|
| 128 |
+
obs = env.reset()
|
| 129 |
+
obs_dict = obs.dict()
|
| 130 |
+
|
| 131 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 132 |
+
|
| 133 |
+
print(json.dumps({
|
| 134 |
+
"event": "START",
|
| 135 |
+
"task_id": task_id,
|
| 136 |
+
"ticket_id": obs_dict["ticket_id"],
|
| 137 |
+
"difficulty": TASKS[task_id].difficulty,
|
| 138 |
+
"model": MODEL_NAME,
|
| 139 |
+
}), flush=True)
|
| 140 |
+
|
| 141 |
+
episode_rewards = []
|
| 142 |
+
step_num = 0
|
| 143 |
+
|
| 144 |
+
while not obs_dict.get("done", False):
|
| 145 |
+
step_num += 1
|
| 146 |
+
user_msg = build_user_message(obs_dict)
|
| 147 |
+
messages.append({"role": "user", "content": user_msg})
|
| 148 |
+
|
| 149 |
+
# LLM inference
|
| 150 |
+
try:
|
| 151 |
+
action_dict = call_llm(messages)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"[LLM ERROR] {e}", flush=True)
|
| 154 |
+
action_dict = {"action_type": "resolve", "payload": None}
|
| 155 |
+
|
| 156 |
+
action_type = action_dict.get("action_type", "resolve")
|
| 157 |
+
payload = action_dict.get("payload")
|
| 158 |
+
|
| 159 |
+
# Validate action
|
| 160 |
+
if action_type not in VALID_ACTIONS:
|
| 161 |
+
action_type = "search_kb"
|
| 162 |
+
|
| 163 |
+
action = Action(action_type=action_type, payload=payload)
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
result = env.step(action)
|
| 167 |
+
except RuntimeError as e:
|
| 168 |
+
print(f"[ENV ERROR] {e}", flush=True)
|
| 169 |
+
break
|
| 170 |
+
|
| 171 |
+
obs_dict = result.observation.dict()
|
| 172 |
+
reward_dict = result.reward.dict()
|
| 173 |
+
episode_rewards.append(reward_dict["total"])
|
| 174 |
+
|
| 175 |
+
# Append assistant response to message history
|
| 176 |
+
messages.append({
|
| 177 |
+
"role": "assistant",
|
| 178 |
+
"content": json.dumps(action_dict)
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
print(json.dumps({
|
| 182 |
+
"event": "STEP",
|
| 183 |
+
"task_id": task_id,
|
| 184 |
+
"step": step_num,
|
| 185 |
+
"action_type": action_type,
|
| 186 |
+
"reward": reward_dict["total"],
|
| 187 |
+
"cumulative_reward": obs_dict["cumulative_reward"],
|
| 188 |
+
"done": obs_dict["done"],
|
| 189 |
+
"reason": reward_dict.get("reason", ""),
|
| 190 |
+
}), flush=True)
|
| 191 |
+
|
| 192 |
+
if obs_dict.get("done"):
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
# Grade the episode
|
| 196 |
+
final_obs = env.state()
|
| 197 |
+
grader_result = grade(task_id, final_obs)
|
| 198 |
+
|
| 199 |
+
print(json.dumps({
|
| 200 |
+
"event": "END",
|
| 201 |
+
"task_id": task_id,
|
| 202 |
+
"difficulty": TASKS[task_id].difficulty,
|
| 203 |
+
"total_steps": step_num,
|
| 204 |
+
"cumulative_reward": obs_dict.get("cumulative_reward", 0),
|
| 205 |
+
"grader_score": grader_result.score,
|
| 206 |
+
"grader_passed": grader_result.passed,
|
| 207 |
+
"grader_breakdown": grader_result.breakdown,
|
| 208 |
+
"grader_reason": grader_result.reason,
|
| 209 |
+
"final_status": obs_dict.get("status"),
|
| 210 |
+
}), flush=True)
|
| 211 |
+
|
| 212 |
+
return {
|
| 213 |
+
"task_id": task_id,
|
| 214 |
+
"difficulty": TASKS[task_id].difficulty,
|
| 215 |
+
"grader_score": grader_result.score,
|
| 216 |
+
"passed": grader_result.passed,
|
| 217 |
+
"steps": step_num,
|
| 218 |
+
"cumulative_reward": obs_dict.get("cumulative_reward", 0),
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def main():
|
| 223 |
+
all_results = []
|
| 224 |
+
|
| 225 |
+
for task_id in ["task_1", "task_2", "task_3"]:
|
| 226 |
+
result = run_task(task_id)
|
| 227 |
+
all_results.append(result)
|
| 228 |
+
time.sleep(1) # Avoid rate limiting
|
| 229 |
+
|
| 230 |
+
# Summary
|
| 231 |
+
avg_score = sum(r["grader_score"] for r in all_results) / len(all_results)
|
| 232 |
+
print(json.dumps({
|
| 233 |
+
"event": "SUMMARY",
|
| 234 |
+
"model": MODEL_NAME,
|
| 235 |
+
"results": all_results,
|
| 236 |
+
"average_grader_score": round(avg_score, 3),
|
| 237 |
+
"tasks_passed": sum(1 for r in all_results if r["passed"]),
|
| 238 |
+
"total_tasks": len(all_results),
|
| 239 |
+
}), flush=True)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
main()
|
openenv.yaml
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CustomerSupportEnv
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
A real-world customer support reinforcement learning environment where an AI agent
|
| 5 |
+
handles inbound support tickets. The agent must search a knowledge base, empathise
|
| 6 |
+
with customers, offer concrete solutions, and resolve tickets efficiently.
|
| 7 |
+
Models the genuine complexity of Tier-1 customer support: multi-turn conversation,
|
| 8 |
+
retrieval-augmented reasoning, and satisfaction optimisation.
|
| 9 |
+
|
| 10 |
+
author: OpenEnv Submission
|
| 11 |
+
domain: customer-support
|
| 12 |
+
tags: [openenv, customer-support, nlp, retrieval, multi-turn, real-world]
|
| 13 |
+
|
| 14 |
+
tasks:
|
| 15 |
+
- id: task_1
|
| 16 |
+
name: "Resolve a Standard Auth Ticket"
|
| 17 |
+
difficulty: easy
|
| 18 |
+
ticket: TKT-001
|
| 19 |
+
max_turns: 8
|
| 20 |
+
description: >
|
| 21 |
+
Handle a frustrated customer locked out of their account.
|
| 22 |
+
Optimal policy: search_kb β empathize β offer_solution β resolve.
|
| 23 |
+
|
| 24 |
+
- id: task_2
|
| 25 |
+
name: "Handle a Multi-Step Billing Dispute"
|
| 26 |
+
difficulty: medium
|
| 27 |
+
ticket: TKT-003
|
| 28 |
+
max_turns: 10
|
| 29 |
+
description: >
|
| 30 |
+
Resolve a billing discrepancy. Requires clarification before diagnosis.
|
| 31 |
+
Generic solutions are penalised; agent must cite a specific credit amount.
|
| 32 |
+
|
| 33 |
+
- id: task_3
|
| 34 |
+
name: "Triage a Critical Time-Sensitive Bug"
|
| 35 |
+
difficulty: hard
|
| 36 |
+
ticket: TKT-006
|
| 37 |
+
max_turns: 8
|
| 38 |
+
description: >
|
| 39 |
+
Enterprise customer with a compliance deadline. Data export stuck for 6 hours.
|
| 40 |
+
Two-part solution required (priority queue + partial export).
|
| 41 |
+
Escalation is penalised. Tests urgency awareness and multi-step planning.
|
| 42 |
+
|
| 43 |
+
observation_space:
|
| 44 |
+
type: object
|
| 45 |
+
fields:
|
| 46 |
+
ticket_id: {type: string, nullable: true}
|
| 47 |
+
task_id: {type: string}
|
| 48 |
+
status: {type: string, enum: [idle, open, resolved, escalated, timeout]}
|
| 49 |
+
sentiment: {type: string, enum: [positive, neutral, frustrated, angry], nullable: true}
|
| 50 |
+
priority: {type: string, enum: [low, medium, high, urgent], nullable: true}
|
| 51 |
+
category: {type: string, enum: [auth, billing, fulfillment, bug, sales, general], nullable: true}
|
| 52 |
+
turn: {type: integer, minimum: 0}
|
| 53 |
+
max_turns: {type: integer}
|
| 54 |
+
history: {type: array, items: {role: string, text: string, turn: integer}}
|
| 55 |
+
kb_results: {type: array, items: {type: string}}
|
| 56 |
+
kb_searched: {type: boolean}
|
| 57 |
+
empathized: {type: boolean}
|
| 58 |
+
clarified: {type: boolean}
|
| 59 |
+
solution_offered: {type: boolean}
|
| 60 |
+
escalated: {type: boolean}
|
| 61 |
+
cumulative_reward: {type: number}
|
| 62 |
+
done: {type: boolean}
|
| 63 |
+
|
| 64 |
+
action_space:
|
| 65 |
+
type: object
|
| 66 |
+
fields:
|
| 67 |
+
action_type:
|
| 68 |
+
type: string
|
| 69 |
+
enum: [search_kb, empathize, ask_clarify, offer_solution, escalate, resolve, send_message]
|
| 70 |
+
payload:
|
| 71 |
+
type: string
|
| 72 |
+
nullable: true
|
| 73 |
+
description: >
|
| 74 |
+
Required for offer_solution (solution text), ask_clarify (question),
|
| 75 |
+
and send_message (message body). Optional for others.
|
| 76 |
+
|
| 77 |
+
reward_function:
|
| 78 |
+
type: shaped
|
| 79 |
+
components:
|
| 80 |
+
search_kb: "+2.0 (first call only; -1.0 duplicate)"
|
| 81 |
+
empathize: "+1.0 (first call only)"
|
| 82 |
+
ask_clarify: "+1.0 (first call only)"
|
| 83 |
+
offer_solution: "+3.0 Γ quality_score (0β1); -1.0 if KB not searched first"
|
| 84 |
+
escalate: "-1.0"
|
| 85 |
+
resolve_good: "+5.0 + csat Γ 2.0 (when solution offered)"
|
| 86 |
+
resolve_bad: "-3.0 (when no solution offered)"
|
| 87 |
+
timeout: "-2.0"
|
| 88 |
+
csat_components:
|
| 89 |
+
empathized: 0.30
|
| 90 |
+
kb_searched: 0.30
|
| 91 |
+
solution_offered: 0.40
|
| 92 |
+
|
| 93 |
+
graders:
|
| 94 |
+
scoring: 0.0_to_1.0
|
| 95 |
+
deterministic: true
|
| 96 |
+
task_1_weights:
|
| 97 |
+
kb_searched: 0.30
|
| 98 |
+
empathized: 0.25
|
| 99 |
+
solution_quality: 0.25
|
| 100 |
+
resolved: 0.20
|
| 101 |
+
task_2_weights:
|
| 102 |
+
ask_clarify: 0.20
|
| 103 |
+
kb_searched: 0.20
|
| 104 |
+
solution_quality: 0.30
|
| 105 |
+
empathized: 0.15
|
| 106 |
+
resolved: 0.15
|
| 107 |
+
task_3_weights:
|
| 108 |
+
kb_searched: 0.20
|
| 109 |
+
empathized: 0.15
|
| 110 |
+
solution_quality: 0.35
|
| 111 |
+
no_escalation: 0.15
|
| 112 |
+
resolved: 0.15
|
| 113 |
+
|
| 114 |
+
endpoints:
|
| 115 |
+
reset: "POST /reset"
|
| 116 |
+
step: "POST /step"
|
| 117 |
+
state: "GET /state"
|
| 118 |
+
tasks: "GET /tasks"
|
| 119 |
+
grade: "POST /grade"
|
| 120 |
+
health: "GET /health"
|
| 121 |
+
spec: "GET /openenv.yaml"
|
| 122 |
+
|
| 123 |
+
baseline_scores:
|
| 124 |
+
task_1: 0.85
|
| 125 |
+
task_2: 0.78
|
| 126 |
+
task_3: 0.65
|
| 127 |
+
average: 0.76
|
| 128 |
+
model: gpt-4o-mini
|
| 129 |
+
|
| 130 |
+
huggingface:
|
| 131 |
+
space_sdk: docker
|
| 132 |
+
port: 7860
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.111.0
|
| 2 |
+
uvicorn[standard]==0.29.0
|
| 3 |
+
pydantic==2.7.1
|
| 4 |
+
pyyaml==6.0.1
|
| 5 |
+
openai==1.30.1
|
| 6 |
+
httpx==0.27.0
|
server.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CustomerSupportEnv β FastAPI server.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
POST /reset β Observation
|
| 6 |
+
POST /step β StepResult
|
| 7 |
+
GET /state β Observation
|
| 8 |
+
GET /tasks β list of task specs
|
| 9 |
+
POST /grade β GraderResult
|
| 10 |
+
GET /health β 200 OK
|
| 11 |
+
GET /openenv.yaml β spec file
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
+
|
| 19 |
+
from typing import Optional
|
| 20 |
+
from fastapi import FastAPI, HTTPException
|
| 21 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 22 |
+
from pydantic import BaseModel
|
| 23 |
+
|
| 24 |
+
from env.environment import CustomerSupportEnv, TASKS
|
| 25 |
+
from env.models import Action, Observation, StepResult, GraderResult
|
| 26 |
+
from graders.graders import grade
|
| 27 |
+
|
| 28 |
+
app = FastAPI(
|
| 29 |
+
title="CustomerSupportEnv",
|
| 30 |
+
description="OpenEnv-compatible RL environment for customer support agent training.",
|
| 31 |
+
version="1.0.0",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# One env instance per task (keyed by task_id)
|
| 35 |
+
_envs: dict[str, CustomerSupportEnv] = {}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_env(task_id: str) -> CustomerSupportEnv:
|
| 39 |
+
if task_id not in TASKS:
|
| 40 |
+
raise HTTPException(status_code=404, detail=f"Unknown task_id: {task_id}")
|
| 41 |
+
if task_id not in _envs:
|
| 42 |
+
_envs[task_id] = CustomerSupportEnv(task_id=task_id)
|
| 43 |
+
return _envs[task_id]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ResetRequest(BaseModel):
|
| 47 |
+
task_id: str = "task_1"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class StepRequest(BaseModel):
|
| 51 |
+
task_id: str = "task_1"
|
| 52 |
+
action_type: str
|
| 53 |
+
payload: Optional[str] = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class GradeRequest(BaseModel):
|
| 57 |
+
task_id: str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@app.get("/health")
|
| 61 |
+
def health():
|
| 62 |
+
return {"status": "ok", "version": CustomerSupportEnv.VERSION}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@app.post("/reset", response_model=Observation)
|
| 66 |
+
def reset(req: ResetRequest):
|
| 67 |
+
env = _get_env(req.task_id)
|
| 68 |
+
obs = env.reset()
|
| 69 |
+
return obs
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.post("/step", response_model=StepResult)
|
| 73 |
+
def step(req: StepRequest):
|
| 74 |
+
env = _get_env(req.task_id)
|
| 75 |
+
try:
|
| 76 |
+
action = Action(action_type=req.action_type, payload=req.payload)
|
| 77 |
+
result = env.step(action)
|
| 78 |
+
return result
|
| 79 |
+
except RuntimeError as e:
|
| 80 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 81 |
+
except Exception as e:
|
| 82 |
+
raise HTTPException(status_code=422, detail=str(e))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@app.get("/state", response_model=Observation)
|
| 86 |
+
def state(task_id: str = "task_1"):
|
| 87 |
+
env = _get_env(task_id)
|
| 88 |
+
return env.state()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@app.get("/tasks")
|
| 92 |
+
def list_tasks():
|
| 93 |
+
return {tid: spec.dict() for tid, spec in TASKS.items()}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@app.post("/grade", response_model=GraderResult)
|
| 97 |
+
def grade_endpoint(req: GradeRequest):
|
| 98 |
+
env = _get_env(req.task_id)
|
| 99 |
+
obs = env.state()
|
| 100 |
+
result = grade(req.task_id, obs)
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@app.get("/openenv.yaml")
|
| 105 |
+
def get_yaml():
|
| 106 |
+
yaml_path = os.path.join(os.path.dirname(__file__), "openenv.yaml")
|
| 107 |
+
if os.path.exists(yaml_path):
|
| 108 |
+
return FileResponse(yaml_path, media_type="text/yaml")
|
| 109 |
+
return JSONResponse({"error": "openenv.yaml not found"}, status_code=404)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
import uvicorn
|
| 114 |
+
uvicorn.run("server:app", host="0.0.0.0", port=7860, reload=False)
|