Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- README.md +250 -5
- __init__.py +16 -0
- client.py +99 -0
- inference.py +168 -0
- models.py +33 -0
- openenv.yaml +7 -0
- pyproject.toml +37 -0
- server/__init__.py +11 -0
- server/app.py +91 -0
- server/data_wrangler_environment.py +178 -0
- server/requirements.txt +8 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=data_wrangler
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,255 @@
|
|
| 1 |
---
|
| 2 |
-
title: Data Wrangler
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Data Wrangler Environment Server
|
| 3 |
+
emoji: 🎹
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Data Wrangler Environment
|
| 15 |
+
|
| 16 |
+
A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
|
| 17 |
+
|
| 18 |
+
## Quick Start
|
| 19 |
+
|
| 20 |
+
The simplest way to use the Data Wrangler environment is through the `DataWranglerEnv` class:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from data_wrangler import DataWranglerAction, DataWranglerEnv
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Create environment from Docker image
|
| 27 |
+
data_wranglerenv = DataWranglerEnv.from_docker_image("data_wrangler-env:latest")
|
| 28 |
+
|
| 29 |
+
# Reset
|
| 30 |
+
result = data_wranglerenv.reset()
|
| 31 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 32 |
+
|
| 33 |
+
# Send multiple messages
|
| 34 |
+
messages = ["Hello, World!", "Testing echo", "Final message"]
|
| 35 |
+
|
| 36 |
+
for msg in messages:
|
| 37 |
+
result = data_wranglerenv.step(DataWranglerAction(message=msg))
|
| 38 |
+
print(f"Sent: '{msg}'")
|
| 39 |
+
print(f" → Echoed: '{result.observation.echoed_message}'")
|
| 40 |
+
print(f" → Length: {result.observation.message_length}")
|
| 41 |
+
print(f" → Reward: {result.reward}")
|
| 42 |
+
|
| 43 |
+
finally:
|
| 44 |
+
# Always clean up
|
| 45 |
+
data_wranglerenv.close()
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
That's it! The `DataWranglerEnv.from_docker_image()` method handles:
|
| 49 |
+
- Starting the Docker container
|
| 50 |
+
- Waiting for the server to be ready
|
| 51 |
+
- Connecting to the environment
|
| 52 |
+
- Container cleanup when you call `close()`
|
| 53 |
+
|
| 54 |
+
## Building the Docker Image
|
| 55 |
+
|
| 56 |
+
Before using the environment, you need to build the Docker image:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# From project root
|
| 60 |
+
docker build -t data_wrangler-env:latest -f server/Dockerfile .
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Deploying to Hugging Face Spaces
|
| 64 |
+
|
| 65 |
+
You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
# From the environment directory (where openenv.yaml is located)
|
| 69 |
+
openenv push
|
| 70 |
+
|
| 71 |
+
# Or specify options
|
| 72 |
+
openenv push --namespace my-org --private
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
The `openenv push` command will:
|
| 76 |
+
1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
|
| 77 |
+
2. Prepare a custom build for Hugging Face Docker space (enables web interface)
|
| 78 |
+
3. Upload to Hugging Face (ensuring you're logged in)
|
| 79 |
+
|
| 80 |
+
### Prerequisites
|
| 81 |
+
|
| 82 |
+
- Authenticate with Hugging Face: The command will prompt for login if not already authenticated
|
| 83 |
+
|
| 84 |
+
### Options
|
| 85 |
+
|
| 86 |
+
- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
|
| 87 |
+
- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
|
| 88 |
+
- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
|
| 89 |
+
- `--private`: Deploy the space as private (default: public)
|
| 90 |
+
|
| 91 |
+
### Examples
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Push to your personal namespace (defaults to username/env-name from openenv.yaml)
|
| 95 |
+
openenv push
|
| 96 |
+
|
| 97 |
+
# Push to a specific repository
|
| 98 |
+
openenv push --repo-id my-org/my-env
|
| 99 |
+
|
| 100 |
+
# Push with a custom base image
|
| 101 |
+
openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
|
| 102 |
+
|
| 103 |
+
# Push as a private space
|
| 104 |
+
openenv push --private
|
| 105 |
+
|
| 106 |
+
# Combine options
|
| 107 |
+
openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
After deployment, your space will be available at:
|
| 111 |
+
`https://huggingface.co/spaces/<repo-id>`
|
| 112 |
+
|
| 113 |
+
The deployed space includes:
|
| 114 |
+
- **Web Interface** at `/web` - Interactive UI for exploring the environment
|
| 115 |
+
- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
|
| 116 |
+
- **Health Check** at `/health` - Container health monitoring
|
| 117 |
+
- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
|
| 118 |
+
|
| 119 |
+
## Environment Details
|
| 120 |
+
|
| 121 |
+
### Action
|
| 122 |
+
**DataWranglerAction**: Contains a single field
|
| 123 |
+
- `message` (str) - The message to echo back
|
| 124 |
+
|
| 125 |
+
### Observation
|
| 126 |
+
**DataWranglerObservation**: Contains the echo response and metadata
|
| 127 |
+
- `echoed_message` (str) - The message echoed back
|
| 128 |
+
- `message_length` (int) - Length of the message
|
| 129 |
+
- `reward` (float) - Reward based on message length (length × 0.1)
|
| 130 |
+
- `done` (bool) - Always False for echo environment
|
| 131 |
+
- `metadata` (dict) - Additional info like step count
|
| 132 |
+
|
| 133 |
+
### Reward
|
| 134 |
+
The reward is calculated as: `message_length × 0.1`
|
| 135 |
+
- "Hi" → reward: 0.2
|
| 136 |
+
- "Hello, World!" → reward: 1.3
|
| 137 |
+
- Empty message → reward: 0.0
|
| 138 |
+
|
| 139 |
+
## Advanced Usage
|
| 140 |
+
|
| 141 |
+
### Connecting to an Existing Server
|
| 142 |
+
|
| 143 |
+
If you already have a Data Wrangler environment server running, you can connect directly:
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
from data_wrangler import DataWranglerEnv
|
| 147 |
+
|
| 148 |
+
# Connect to existing server
|
| 149 |
+
data_wranglerenv = DataWranglerEnv(base_url="<ENV_HTTP_URL_HERE>")
|
| 150 |
+
|
| 151 |
+
# Use as normal
|
| 152 |
+
result = data_wranglerenv.reset()
|
| 153 |
+
result = data_wranglerenv.step(DataWranglerAction(message="Hello!"))
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Note: When connecting to an existing server, `data_wranglerenv.close()` will NOT stop the server.
|
| 157 |
+
|
| 158 |
+
### Using the Context Manager
|
| 159 |
+
|
| 160 |
+
The client supports context manager usage for automatic connection management:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
from data_wrangler import DataWranglerAction, DataWranglerEnv
|
| 164 |
+
|
| 165 |
+
# Connect with context manager (auto-connects and closes)
|
| 166 |
+
with DataWranglerEnv(base_url="http://localhost:8000") as env:
|
| 167 |
+
result = env.reset()
|
| 168 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 169 |
+
# Multiple steps with low latency
|
| 170 |
+
for msg in ["Hello", "World", "!"]:
|
| 171 |
+
result = env.step(DataWranglerAction(message=msg))
|
| 172 |
+
print(f"Echoed: {result.observation.echoed_message}")
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
The client uses WebSocket connections for:
|
| 176 |
+
- **Lower latency**: No HTTP connection overhead per request
|
| 177 |
+
- **Persistent session**: Server maintains your environment state
|
| 178 |
+
- **Efficient for episodes**: Better for many sequential steps
|
| 179 |
+
|
| 180 |
+
### Concurrent WebSocket Sessions
|
| 181 |
+
|
| 182 |
+
The server supports multiple concurrent WebSocket connections. To enable this,
|
| 183 |
+
modify `server/app.py` to use factory mode:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
# In server/app.py - use factory mode for concurrent sessions
|
| 187 |
+
app = create_app(
|
| 188 |
+
DataWranglerEnvironment, # Pass class, not instance
|
| 189 |
+
DataWranglerAction,
|
| 190 |
+
DataWranglerObservation,
|
| 191 |
+
max_concurrent_envs=4, # Allow 4 concurrent sessions
|
| 192 |
+
)
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
Then multiple clients can connect simultaneously:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
from data_wrangler import DataWranglerAction, DataWranglerEnv
|
| 199 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 200 |
+
|
| 201 |
+
def run_episode(client_id: int):
|
| 202 |
+
with DataWranglerEnv(base_url="http://localhost:8000") as env:
|
| 203 |
+
result = env.reset()
|
| 204 |
+
for i in range(10):
|
| 205 |
+
result = env.step(DataWranglerAction(message=f"Client {client_id}, step {i}"))
|
| 206 |
+
return client_id, result.observation.message_length
|
| 207 |
+
|
| 208 |
+
# Run 4 episodes concurrently
|
| 209 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 210 |
+
results = list(executor.map(run_episode, range(4)))
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
## Development & Testing
|
| 214 |
+
|
| 215 |
+
### Direct Environment Testing
|
| 216 |
+
|
| 217 |
+
Test the environment logic directly without starting the HTTP server:
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
# From the server directory
|
| 221 |
+
python3 server/data_wrangler_environment.py
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
This verifies that:
|
| 225 |
+
- Environment resets correctly
|
| 226 |
+
- Step executes actions properly
|
| 227 |
+
- State tracking works
|
| 228 |
+
- Rewards are calculated correctly
|
| 229 |
+
|
| 230 |
+
### Running Locally
|
| 231 |
+
|
| 232 |
+
Run the server locally for development:
|
| 233 |
+
|
| 234 |
+
```bash
|
| 235 |
+
uvicorn server.app:app --reload
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## Project Structure
|
| 239 |
+
|
| 240 |
+
```
|
| 241 |
+
data_wrangler/
|
| 242 |
+
├── .dockerignore # Docker build exclusions
|
| 243 |
+
├── __init__.py # Module exports
|
| 244 |
+
├── README.md # This file
|
| 245 |
+
├── openenv.yaml # OpenEnv manifest
|
| 246 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 247 |
+
├── uv.lock # Locked dependencies (generated)
|
| 248 |
+
├── client.py # DataWranglerEnv client
|
| 249 |
+
├── models.py # Action and Observation models
|
| 250 |
+
└── server/
|
| 251 |
+
├── __init__.py # Server module exports
|
| 252 |
+
├── data_wrangler_environment.py # Core environment logic
|
| 253 |
+
├── app.py # FastAPI application (HTTP + WebSocket endpoints)
|
| 254 |
+
└── Dockerfile # Container image definition
|
| 255 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Data Wrangler Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import DataWranglerEnv
|
| 10 |
+
from .models import DataWranglerAction, DataWranglerObservation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"DataWranglerAction",
|
| 14 |
+
"DataWranglerObservation",
|
| 15 |
+
"DataWranglerEnv",
|
| 16 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Data Wrangler Environment Client."""
|
| 8 |
+
|
| 9 |
+
from typing import Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core import EnvClient
|
| 12 |
+
from openenv.core.client_types import StepResult
|
| 13 |
+
from openenv.core.env_server.types import State
|
| 14 |
+
|
| 15 |
+
from .models import DataWranglerAction, DataWranglerObservation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DataWranglerEnv(
|
| 19 |
+
EnvClient[DataWranglerAction, DataWranglerObservation, State]
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Client for the Data Wrangler Environment.
|
| 23 |
+
|
| 24 |
+
This client maintains a persistent WebSocket connection to the environment server,
|
| 25 |
+
enabling efficient multi-step interactions with lower latency.
|
| 26 |
+
Each client instance has its own dedicated environment session on the server.
|
| 27 |
+
|
| 28 |
+
Example:
|
| 29 |
+
>>> # Connect to a running server
|
| 30 |
+
>>> with DataWranglerEnv(base_url="http://localhost:8000") as client:
|
| 31 |
+
... result = client.reset()
|
| 32 |
+
... print(result.observation.echoed_message)
|
| 33 |
+
...
|
| 34 |
+
... result = client.step(DataWranglerAction(message="Hello!"))
|
| 35 |
+
... print(result.observation.echoed_message)
|
| 36 |
+
|
| 37 |
+
Example with Docker:
|
| 38 |
+
>>> # Automatically start container and connect
|
| 39 |
+
>>> client = DataWranglerEnv.from_docker_image("data_wrangler-env:latest")
|
| 40 |
+
>>> try:
|
| 41 |
+
... result = client.reset()
|
| 42 |
+
... result = client.step(DataWranglerAction(message="Test"))
|
| 43 |
+
... finally:
|
| 44 |
+
... client.close()
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def _step_payload(self, action: DataWranglerAction) -> Dict:
|
| 48 |
+
"""
|
| 49 |
+
Convert DataWranglerAction to JSON payload for step message.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
action: DataWranglerAction instance
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Dictionary representation suitable for JSON encoding
|
| 56 |
+
"""
|
| 57 |
+
return {
|
| 58 |
+
"message": action.message,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def _parse_result(self, payload: Dict) -> StepResult[DataWranglerObservation]:
|
| 62 |
+
"""
|
| 63 |
+
Parse server response into StepResult[DataWranglerObservation].
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
payload: JSON response data from server
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
StepResult with DataWranglerObservation
|
| 70 |
+
"""
|
| 71 |
+
obs_data = payload.get("observation", {})
|
| 72 |
+
observation = DataWranglerObservation(
|
| 73 |
+
echoed_message=obs_data.get("echoed_message", ""),
|
| 74 |
+
message_length=obs_data.get("message_length", 0),
|
| 75 |
+
done=payload.get("done", False),
|
| 76 |
+
reward=payload.get("reward"),
|
| 77 |
+
metadata=obs_data.get("metadata", {}),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return StepResult(
|
| 81 |
+
observation=observation,
|
| 82 |
+
reward=payload.get("reward"),
|
| 83 |
+
done=payload.get("done", False),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 87 |
+
"""
|
| 88 |
+
Parse server response into State object.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
payload: JSON response from state request
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
State object with episode_id and step_count
|
| 95 |
+
"""
|
| 96 |
+
return State(
|
| 97 |
+
episode_id=payload.get("episode_id"),
|
| 98 |
+
step_count=payload.get("step_count", 0),
|
| 99 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import asyncio
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
|
| 6 |
+
# OpenEnv V5 specific client components
|
| 7 |
+
# We import directly since OpenEnv varies slightly in versions, but this mirrors the validator script expectations.
|
| 8 |
+
from openenv.core.client import EnvClient
|
| 9 |
+
|
| 10 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 11 |
+
API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
| 12 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
|
| 13 |
+
IMAGE_NAME = "data_wrangler"
|
| 14 |
+
TASK_NAME = "Data Writer Level 1"
|
| 15 |
+
BENCHMARK = "data_wrangler"
|
| 16 |
+
MAX_STEPS = 10
|
| 17 |
+
MAX_TOTAL_REWARD = 1.0
|
| 18 |
+
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 19 |
+
|
| 20 |
+
system_prompt = """\
|
| 21 |
+
SYSTEM INSTRUCTIONS: ELITE DATA ENGINEER AGENT
|
| 22 |
+
|
| 23 |
+
ROLE AND PERSONA
|
| 24 |
+
You are an elite Data Engineering AI Agent operating within an automated data-wrangling pipeline. Your core function is to autonomously clean, format, and standardize messy, real-world datasets until they perfectly match a hidden "ground truth" target. You operate systematically, analytically, and with absolute precision.
|
| 25 |
+
|
| 26 |
+
MISSION OBJECTIVE
|
| 27 |
+
At each step, you will receive an Observation of the current data state. You must analyze the data anomalies (missing values, bad schemas, incorrect data types) and issue exactly ONE valid operation from your Action Space. You will iterate on this process until the dataset is perfectly clean, at which point you will issue the submit action.
|
| 28 |
+
|
| 29 |
+
THE OBSERVATION
|
| 30 |
+
You will receive a state dictionary detailing the dataset's current form:
|
| 31 |
+
columns: Current list of headers.
|
| 32 |
+
row_count: Total number of rows in the dataset.
|
| 33 |
+
column_stats: Dictionary mapping column names to {dtype, missing_count, sample_values}.
|
| 34 |
+
last_action_feedback: Status/error message resulting from your previous action.
|
| 35 |
+
is_done: Boolean termination flag.
|
| 36 |
+
|
| 37 |
+
ACTION SPACE (AVAILABLE TOOLS)
|
| 38 |
+
You have a strict, highly constrained toolset. Your chosen action MUST be a valid JSON object matching exactly ONE of the schemas:
|
| 39 |
+
1. Drop Column: {"action_type": "drop_column", "target_column": "..."}
|
| 40 |
+
2. Rename Column: {"action_type": "rename_column", "target_column": "...", "new_name": "..."}
|
| 41 |
+
3. Fill Missing Values: {"action_type": "fill_missing", "target_column": "...", "fill_value": "..."}
|
| 42 |
+
4. Cast Data Type: {"action_type": "cast_type", "target_column": "...", "cast_to": "..."}
|
| 43 |
+
5. Submit: {"action_type": "submit"}
|
| 44 |
+
|
| 45 |
+
REQUIRED OUTPUT FORMAT (CHAIN OF THOUGHT)
|
| 46 |
+
<thinking>
|
| 47 |
+
Analyze Observation: What is the current state? What did the last action do?
|
| 48 |
+
Identify Anomalies: Which columns have wrong types, bad names, or missing data?
|
| 49 |
+
Formulate Plan: What is the highest priority fix right now?
|
| 50 |
+
Select Action: Which action type and parameters will execute this fix?
|
| 51 |
+
</thinking>
|
| 52 |
+
{
|
| 53 |
+
"action_type": "...",
|
| 54 |
+
...
|
| 55 |
+
}
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
async def get_model_message(client, step, obs_dict, last_reward, history):
|
| 59 |
+
obs_text = str(obs_dict)
|
| 60 |
+
prompt = f"Step {step}.\nObservation: {obs_text}\nLast Reward: {last_reward}\nHistory: {history}\nChoose your next action (JSON matching schema)."
|
| 61 |
+
try:
|
| 62 |
+
response = await client.chat.completions.create(
|
| 63 |
+
model=MODEL_NAME,
|
| 64 |
+
messages=[
|
| 65 |
+
{"role": "system", "content": system_prompt},
|
| 66 |
+
{"role": "user", "content": prompt}
|
| 67 |
+
],
|
| 68 |
+
temperature=0.0
|
| 69 |
+
)
|
| 70 |
+
content = response.choices[0].message.content
|
| 71 |
+
import json
|
| 72 |
+
import re
|
| 73 |
+
# Basic parsing of the JSON structure that follows the thinking tags
|
| 74 |
+
match = re.search(r'(\{.*\})', content, re.DOTALL)
|
| 75 |
+
if match:
|
| 76 |
+
return json.loads(match.group(1))
|
| 77 |
+
# Fallback if unparseable
|
| 78 |
+
return {"action_type": "submit"}
|
| 79 |
+
except Exception as e:
|
| 80 |
+
return {"action_type": "submit"}
|
| 81 |
+
|
| 82 |
+
def log_start(task, env, model):
|
| 83 |
+
print(f"[START] task={task} env={env} model={model}")
|
| 84 |
+
|
| 85 |
+
def log_step(step, action, reward, done, error):
|
| 86 |
+
print(f"[STEP] step={step} action={action} reward={reward} done={done} error={error}")
|
| 87 |
+
|
| 88 |
+
def log_end(success, steps, score, rewards):
|
| 89 |
+
print(f"[END] success={success} steps={steps} score={score} rewards={rewards}")
|
| 90 |
+
|
| 91 |
+
async def main():
|
| 92 |
+
if not API_KEY:
|
| 93 |
+
print("Missing OPENAI_API_KEY environment variable.")
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 97 |
+
|
| 98 |
+
# Needs EnvClient or appropriate environment factory setup depending on OpenEnv validator logic
|
| 99 |
+
# Following generic OpenEnv V4/V5 inference loop
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
from server.data_wrangler_environment import DataWranglerEnvironment
|
| 103 |
+
env = DataWranglerEnvironment() # Using local environment logic directly to mock tests for now
|
| 104 |
+
except BaseException as e:
|
| 105 |
+
print("Could not load local environment for test.", e)
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
history = []
|
| 109 |
+
rewards = []
|
| 110 |
+
steps_taken = 0
|
| 111 |
+
score = 0.0
|
| 112 |
+
success = False
|
| 113 |
+
|
| 114 |
+
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
result = env.reset()
|
| 118 |
+
obs_dict = {
|
| 119 |
+
"columns": result.columns,
|
| 120 |
+
"row_count": result.row_count,
|
| 121 |
+
"column_stats": result.column_stats,
|
| 122 |
+
"last_action_feedback": result.last_action_feedback,
|
| 123 |
+
"is_done": result.is_done
|
| 124 |
+
}
|
| 125 |
+
last_reward = result.reward
|
| 126 |
+
|
| 127 |
+
for step in range(1, MAX_STEPS + 1):
|
| 128 |
+
if result.is_done:
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
action_data = await get_model_message(client, step, obs_dict, last_reward, history)
|
| 132 |
+
|
| 133 |
+
from models import DataWranglerAction
|
| 134 |
+
action_obj = DataWranglerAction(**action_data)
|
| 135 |
+
result = env.step(action_obj)
|
| 136 |
+
|
| 137 |
+
obs_dict = {
|
| 138 |
+
"columns": result.columns,
|
| 139 |
+
"row_count": result.row_count,
|
| 140 |
+
"column_stats": result.column_stats,
|
| 141 |
+
"last_action_feedback": result.last_action_feedback,
|
| 142 |
+
"is_done": result.is_done
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
reward = result.reward or 0.0
|
| 146 |
+
done = result.done or result.is_done
|
| 147 |
+
error = None
|
| 148 |
+
|
| 149 |
+
rewards.append(reward)
|
| 150 |
+
steps_taken = step
|
| 151 |
+
last_reward = reward
|
| 152 |
+
|
| 153 |
+
log_step(step=step, action=action_data, reward=reward, done=done, error=error)
|
| 154 |
+
|
| 155 |
+
history.append(f"Step {step}: {action_data} -> reward {reward:+.2f}")
|
| 156 |
+
|
| 157 |
+
if done:
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
|
| 161 |
+
score = min(max(score, 0.0), 1.0)
|
| 162 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 163 |
+
|
| 164 |
+
finally:
|
| 165 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
asyncio.run(main())
|
models.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the Data Wrangler Environment.
|
| 9 |
+
|
| 10 |
+
The data_wrangler environment is a simple test environment that echoes back messages.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Dict, List, Optional, Any
|
| 14 |
+
from openenv.core.env_server.types import Action, Observation
|
| 15 |
+
from pydantic import Field
|
| 16 |
+
|
| 17 |
+
class DataWranglerAction(Action):
|
| 18 |
+
"""Action for the Data Wrangler environment."""
|
| 19 |
+
action_type: str = Field(..., description="Type of action: drop_column, rename_column, fill_missing, cast_type, submit")
|
| 20 |
+
|
| 21 |
+
# Specifics depending on action_type
|
| 22 |
+
target_column: Optional[str] = Field(None, description="The name of the column to act upon.")
|
| 23 |
+
new_name: Optional[str] = Field(None, description="New name of the column (for rename_column).")
|
| 24 |
+
fill_value: Optional[str] = Field(None, description="Value to fill missing data with (for fill_missing).")
|
| 25 |
+
cast_to: Optional[str] = Field(None, description="Target data type (for cast_type, e.g. 'int', 'float', 'datetime', 'string').")
|
| 26 |
+
|
| 27 |
+
class DataWranglerObservation(Observation):
|
| 28 |
+
"""Observation representing the state of the dataset."""
|
| 29 |
+
columns: List[str] = Field(default_factory=list, description="Current list of headers.")
|
| 30 |
+
row_count: int = Field(default=0, description="Total number of rows in the dataset.")
|
| 31 |
+
column_stats: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Stats per column: dtype, missing_count, sample_values.")
|
| 32 |
+
last_action_feedback: str = Field(default="Environment initialized.", description="Feedback from the last executed action.")
|
| 33 |
+
is_done: bool = Field(default=False, description="Whether the task has terminated.")
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: data_wrangler
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-data_wrangler"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Data Wrangler environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
"openenv-core[core]>=0.2.2",
|
| 19 |
+
"pandas>=2.0.0",
|
| 20 |
+
"openai>=1.0.0",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project.optional-dependencies]
|
| 24 |
+
dev = [
|
| 25 |
+
"pytest>=8.0.0",
|
| 26 |
+
"pytest-cov>=4.0.0",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[project.scripts]
|
| 30 |
+
# Server entry point - enables running via: uv run --project . server
|
| 31 |
+
# or: python -m data_wrangler.server.app
|
| 32 |
+
server = "data_wrangler.server.app:main"
|
| 33 |
+
|
| 34 |
+
[tool.setuptools]
|
| 35 |
+
include-package-data = true
|
| 36 |
+
packages = ["data_wrangler", "data_wrangler.server"]
|
| 37 |
+
package-dir = { "data_wrangler" = ".", "data_wrangler.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Data Wrangler environment server components."""
|
| 8 |
+
|
| 9 |
+
from .data_wrangler_environment import DataWranglerEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["DataWranglerEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the Data Wrangler Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes the DataWranglerEnvironment
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Endpoints:
|
| 14 |
+
- POST /reset: Reset the environment
|
| 15 |
+
- POST /step: Execute an action
|
| 16 |
+
- GET /state: Get current environment state
|
| 17 |
+
- GET /schema: Get action/observation schemas
|
| 18 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
# Development (with auto-reload):
|
| 22 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 23 |
+
|
| 24 |
+
# Production:
|
| 25 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
|
| 26 |
+
|
| 27 |
+
# Or run directly:
|
| 28 |
+
python -m server.app
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from openenv.core.env_server.http_server import create_app
|
| 33 |
+
except Exception as e: # pragma: no cover
|
| 34 |
+
raise ImportError(
|
| 35 |
+
"openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
|
| 36 |
+
) from e
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from ..models import DataWranglerAction, DataWranglerObservation
|
| 40 |
+
from .data_wrangler_environment import DataWranglerEnvironment
|
| 41 |
+
except (ImportError, ValueError, ModuleNotFoundError):
|
| 42 |
+
import sys
|
| 43 |
+
import os
|
| 44 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 45 |
+
from models import DataWranglerAction, DataWranglerObservation
|
| 46 |
+
from server.data_wrangler_environment import DataWranglerEnvironment
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Create the app with web interface and README integration
|
| 50 |
+
app = create_app(
|
| 51 |
+
DataWranglerEnvironment,
|
| 52 |
+
DataWranglerAction,
|
| 53 |
+
DataWranglerObservation,
|
| 54 |
+
env_name="data_wrangler",
|
| 55 |
+
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 60 |
+
"""
|
| 61 |
+
Entry point for direct execution via uv run or python -m.
|
| 62 |
+
|
| 63 |
+
This function enables running the server without Docker:
|
| 64 |
+
uv run --project . server
|
| 65 |
+
uv run --project . server --port 8001
|
| 66 |
+
python -m data_wrangler.server.app
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
host: Host address to bind to (default: "0.0.0.0")
|
| 70 |
+
port: Port number to listen on (default: 8000)
|
| 71 |
+
|
| 72 |
+
For production deployments, consider using uvicorn directly with
|
| 73 |
+
multiple workers:
|
| 74 |
+
uvicorn data_wrangler.server.app:app --workers 4
|
| 75 |
+
"""
|
| 76 |
+
import uvicorn
|
| 77 |
+
|
| 78 |
+
uvicorn.run(app, host=host, port=port)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == '__main__':
|
| 82 |
+
# ensures openenv validate passes main() check
|
| 83 |
+
import argparse
|
| 84 |
+
|
| 85 |
+
parser = argparse.ArgumentParser()
|
| 86 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 87 |
+
args = parser.parse_args()
|
| 88 |
+
if not args.port or args.port == 8000:
|
| 89 |
+
main()
|
| 90 |
+
else:
|
| 91 |
+
main(port=args.port)
|
server/data_wrangler_environment.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from uuid import uuid4
|
| 4 |
+
from openenv.core.env_server.interfaces import Environment
|
| 5 |
+
from openenv.core.env_server.types import State
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from ..models import DataWranglerAction, DataWranglerObservation
|
| 9 |
+
except (ImportError, ValueError, ModuleNotFoundError):
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 13 |
+
from models import DataWranglerAction, DataWranglerObservation
|
| 14 |
+
|
| 15 |
+
class DataWranglerEnvironment(Environment):
|
| 16 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 20 |
+
self._reset_count = 0
|
| 21 |
+
self.df = None
|
| 22 |
+
self.target_df = None
|
| 23 |
+
self.task_level = int(os.environ.get("TASK_LEVEL", "1"))
|
| 24 |
+
self._initialize_task()
|
| 25 |
+
|
| 26 |
+
def _initialize_task(self):
|
| 27 |
+
self.df = pd.DataFrame()
|
| 28 |
+
self.target_df = pd.DataFrame()
|
| 29 |
+
if self.task_level == 1:
|
| 30 |
+
# Easy: Just drop a column and rename one
|
| 31 |
+
self.df = pd.DataFrame({
|
| 32 |
+
"User Name": ["Alice", "Bob", "Charlie"],
|
| 33 |
+
"Unnamed: 0": [0, 1, 2],
|
| 34 |
+
"Age": [25, 30, 35]
|
| 35 |
+
})
|
| 36 |
+
self.target_df = pd.DataFrame({
|
| 37 |
+
"user_name": ["Alice", "Bob", "Charlie"],
|
| 38 |
+
"age": [25, 30, 35]
|
| 39 |
+
})
|
| 40 |
+
elif self.task_level == 2:
|
| 41 |
+
# Medium: fill missing and cast type
|
| 42 |
+
self.df = pd.DataFrame({
|
| 43 |
+
"product_ID ": ["101", "102", "103"],
|
| 44 |
+
"price": ["10.5", None, "12.0"],
|
| 45 |
+
"bad_col": [None, None, None]
|
| 46 |
+
})
|
| 47 |
+
self.target_df = pd.DataFrame({
|
| 48 |
+
"product_id": [101.0, 102.0, 103.0],
|
| 49 |
+
"price": [10.5, 0.0, 12.0]
|
| 50 |
+
})
|
| 51 |
+
else:
|
| 52 |
+
# Hard: Multiple issues
|
| 53 |
+
self.df = pd.DataFrame({
|
| 54 |
+
"date_joined ": ["2020-01-01", "2021-05-15", None],
|
| 55 |
+
"Sales_total": ["100", "200", "300"],
|
| 56 |
+
"IsActive": [True, False, None],
|
| 57 |
+
"DROPME_1": [1,2,3]
|
| 58 |
+
})
|
| 59 |
+
self.target_df = pd.DataFrame({
|
| 60 |
+
"date_joined": [pd.Timestamp("2020-01-01"), pd.Timestamp("2021-05-15"), pd.Timestamp("1970-01-01")],
|
| 61 |
+
"sales_total": [100.0, 200.0, 300.0],
|
| 62 |
+
"is_active": [True, False, False]
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
def _get_obs(self, feedback: str = "Environment initialized.", done: bool = False, reward: float = 0.0) -> DataWranglerObservation:
|
| 66 |
+
stats = {}
|
| 67 |
+
for col in self.df.columns:
|
| 68 |
+
stats[col] = {
|
| 69 |
+
"dtype": str(self.df[col].dtype),
|
| 70 |
+
"missing_count": int(self.df[col].isna().sum()),
|
| 71 |
+
"sample_values": self.df[col].dropna().astype(str).tolist()[:3]
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
return DataWranglerObservation(
|
| 75 |
+
columns=list(self.df.columns),
|
| 76 |
+
row_count=len(self.df),
|
| 77 |
+
column_stats=stats,
|
| 78 |
+
last_action_feedback=feedback,
|
| 79 |
+
is_done=done,
|
| 80 |
+
reward=reward,
|
| 81 |
+
done=done,
|
| 82 |
+
metadata={"step": self._state.step_count}
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def reset(self) -> DataWranglerObservation:
|
| 86 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 87 |
+
self._reset_count += 1
|
| 88 |
+
self._initialize_task()
|
| 89 |
+
return self._get_obs()
|
| 90 |
+
|
| 91 |
+
def step(self, action: DataWranglerAction) -> DataWranglerObservation: # type: ignore
|
| 92 |
+
self._state.step_count += 1
|
| 93 |
+
feedback = "Action executed successfully."
|
| 94 |
+
reward = 0.0
|
| 95 |
+
done = False
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
if action.action_type == "drop_column":
|
| 99 |
+
col = action.target_column
|
| 100 |
+
if col in self.df.columns:
|
| 101 |
+
self.df.drop(columns=[col], inplace=True)
|
| 102 |
+
if col not in self.target_df.columns:
|
| 103 |
+
reward = 0.2
|
| 104 |
+
else:
|
| 105 |
+
reward = -0.5
|
| 106 |
+
feedback = f"Warning: dropped targeting column {col}"
|
| 107 |
+
else:
|
| 108 |
+
feedback = f"Error: Column '{col}' not found."
|
| 109 |
+
|
| 110 |
+
elif action.action_type == "rename_column":
|
| 111 |
+
col = action.target_column
|
| 112 |
+
new_col = action.new_name
|
| 113 |
+
if col in self.df.columns:
|
| 114 |
+
self.df.rename(columns={col: new_col}, inplace=True)
|
| 115 |
+
if new_col in self.target_df.columns:
|
| 116 |
+
reward = 0.2
|
| 117 |
+
else:
|
| 118 |
+
feedback = f"Error: Column '{col}' not found."
|
| 119 |
+
|
| 120 |
+
elif action.action_type == "fill_missing":
|
| 121 |
+
col = action.target_column
|
| 122 |
+
if col in self.df.columns:
|
| 123 |
+
self.df[col].fillna(action.fill_value, inplace=True)
|
| 124 |
+
reward = 0.1
|
| 125 |
+
else:
|
| 126 |
+
feedback = f"Error: Column '{col}' not found."
|
| 127 |
+
|
| 128 |
+
elif action.action_type == "cast_type":
|
| 129 |
+
col = action.target_column
|
| 130 |
+
to_type = action.cast_to
|
| 131 |
+
if col in self.df.columns:
|
| 132 |
+
if to_type == 'int':
|
| 133 |
+
self.df = self.df.astype({col: int})
|
| 134 |
+
elif to_type == 'float':
|
| 135 |
+
self.df = self.df.astype({col: float})
|
| 136 |
+
elif to_type == 'datetime':
|
| 137 |
+
self.df[col] = pd.to_datetime(self.df[col])
|
| 138 |
+
elif to_type == 'string':
|
| 139 |
+
self.df = self.df.astype({col: str})
|
| 140 |
+
reward = 0.2
|
| 141 |
+
else:
|
| 142 |
+
feedback = f"Error: Column '{col}' not found."
|
| 143 |
+
|
| 144 |
+
elif action.action_type == "submit":
|
| 145 |
+
score = self._grade()
|
| 146 |
+
reward = score
|
| 147 |
+
feedback = f"Submitted. Final Score: {score}"
|
| 148 |
+
done = True
|
| 149 |
+
else:
|
| 150 |
+
feedback = f"Error: Unknown action type {action.action_type}"
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
feedback = f"Exception occurred: {str(e)}"
|
| 154 |
+
reward = -0.1
|
| 155 |
+
|
| 156 |
+
return self._get_obs(feedback=feedback, done=done, reward=reward)
|
| 157 |
+
|
| 158 |
+
def _grade(self) -> float:
|
| 159 |
+
score = 0.0
|
| 160 |
+
if list(self.df.columns) == list(self.target_df.columns):
|
| 161 |
+
score += 0.5
|
| 162 |
+
# Match types and values
|
| 163 |
+
value_matches = 0
|
| 164 |
+
for col in self.df.columns:
|
| 165 |
+
try:
|
| 166 |
+
# simple match check
|
| 167 |
+
match = (self.df[col] == self.target_df[col]).all()
|
| 168 |
+
if match:
|
| 169 |
+
value_matches += 1
|
| 170 |
+
except:
|
| 171 |
+
pass
|
| 172 |
+
score += 0.5 * (value_matches / max(len(self.target_df.columns), 1))
|
| 173 |
+
|
| 174 |
+
return score
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def state(self) -> State:
|
| 178 |
+
return self._state
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
openai>=1.0.0
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|