Spaces:
Running
Running
Zach Wentz commited on
Commit ·
088d017
1
Parent(s): 378e8a1
🤖 Deploy sumo_rl_env environment - 2025-10-22 10:17:10
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +0 -35
- Dockerfile +73 -0
- README.md +51 -5
- src/core/__init__.py +19 -0
- src/core/__pycache__/__init__.cpython-311.pyc +0 -0
- src/core/__pycache__/__init__.cpython-313.pyc +0 -0
- src/core/__pycache__/http_env_client.cpython-311.pyc +0 -0
- src/core/__pycache__/http_env_client.cpython-313.pyc +0 -0
- src/core/__pycache__/types.cpython-311.pyc +0 -0
- src/core/__pycache__/types.cpython-313.pyc +0 -0
- src/core/containers/__init__.py +7 -0
- src/core/containers/__pycache__/__init__.cpython-311.pyc +0 -0
- src/core/containers/__pycache__/__init__.cpython-313.pyc +0 -0
- src/core/containers/images/Dockerfile +46 -0
- src/core/containers/images/README.md +92 -0
- src/core/containers/runtime/__init__.py +15 -0
- src/core/containers/runtime/__pycache__/__init__.cpython-311.pyc +0 -0
- src/core/containers/runtime/__pycache__/__init__.cpython-313.pyc +0 -0
- src/core/containers/runtime/__pycache__/providers.cpython-311.pyc +0 -0
- src/core/containers/runtime/__pycache__/providers.cpython-313.pyc +0 -0
- src/core/containers/runtime/providers.py +289 -0
- src/core/containers/test_local_docker_provider.py +258 -0
- src/core/env_server/__init__.py +35 -0
- src/core/env_server/__pycache__/__init__.cpython-311.pyc +0 -0
- src/core/env_server/__pycache__/__init__.cpython-313.pyc +0 -0
- src/core/env_server/__pycache__/base_transforms.cpython-311.pyc +0 -0
- src/core/env_server/__pycache__/base_transforms.cpython-313.pyc +0 -0
- src/core/env_server/__pycache__/http_server.cpython-311.pyc +0 -0
- src/core/env_server/__pycache__/http_server.cpython-313.pyc +0 -0
- src/core/env_server/__pycache__/interfaces.cpython-311.pyc +0 -0
- src/core/env_server/__pycache__/interfaces.cpython-313.pyc +0 -0
- src/core/env_server/__pycache__/types.cpython-311.pyc +0 -0
- src/core/env_server/__pycache__/types.cpython-313.pyc +0 -0
- src/core/env_server/__pycache__/web_interface.cpython-311.pyc +0 -0
- src/core/env_server/__pycache__/web_interface.cpython-313.pyc +0 -0
- src/core/env_server/base_transforms.py +29 -0
- src/core/env_server/http_server.py +233 -0
- src/core/env_server/interfaces.py +118 -0
- src/core/env_server/types.py +57 -0
- src/core/env_server/web_interface.py +1613 -0
- src/core/http_env_client.py +175 -0
- src/core/tools/__init__.py +11 -0
- src/core/tools/local_python_executor.py +105 -0
- src/core/types.py +22 -0
- src/envs/sumo_rl_env/README.md +341 -0
- src/envs/sumo_rl_env/__init__.py +31 -0
- src/envs/sumo_rl_env/client.py +145 -0
- src/envs/sumo_rl_env/models.py +110 -0
- src/envs/sumo_rl_env/nets/single-intersection/single-intersection.edg.xml +6 -0
- src/envs/sumo_rl_env/nets/single-intersection/single-intersection.net.xml +86 -0
.gitattributes
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build: First stage builds the base image
|
| 8 |
+
FROM python:3.11-slim as base-builder
|
| 9 |
+
|
| 10 |
+
# Install system dependencies
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
curl \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies that all environments need
|
| 16 |
+
RUN pip install --no-cache-dir \
|
| 17 |
+
fastapi>=0.104.0 \
|
| 18 |
+
"uvicorn[standard]>=0.24.0" \
|
| 19 |
+
requests>=2.25.0 \
|
| 20 |
+
wsproto>=1.0.0
|
| 21 |
+
|
| 22 |
+
# Set working directory
|
| 23 |
+
WORKDIR /app
|
| 24 |
+
|
| 25 |
+
# Default environment variables
|
| 26 |
+
ENV PYTHONPATH=/app/src
|
| 27 |
+
ENV PYTHONUNBUFFERED=1
|
| 28 |
+
|
| 29 |
+
# Second stage: Use the built base image and add environment-specific dependencies
|
| 30 |
+
FROM base-builder
|
| 31 |
+
|
| 32 |
+
# Install SUMO system dependencies
|
| 33 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 34 |
+
sumo \
|
| 35 |
+
sumo-tools \
|
| 36 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 37 |
+
|
| 38 |
+
# Set SUMO_HOME environment variable
|
| 39 |
+
ENV SUMO_HOME=/usr/share/sumo
|
| 40 |
+
|
| 41 |
+
# Install SUMO-RL and Python dependencies
|
| 42 |
+
RUN pip install --no-cache-dir \
|
| 43 |
+
gymnasium>=0.28 \
|
| 44 |
+
pettingzoo>=1.24.3 \
|
| 45 |
+
numpy>=1.24.0 \
|
| 46 |
+
pandas>=2.0.0 \
|
| 47 |
+
sumolib>=1.14.0 \
|
| 48 |
+
traci>=1.14.0 \
|
| 49 |
+
sumo-rl>=1.4.5
|
| 50 |
+
|
| 51 |
+
# SUMO environment variables (can be overridden at runtime)
|
| 52 |
+
ENV SUMO_NET_FILE=/app/nets/single-intersection/single-intersection.net.xml
|
| 53 |
+
ENV SUMO_ROUTE_FILE=/app/nets/single-intersection/single-intersection.rou.xml
|
| 54 |
+
ENV SUMO_NUM_SECONDS=20000
|
| 55 |
+
ENV SUMO_DELTA_TIME=5
|
| 56 |
+
ENV SUMO_YELLOW_TIME=2
|
| 57 |
+
ENV SUMO_MIN_GREEN=5
|
| 58 |
+
ENV SUMO_MAX_GREEN=50
|
| 59 |
+
ENV SUMO_REWARD_FN=diff-waiting-time
|
| 60 |
+
ENV SUMO_SEED=42
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Copy only what's needed for this environment
|
| 64 |
+
COPY src/core/ /app/src/core/
|
| 65 |
+
COPY src/envs/sumo_rl_env/ /app/src/envs/sumo_rl_env/
|
| 66 |
+
|
| 67 |
+
# Health check
|
| 68 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 69 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 70 |
+
|
| 71 |
+
# Run the FastAPI server
|
| 72 |
+
CMD ["uvicorn", "envs.sumo_rl_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 73 |
+
ENV ENABLE_WEB_INTERFACE=true
|
README.md
CHANGED
|
@@ -1,10 +1,56 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Sumo_rl_env Environment Server
|
| 3 |
+
emoji: 🐳
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Sumo_rl_env Environment Server
|
| 13 |
+
|
| 14 |
+
FastAPI server for sumo_rl_env environment powered by Meta's OpenEnv.
|
| 15 |
+
|
| 16 |
+
## About
|
| 17 |
+
|
| 18 |
+
This Space provides a containerized environment for sumo_rl_env interactions.
|
| 19 |
+
Built with FastAPI and OpenEnv framework.
|
| 20 |
+
|
| 21 |
+
## Web Interface
|
| 22 |
+
|
| 23 |
+
This deployment includes an interactive web interface for exploring the environment:
|
| 24 |
+
- **HumanAgent Interface**: Interact with the environment using a web form
|
| 25 |
+
- **State Observer**: Real-time view of environment state and action history
|
| 26 |
+
- **Live Updates**: WebSocket-based real-time updates
|
| 27 |
+
|
| 28 |
+
Access the web interface at: `/web`
|
| 29 |
+
|
| 30 |
+
## SUMO Environment
|
| 31 |
+
|
| 32 |
+
Provides traffic signal control via SUMO (Simulation of Urban MObility) for reinforcement learning.
|
| 33 |
+
|
| 34 |
+
### Usage
|
| 35 |
+
Send a POST request to `/step` with:
|
| 36 |
+
```json
|
| 37 |
+
{
|
| 38 |
+
"action": {
|
| 39 |
+
"phase_id": 1
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
### Features
|
| 45 |
+
- **Traffic Simulation**: Realistic traffic flow via SUMO
|
| 46 |
+
- **Signal Control**: Optimize traffic light timing
|
| 47 |
+
- **Multiple Networks**: Support for custom traffic networks
|
| 48 |
+
- **Configurable Rewards**: Waiting time, queue length, pressure metrics
|
| 49 |
+
|
| 50 |
+
## API Documentation
|
| 51 |
+
|
| 52 |
+
Visit `/docs` for interactive API documentation.
|
| 53 |
+
|
| 54 |
+
## Health Check
|
| 55 |
+
|
| 56 |
+
The environment provides a health check endpoint at `/health`.
|
src/core/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Core components for agentic environments."""
|
| 8 |
+
|
| 9 |
+
# Re-export main components from submodules for convenience
|
| 10 |
+
from .env_server import *
|
| 11 |
+
from .http_env_client import HTTPEnvClient
|
| 12 |
+
from .types import StepResult
|
| 13 |
+
|
| 14 |
+
# Note: MCP module doesn't export anything yet
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"HTTPEnvClient",
|
| 18 |
+
"StepResult",
|
| 19 |
+
]
|
src/core/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (400 Bytes). View file
|
|
|
src/core/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (383 Bytes). View file
|
|
|
src/core/__pycache__/http_env_client.cpython-311.pyc
ADDED
|
Binary file (7.68 kB). View file
|
|
|
src/core/__pycache__/http_env_client.cpython-313.pyc
ADDED
|
Binary file (6.93 kB). View file
|
|
|
src/core/__pycache__/types.cpython-311.pyc
ADDED
|
Binary file (1.09 kB). View file
|
|
|
src/core/__pycache__/types.cpython-313.pyc
ADDED
|
Binary file (993 Bytes). View file
|
|
|
src/core/containers/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Container management for environment servers."""
|
src/core/containers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
src/core/containers/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
src/core/containers/images/Dockerfile
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#
|
| 8 |
+
# OpenEnv Base Image
|
| 9 |
+
#
|
| 10 |
+
# This is the standard base image for all OpenEnv environment servers.
|
| 11 |
+
# It includes the minimal dependencies needed to run HTTP environment servers.
|
| 12 |
+
#
|
| 13 |
+
# Build: docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 14 |
+
# Tag: docker tag openenv-base:latest openenv-base:0.1.0
|
| 15 |
+
#
|
| 16 |
+
|
| 17 |
+
FROM python:3.11-slim
|
| 18 |
+
|
| 19 |
+
# Set metadata
|
| 20 |
+
LABEL maintainer="OpenEnv Team"
|
| 21 |
+
LABEL description="Base image for OpenEnv based environment servers"
|
| 22 |
+
LABEL version="0.1.0"
|
| 23 |
+
|
| 24 |
+
# Install system dependencies
|
| 25 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 26 |
+
curl \
|
| 27 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 28 |
+
|
| 29 |
+
# Install Python dependencies that all environments need
|
| 30 |
+
RUN pip install --no-cache-dir \
|
| 31 |
+
fastapi>=0.104.0 \
|
| 32 |
+
"uvicorn[standard]>=0.24.0" \
|
| 33 |
+
requests>=2.25.0 \
|
| 34 |
+
wsproto>=1.0.0
|
| 35 |
+
|
| 36 |
+
# Set working directory
|
| 37 |
+
WORKDIR /app
|
| 38 |
+
|
| 39 |
+
# Default environment variables
|
| 40 |
+
ENV PYTHONPATH=/app/src
|
| 41 |
+
ENV PYTHONUNBUFFERED=1
|
| 42 |
+
|
| 43 |
+
# Default expose port (can be overridden)
|
| 44 |
+
EXPOSE 8000
|
| 45 |
+
|
| 46 |
+
# Note: CMD should be specified in child Dockerfiles
|
src/core/containers/images/README.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenEnv Base Image
|
| 2 |
+
|
| 3 |
+
Standard base image for all OpenEnv environment servers.
|
| 4 |
+
|
| 5 |
+
## What's Included
|
| 6 |
+
|
| 7 |
+
| Layer | Size | Contents |
|
| 8 |
+
|-------|------|----------|
|
| 9 |
+
| python:3.11-slim | 200 MB | Base Python runtime |
|
| 10 |
+
| + Dependencies | 100 MB | FastAPI, uvicorn, requests |
|
| 11 |
+
| **Total** | **~300 MB** | Ready for environment servers |
|
| 12 |
+
|
| 13 |
+
## Image Sizes
|
| 14 |
+
|
| 15 |
+
```
|
| 16 |
+
openenv-base:latest 300 MB (python + fastapi + uvicorn)
|
| 17 |
+
```
|
| 18 |
+
echo-env:latest 500 MB (python + fastapi + uvicorn + app)
|
| 19 |
+
coding-env:latest 520 MB (python + fastapi + uvicorn + app + tools)
|
| 20 |
+
another-env:latest 510 MB (python + fastapi + uvicorn + app)
|
| 21 |
+
---
|
| 22 |
+
Total: 1.5 GB (with lots of duplication)
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### With Base Images (✅ Solution)
|
| 26 |
+
```
|
| 27 |
+
openenv-base:latest 300 MB (python + fastapi + uvicorn)
|
| 28 |
+
echo-env:latest 50 MB (app only, uses base)
|
| 29 |
+
coding-env:latest 70 MB (app + tools, uses base)
|
| 30 |
+
another-env:latest 45 MB (app only, uses base)
|
| 31 |
+
---
|
| 32 |
+
Total: 465 MB (base shared, minimal duplication)
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Building the Base Image
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
# From project root
|
| 39 |
+
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Usage in Environment Dockerfiles
|
| 43 |
+
|
| 44 |
+
Each environment Dockerfile should start with:
|
| 45 |
+
|
| 46 |
+
```dockerfile
|
| 47 |
+
FROM openenv-base:latest
|
| 48 |
+
|
| 49 |
+
# Copy only environment-specific files
|
| 50 |
+
COPY src/core/ /app/src/core/
|
| 51 |
+
COPY src/envs/my_env/ /app/src/envs/my_env/
|
| 52 |
+
|
| 53 |
+
# Run the server
|
| 54 |
+
CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Base Image Contents
|
| 58 |
+
|
| 59 |
+
- Python 3.11-slim
|
| 60 |
+
- FastAPI >= 0.104.0
|
| 61 |
+
- Uvicorn >= 0.24.0
|
| 62 |
+
- Requests >= 2.25.0
|
| 63 |
+
- curl (for health checks)
|
| 64 |
+
|
| 65 |
+
## Example: Building Echo Environment
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
# Step 1: Build base image (do this once)
|
| 69 |
+
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 70 |
+
|
| 71 |
+
# Step 2: Build echo environment (uses base)
|
| 72 |
+
docker build -t echo-env:latest -f src/envs/echo_env/server/Dockerfile .
|
| 73 |
+
|
| 74 |
+
# Step 3: Run echo environment
|
| 75 |
+
docker run -p 8000:8000 echo-env:latest
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Updating the Base
|
| 79 |
+
|
| 80 |
+
When dependencies need updating:
|
| 81 |
+
|
| 82 |
+
1. Update `src/core/containers/images/Dockerfile`
|
| 83 |
+
2. Rebuild base image
|
| 84 |
+
3. Rebuild all environment images (they'll use new base)
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
# Update base
|
| 88 |
+
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 89 |
+
|
| 90 |
+
# Rebuild environments (they automatically use new base)
|
| 91 |
+
docker build -t echo-env:latest -f src/envs/echo_env/server/Dockerfile .
|
| 92 |
+
```
|
src/core/containers/runtime/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Container runtime providers."""
|
| 8 |
+
|
| 9 |
+
from .providers import ContainerProvider, KubernetesProvider, LocalDockerProvider
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"ContainerProvider",
|
| 13 |
+
"LocalDockerProvider",
|
| 14 |
+
"KubernetesProvider",
|
| 15 |
+
]
|
src/core/containers/runtime/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (389 Bytes). View file
|
|
|
src/core/containers/runtime/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (375 Bytes). View file
|
|
|
src/core/containers/runtime/__pycache__/providers.cpython-311.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
src/core/containers/runtime/__pycache__/providers.cpython-313.pyc
ADDED
|
Binary file (9.64 kB). View file
|
|
|
src/core/containers/runtime/providers.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Container provider abstractions for running environment servers.
|
| 9 |
+
|
| 10 |
+
This module provides a pluggable architecture for different container providers
|
| 11 |
+
(local Docker, Kubernetes, cloud providers, etc.) to be used with HTTPEnvClient.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from typing import Any, Dict, Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ContainerProvider(ABC):
|
| 21 |
+
"""
|
| 22 |
+
Abstract base class for container providers.
|
| 23 |
+
|
| 24 |
+
Providers implement this interface to support different container platforms:
|
| 25 |
+
- LocalDockerProvider: Runs containers on local Docker daemon
|
| 26 |
+
- KubernetesProvider: Runs containers in Kubernetes cluster
|
| 27 |
+
- FargateProvider: Runs containers on AWS Fargate
|
| 28 |
+
- CloudRunProvider: Runs containers on Google Cloud Run
|
| 29 |
+
|
| 30 |
+
The provider manages a single container lifecycle and provides the base URL
|
| 31 |
+
for connecting to it.
|
| 32 |
+
|
| 33 |
+
Example:
|
| 34 |
+
>>> provider = LocalDockerProvider()
|
| 35 |
+
>>> base_url = provider.start_container("echo-env:latest")
|
| 36 |
+
>>> print(base_url) # http://localhost:8000
|
| 37 |
+
>>> # Use the environment via base_url
|
| 38 |
+
>>> provider.stop_container()
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def start_container(
|
| 43 |
+
self,
|
| 44 |
+
image: str,
|
| 45 |
+
port: Optional[int] = None,
|
| 46 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 47 |
+
**kwargs: Any,
|
| 48 |
+
) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Start a container from the specified image.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
image: Container image name (e.g., "echo-env:latest")
|
| 54 |
+
port: Port to expose (if None, provider chooses)
|
| 55 |
+
env_vars: Environment variables to pass to container
|
| 56 |
+
**kwargs: Provider-specific options
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Base URL to connect to the container (e.g., "http://localhost:8000")
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
RuntimeError: If container fails to start
|
| 63 |
+
"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
@abstractmethod
|
| 67 |
+
def stop_container(self) -> None:
|
| 68 |
+
"""
|
| 69 |
+
Stop and remove the running container.
|
| 70 |
+
|
| 71 |
+
This cleans up the container that was started by start_container().
|
| 72 |
+
"""
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
|
| 77 |
+
"""
|
| 78 |
+
Wait for the container to be ready to accept requests.
|
| 79 |
+
|
| 80 |
+
This typically polls the /health endpoint until it returns 200.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
base_url: Base URL of the container
|
| 84 |
+
timeout_s: Maximum time to wait
|
| 85 |
+
|
| 86 |
+
Raises:
|
| 87 |
+
TimeoutError: If container doesn't become ready in time
|
| 88 |
+
"""
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class LocalDockerProvider(ContainerProvider):
|
| 93 |
+
"""
|
| 94 |
+
Container provider for local Docker daemon.
|
| 95 |
+
|
| 96 |
+
This provider runs containers on the local machine using Docker.
|
| 97 |
+
Useful for development and testing.
|
| 98 |
+
|
| 99 |
+
Example:
|
| 100 |
+
>>> provider = LocalDockerProvider()
|
| 101 |
+
>>> base_url = provider.start_container("echo-env:latest")
|
| 102 |
+
>>> # Container running on http://localhost:<random-port>
|
| 103 |
+
>>> provider.stop_container()
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self):
|
| 107 |
+
"""Initialize the local Docker provider."""
|
| 108 |
+
self._container_id: Optional[str] = None
|
| 109 |
+
self._container_name: Optional[str] = None
|
| 110 |
+
|
| 111 |
+
# Check if Docker is available
|
| 112 |
+
import subprocess
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
subprocess.run(
|
| 116 |
+
["docker", "version"],
|
| 117 |
+
check=True,
|
| 118 |
+
capture_output=True,
|
| 119 |
+
timeout=5,
|
| 120 |
+
)
|
| 121 |
+
except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired):
|
| 122 |
+
raise RuntimeError(
|
| 123 |
+
"Docker is not available. Please install Docker Desktop or Docker Engine."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def start_container(
|
| 127 |
+
self,
|
| 128 |
+
image: str,
|
| 129 |
+
port: Optional[int] = None,
|
| 130 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 131 |
+
**kwargs: Any,
|
| 132 |
+
) -> str:
|
| 133 |
+
"""
|
| 134 |
+
Start a Docker container locally.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
image: Docker image name
|
| 138 |
+
port: Port to expose (if None, finds available port)
|
| 139 |
+
env_vars: Environment variables for the container
|
| 140 |
+
**kwargs: Additional Docker run options
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Base URL to connect to the container
|
| 144 |
+
"""
|
| 145 |
+
import subprocess
|
| 146 |
+
import time
|
| 147 |
+
|
| 148 |
+
# Find available port if not specified
|
| 149 |
+
if port is None:
|
| 150 |
+
port = self._find_available_port()
|
| 151 |
+
|
| 152 |
+
# Generate container name
|
| 153 |
+
self._container_name = self._generate_container_name(image)
|
| 154 |
+
|
| 155 |
+
# Build docker run command
|
| 156 |
+
cmd = [
|
| 157 |
+
"docker", "run",
|
| 158 |
+
"-d", # Detached
|
| 159 |
+
"--name", self._container_name,
|
| 160 |
+
"-p", f"{port}:8000", # Map port
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
# Add environment variables
|
| 164 |
+
if env_vars:
|
| 165 |
+
for key, value in env_vars.items():
|
| 166 |
+
cmd.extend(["-e", f"{key}={value}"])
|
| 167 |
+
|
| 168 |
+
# Add image
|
| 169 |
+
cmd.append(image)
|
| 170 |
+
|
| 171 |
+
# Run container
|
| 172 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 173 |
+
self._container_id = result.stdout.strip()
|
| 174 |
+
|
| 175 |
+
# Wait a moment for container to start
|
| 176 |
+
time.sleep(1)
|
| 177 |
+
|
| 178 |
+
base_url = f"http://localhost:{port}"
|
| 179 |
+
return base_url
|
| 180 |
+
|
| 181 |
+
def stop_container(self) -> None:
|
| 182 |
+
"""
|
| 183 |
+
Stop and remove the Docker container.
|
| 184 |
+
"""
|
| 185 |
+
if self._container_id is None:
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
import subprocess
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
# Stop container
|
| 192 |
+
subprocess.run(
|
| 193 |
+
["docker", "stop", self._container_id],
|
| 194 |
+
capture_output=True,
|
| 195 |
+
check=True,
|
| 196 |
+
timeout=10,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Remove container
|
| 200 |
+
subprocess.run(
|
| 201 |
+
["docker", "rm", self._container_id],
|
| 202 |
+
capture_output=True,
|
| 203 |
+
check=True,
|
| 204 |
+
timeout=10,
|
| 205 |
+
)
|
| 206 |
+
except subprocess.CalledProcessError:
|
| 207 |
+
# Container might already be stopped/removed
|
| 208 |
+
pass
|
| 209 |
+
finally:
|
| 210 |
+
self._container_id = None
|
| 211 |
+
self._container_name = None
|
| 212 |
+
|
| 213 |
+
def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
|
| 214 |
+
"""
|
| 215 |
+
Wait for container to be ready by polling /health endpoint.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
base_url: Base URL of the container
|
| 219 |
+
timeout_s: Maximum time to wait
|
| 220 |
+
|
| 221 |
+
Raises:
|
| 222 |
+
TimeoutError: If container doesn't become ready
|
| 223 |
+
"""
|
| 224 |
+
import time
|
| 225 |
+
import requests
|
| 226 |
+
|
| 227 |
+
start_time = time.time()
|
| 228 |
+
health_url = f"{base_url}/health"
|
| 229 |
+
|
| 230 |
+
while time.time() - start_time < timeout_s:
|
| 231 |
+
try:
|
| 232 |
+
response = requests.get(health_url, timeout=2.0)
|
| 233 |
+
if response.status_code == 200:
|
| 234 |
+
return
|
| 235 |
+
except requests.RequestException:
|
| 236 |
+
pass
|
| 237 |
+
|
| 238 |
+
time.sleep(0.5)
|
| 239 |
+
|
| 240 |
+
raise TimeoutError(
|
| 241 |
+
f"Container at {base_url} did not become ready within {timeout_s}s"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def _find_available_port(self) -> int:
|
| 245 |
+
"""
|
| 246 |
+
Find an available port on localhost.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
An available port number
|
| 250 |
+
"""
|
| 251 |
+
import socket
|
| 252 |
+
|
| 253 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 254 |
+
s.bind(("", 0))
|
| 255 |
+
s.listen(1)
|
| 256 |
+
port = s.getsockname()[1]
|
| 257 |
+
return port
|
| 258 |
+
|
| 259 |
+
def _generate_container_name(self, image: str) -> str:
|
| 260 |
+
"""
|
| 261 |
+
Generate a unique container name based on image name and timestamp.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
image: Docker image name
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
A unique container name
|
| 268 |
+
"""
|
| 269 |
+
import time
|
| 270 |
+
|
| 271 |
+
clean_image = image.split("/")[-1].split(":")[0]
|
| 272 |
+
timestamp = int(time.time() * 1000)
|
| 273 |
+
return f"{clean_image}-{timestamp}"
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class KubernetesProvider(ContainerProvider):
|
| 277 |
+
"""
|
| 278 |
+
Container provider for Kubernetes clusters.
|
| 279 |
+
|
| 280 |
+
This provider creates pods in a Kubernetes cluster and exposes them
|
| 281 |
+
via services or port-forwarding.
|
| 282 |
+
|
| 283 |
+
Example:
|
| 284 |
+
>>> provider = KubernetesProvider(namespace="envtorch-dev")
|
| 285 |
+
>>> base_url = provider.start_container("echo-env:latest")
|
| 286 |
+
>>> # Pod running in k8s, accessible via service or port-forward
|
| 287 |
+
>>> provider.stop_container()
|
| 288 |
+
"""
|
| 289 |
+
pass
|
src/core/containers/test_local_docker_provider.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
End-to-end test for LocalDockerProvider.
|
| 4 |
+
|
| 5 |
+
This script tests the complete flow:
|
| 6 |
+
1. Start a container using LocalDockerProvider
|
| 7 |
+
2. Wait for it to be ready
|
| 8 |
+
3. Make HTTP requests to test the environment
|
| 9 |
+
4. Clean up the container
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# Add src to path
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 17 |
+
|
| 18 |
+
import requests
|
| 19 |
+
|
| 20 |
+
from core.containers.runtime import LocalDockerProvider
|
| 21 |
+
|
| 22 |
+
# TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env
|
| 23 |
+
def test_local_docker_provider():
|
| 24 |
+
"""Test LocalDockerProvider end-to-end."""
|
| 25 |
+
print("=" * 60)
|
| 26 |
+
print("LocalDockerProvider End-to-End Test")
|
| 27 |
+
print("=" * 60)
|
| 28 |
+
print()
|
| 29 |
+
|
| 30 |
+
provider = None
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
# Step 1: Create provider
|
| 34 |
+
print("Step 1: Creating LocalDockerProvider...")
|
| 35 |
+
provider = LocalDockerProvider()
|
| 36 |
+
print("✓ Provider created\n")
|
| 37 |
+
|
| 38 |
+
# Step 2: Start container
|
| 39 |
+
print("Step 2: Starting echo-env container...")
|
| 40 |
+
base_url = provider.start_container("echo-env:latest")
|
| 41 |
+
print(f"✓ Container started at: {base_url}")
|
| 42 |
+
if provider._container_id:
|
| 43 |
+
print(f" Container ID: {provider._container_id[:12]}...")
|
| 44 |
+
if provider._container_name:
|
| 45 |
+
print(f" Container name: {provider._container_name}\n")
|
| 46 |
+
|
| 47 |
+
# Step 3: Wait for ready
|
| 48 |
+
print("Step 3: Waiting for container to be ready...")
|
| 49 |
+
provider.wait_for_ready(base_url, timeout_s=30.0)
|
| 50 |
+
print("✓ Container is ready!\n")
|
| 51 |
+
|
| 52 |
+
# Step 4: Test health endpoint
|
| 53 |
+
print("Step 4: Testing /health endpoint...")
|
| 54 |
+
response = requests.get(f"{base_url}/health")
|
| 55 |
+
print(f" Status: {response.status_code}")
|
| 56 |
+
print(f" Response: {response.json()}")
|
| 57 |
+
assert response.status_code == 200
|
| 58 |
+
assert response.json()["status"] == "healthy"
|
| 59 |
+
print("✓ Health check passed\n")
|
| 60 |
+
|
| 61 |
+
# Step 5: Test reset endpoint
|
| 62 |
+
print("Step 5: Testing /reset endpoint...")
|
| 63 |
+
response = requests.post(
|
| 64 |
+
f"{base_url}/reset",
|
| 65 |
+
json={},
|
| 66 |
+
headers={"Content-Type": "application/json"},
|
| 67 |
+
)
|
| 68 |
+
print(f" Status: {response.status_code}")
|
| 69 |
+
data = response.json()
|
| 70 |
+
print(f" Message: {data['observation']['echoed_message']}")
|
| 71 |
+
print(f" Reward: {data['reward']}")
|
| 72 |
+
print(f" Done: {data['done']}")
|
| 73 |
+
assert response.status_code == 200
|
| 74 |
+
assert data["observation"]["echoed_message"] == "Echo environment ready!"
|
| 75 |
+
print("✓ Reset test passed\n")
|
| 76 |
+
|
| 77 |
+
# Step 6: Test step endpoint
|
| 78 |
+
print("Step 6: Testing /step endpoint...")
|
| 79 |
+
response = requests.post(
|
| 80 |
+
f"{base_url}/step",
|
| 81 |
+
json={"action": {"message": "Hello from LocalDockerProvider!"}},
|
| 82 |
+
headers={"Content-Type": "application/json"},
|
| 83 |
+
)
|
| 84 |
+
print(f" Status: {response.status_code}")
|
| 85 |
+
data = response.json()
|
| 86 |
+
print(f" Echoed: {data['observation']['echoed_message']}")
|
| 87 |
+
print(f" Length: {data['observation']['message_length']}")
|
| 88 |
+
print(f" Reward: {data['reward']}")
|
| 89 |
+
assert response.status_code == 200
|
| 90 |
+
assert data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!"
|
| 91 |
+
assert data["observation"]["message_length"] == 31
|
| 92 |
+
print("✓ Step test passed\n")
|
| 93 |
+
|
| 94 |
+
# Step 7: Test state endpoint
|
| 95 |
+
print("Step 7: Testing /state endpoint...")
|
| 96 |
+
response = requests.get(f"{base_url}/state")
|
| 97 |
+
print(f" Status: {response.status_code}")
|
| 98 |
+
data = response.json()
|
| 99 |
+
print(f" Episode ID: {data['episode_id']}")
|
| 100 |
+
print(f" Step count: {data['step_count']}")
|
| 101 |
+
assert response.status_code == 200
|
| 102 |
+
assert data["step_count"] == 1 # One step from above
|
| 103 |
+
print("✓ State test passed\n")
|
| 104 |
+
|
| 105 |
+
# Step 8: Multiple steps
|
| 106 |
+
print("Step 8: Testing multiple steps...")
|
| 107 |
+
for i in range(3):
|
| 108 |
+
response = requests.post(
|
| 109 |
+
f"{base_url}/step",
|
| 110 |
+
json={"action": {"message": f"Message {i+1}"}},
|
| 111 |
+
headers={"Content-Type": "application/json"},
|
| 112 |
+
)
|
| 113 |
+
assert response.status_code == 200
|
| 114 |
+
print(f" Step {i+1}: ✓")
|
| 115 |
+
|
| 116 |
+
# Check state updated
|
| 117 |
+
response = requests.get(f"{base_url}/state")
|
| 118 |
+
data = response.json()
|
| 119 |
+
assert data["step_count"] == 4 # 1 + 3 more steps
|
| 120 |
+
print(f" Final step count: {data['step_count']}")
|
| 121 |
+
print("✓ Multiple steps test passed\n")
|
| 122 |
+
|
| 123 |
+
print("=" * 60)
|
| 124 |
+
print("✓ All tests passed!")
|
| 125 |
+
print("=" * 60)
|
| 126 |
+
print()
|
| 127 |
+
|
| 128 |
+
return True
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"\n❌ Test failed: {e}")
|
| 132 |
+
import traceback
|
| 133 |
+
traceback.print_exc()
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
finally:
|
| 137 |
+
# Step 9: Cleanup
|
| 138 |
+
if provider is not None:
|
| 139 |
+
print("\nStep 9: Cleaning up container...")
|
| 140 |
+
try:
|
| 141 |
+
provider.stop_container()
|
| 142 |
+
print("✓ Container stopped and removed\n")
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"⚠️ Cleanup warning: {e}\n")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_provider_with_custom_port():
|
| 148 |
+
"""Test provider with custom port."""
|
| 149 |
+
print("=" * 60)
|
| 150 |
+
print("LocalDockerProvider with Custom Port Test")
|
| 151 |
+
print("=" * 60)
|
| 152 |
+
print()
|
| 153 |
+
|
| 154 |
+
provider = None
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
provider = LocalDockerProvider()
|
| 158 |
+
|
| 159 |
+
print("Starting container on custom port 8123...")
|
| 160 |
+
base_url = provider.start_container("echo-env:latest", port=8123)
|
| 161 |
+
print(f"✓ Started at: {base_url}")
|
| 162 |
+
assert ":8123" in base_url
|
| 163 |
+
|
| 164 |
+
print("Waiting for ready...")
|
| 165 |
+
provider.wait_for_ready(base_url)
|
| 166 |
+
print("✓ Ready!")
|
| 167 |
+
|
| 168 |
+
print("Testing health...")
|
| 169 |
+
response = requests.get(f"{base_url}/health")
|
| 170 |
+
assert response.status_code == 200
|
| 171 |
+
print("✓ Health check passed")
|
| 172 |
+
|
| 173 |
+
print("\n✓ Custom port test passed!\n")
|
| 174 |
+
return True
|
| 175 |
+
|
| 176 |
+
except Exception as e:
|
| 177 |
+
print(f"\n❌ Test failed: {e}")
|
| 178 |
+
return False
|
| 179 |
+
|
| 180 |
+
finally:
|
| 181 |
+
if provider is not None:
|
| 182 |
+
provider.stop_container()
|
| 183 |
+
print("✓ Cleaned up\n")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def test_provider_with_env_vars():
|
| 187 |
+
"""Test provider with environment variables."""
|
| 188 |
+
print("=" * 60)
|
| 189 |
+
print("LocalDockerProvider with Environment Variables Test")
|
| 190 |
+
print("=" * 60)
|
| 191 |
+
print()
|
| 192 |
+
|
| 193 |
+
provider = None
|
| 194 |
+
|
| 195 |
+
try:
|
| 196 |
+
provider = LocalDockerProvider()
|
| 197 |
+
|
| 198 |
+
print("Starting container with environment variables...")
|
| 199 |
+
base_url = provider.start_container(
|
| 200 |
+
"echo-env:latest",
|
| 201 |
+
env_vars={"DEBUG": "true", "LOG_LEVEL": "info"}
|
| 202 |
+
)
|
| 203 |
+
print(f"✓ Started at: {base_url}")
|
| 204 |
+
|
| 205 |
+
print("Waiting for ready...")
|
| 206 |
+
provider.wait_for_ready(base_url)
|
| 207 |
+
print("✓ Ready!")
|
| 208 |
+
|
| 209 |
+
print("Testing health...")
|
| 210 |
+
response = requests.get(f"{base_url}/health")
|
| 211 |
+
assert response.status_code == 200
|
| 212 |
+
print("✓ Health check passed")
|
| 213 |
+
|
| 214 |
+
print("\n✓ Environment variables test passed!\n")
|
| 215 |
+
return True
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"\n❌ Test failed: {e}")
|
| 219 |
+
return False
|
| 220 |
+
|
| 221 |
+
finally:
|
| 222 |
+
if provider is not None:
|
| 223 |
+
provider.stop_container()
|
| 224 |
+
print("✓ Cleaned up\n")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
print()
|
| 229 |
+
print("🐳 LocalDockerProvider Test Suite")
|
| 230 |
+
print()
|
| 231 |
+
|
| 232 |
+
results = []
|
| 233 |
+
|
| 234 |
+
# Run basic test
|
| 235 |
+
results.append(("Basic End-to-End", test_local_docker_provider()))
|
| 236 |
+
|
| 237 |
+
# Run custom port test
|
| 238 |
+
results.append(("Custom Port", test_provider_with_custom_port()))
|
| 239 |
+
|
| 240 |
+
# Run environment variables test
|
| 241 |
+
results.append(("Environment Variables", test_provider_with_env_vars()))
|
| 242 |
+
|
| 243 |
+
# Summary
|
| 244 |
+
print("=" * 60)
|
| 245 |
+
print("Test Summary")
|
| 246 |
+
print("=" * 60)
|
| 247 |
+
for name, passed in results:
|
| 248 |
+
status = "✓ PASSED" if passed else "✗ FAILED"
|
| 249 |
+
print(f"{name:25} {status}")
|
| 250 |
+
print("=" * 60)
|
| 251 |
+
|
| 252 |
+
all_passed = all(result for _, result in results)
|
| 253 |
+
if all_passed:
|
| 254 |
+
print("\n🎉 All tests passed!")
|
| 255 |
+
exit(0)
|
| 256 |
+
else:
|
| 257 |
+
print("\n❌ Some tests failed")
|
| 258 |
+
exit(1)
|
src/core/env_server/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Core environment interfaces and types."""
|
| 8 |
+
|
| 9 |
+
from .base_transforms import CompositeTransform, NullTransform
|
| 10 |
+
from .http_server import HTTPEnvServer, create_app, create_fastapi_app
|
| 11 |
+
from .interfaces import Environment, Message, ModelTokenizer, Transform
|
| 12 |
+
from .types import Action, Observation, State
|
| 13 |
+
from .web_interface import create_web_interface_app, WebInterfaceManager
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
# Core interfaces
|
| 17 |
+
"Environment",
|
| 18 |
+
"Transform",
|
| 19 |
+
"Message",
|
| 20 |
+
"ModelTokenizer",
|
| 21 |
+
# Types
|
| 22 |
+
"Action",
|
| 23 |
+
"Observation",
|
| 24 |
+
"State",
|
| 25 |
+
# Base transforms
|
| 26 |
+
"CompositeTransform",
|
| 27 |
+
"NullTransform",
|
| 28 |
+
# HTTP Server
|
| 29 |
+
"HTTPEnvServer",
|
| 30 |
+
"create_app",
|
| 31 |
+
"create_fastapi_app",
|
| 32 |
+
# Web Interface
|
| 33 |
+
"create_web_interface_app",
|
| 34 |
+
"WebInterfaceManager",
|
| 35 |
+
]
|
src/core/env_server/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (898 Bytes). View file
|
|
|
src/core/env_server/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (788 Bytes). View file
|
|
|
src/core/env_server/__pycache__/base_transforms.cpython-311.pyc
ADDED
|
Binary file (1.67 kB). View file
|
|
|
src/core/env_server/__pycache__/base_transforms.cpython-313.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
src/core/env_server/__pycache__/http_server.cpython-311.pyc
ADDED
|
Binary file (9.2 kB). View file
|
|
|
src/core/env_server/__pycache__/http_server.cpython-313.pyc
ADDED
|
Binary file (8.33 kB). View file
|
|
|
src/core/env_server/__pycache__/interfaces.cpython-311.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
src/core/env_server/__pycache__/interfaces.cpython-313.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
src/core/env_server/__pycache__/types.cpython-311.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
src/core/env_server/__pycache__/types.cpython-313.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
src/core/env_server/__pycache__/web_interface.cpython-311.pyc
ADDED
|
Binary file (29.9 kB). View file
|
|
|
src/core/env_server/__pycache__/web_interface.cpython-313.pyc
ADDED
|
Binary file (59.3 kB). View file
|
|
|
src/core/env_server/base_transforms.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Base transform implementations for composing environment-specific transforms."""
|
| 8 |
+
|
| 9 |
+
from .interfaces import Transform
|
| 10 |
+
from .types import Observation
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CompositeTransform(Transform):
|
| 14 |
+
"""Combines multiple transforms into a single transform."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, transforms: list[Transform]):
|
| 17 |
+
self.transforms = transforms
|
| 18 |
+
|
| 19 |
+
def __call__(self, observation: Observation) -> Observation:
|
| 20 |
+
for transform in self.transforms:
|
| 21 |
+
observation = transform(observation)
|
| 22 |
+
return observation
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class NullTransform(Transform):
|
| 26 |
+
"""Default transform that passes through unchanged."""
|
| 27 |
+
|
| 28 |
+
def __call__(self, observation: Observation) -> Observation:
|
| 29 |
+
return observation
|
src/core/env_server/http_server.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
HTTP server wrapper for Environment instances.
|
| 9 |
+
|
| 10 |
+
This module provides utilities to wrap any Environment subclass and expose it
|
| 11 |
+
over HTTP endpoints that HTTPEnvClient can consume.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from dataclasses import asdict
|
| 18 |
+
from typing import Any, Dict, Type
|
| 19 |
+
|
| 20 |
+
from .interfaces import Environment
|
| 21 |
+
from .types import Action, Observation
|
| 22 |
+
from fastapi import Body, FastAPI
|
| 23 |
+
|
| 24 |
+
class HTTPEnvServer:
|
| 25 |
+
"""
|
| 26 |
+
HTTP server wrapper for Environment instances.
|
| 27 |
+
|
| 28 |
+
This class wraps an Environment and exposes its reset(), step(), and state
|
| 29 |
+
methods as HTTP endpoints compatible with HTTPEnvClient.
|
| 30 |
+
|
| 31 |
+
The server expects:
|
| 32 |
+
- Action deserialization: Converts JSON dict to Action subclass
|
| 33 |
+
- Observation serialization: Converts Observation subclass to JSON dict
|
| 34 |
+
|
| 35 |
+
Example:
|
| 36 |
+
>>> from core.env_server import HTTPEnvServer
|
| 37 |
+
>>> from envs.coding_env.server import CodeExecutionEnvironment
|
| 38 |
+
>>>
|
| 39 |
+
>>> env = CodeExecutionEnvironment()
|
| 40 |
+
>>> server = HTTPEnvServer(env)
|
| 41 |
+
>>>
|
| 42 |
+
>>> # Register routes with FastAPI
|
| 43 |
+
>>> from fastapi import FastAPI
|
| 44 |
+
>>> app = FastAPI()
|
| 45 |
+
>>> server.register_routes(app)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
env: Environment,
|
| 51 |
+
action_cls: Type[Action],
|
| 52 |
+
observation_cls: Type[Observation],
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Initialize HTTP server wrapper.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
env: The Environment instance to wrap
|
| 59 |
+
action_cls: The Action subclass this environment expects
|
| 60 |
+
observation_cls: The Observation subclass this environment returns
|
| 61 |
+
"""
|
| 62 |
+
self.env = env
|
| 63 |
+
self.action_cls = action_cls
|
| 64 |
+
self.observation_cls = observation_cls
|
| 65 |
+
|
| 66 |
+
def register_routes(self, app: Any) -> None:
|
| 67 |
+
"""
|
| 68 |
+
Register HTTP routes on a FastAPI application.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
app: FastAPI application instance
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
if not isinstance(app, FastAPI):
|
| 75 |
+
raise TypeError("app must be a FastAPI instance")
|
| 76 |
+
|
| 77 |
+
@app.post("/reset")
|
| 78 |
+
async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]:
|
| 79 |
+
"""Reset endpoint - returns initial observation."""
|
| 80 |
+
# TODO: Handle seed, episode_id from request if provided
|
| 81 |
+
observation = self.env.reset()
|
| 82 |
+
return self._serialize_observation(observation)
|
| 83 |
+
|
| 84 |
+
@app.post("/step")
|
| 85 |
+
async def step(request: Dict[str, Any]) -> Dict[str, Any]:
|
| 86 |
+
"""Step endpoint - executes action and returns observation."""
|
| 87 |
+
action_data = request.get("action", {})
|
| 88 |
+
# TODO: Handle timeout_s, request_id, episode_id from request if provided
|
| 89 |
+
|
| 90 |
+
# Deserialize action
|
| 91 |
+
action = self._deserialize_action(action_data)
|
| 92 |
+
|
| 93 |
+
# Execute step
|
| 94 |
+
observation = self.env.step(action)
|
| 95 |
+
|
| 96 |
+
# Return serialized observation
|
| 97 |
+
return self._serialize_observation(observation)
|
| 98 |
+
|
| 99 |
+
@app.get("/state")
|
| 100 |
+
async def get_state() -> Dict[str, Any]:
|
| 101 |
+
"""State endpoint - returns current environment state."""
|
| 102 |
+
state = self.env.state
|
| 103 |
+
return asdict(state)
|
| 104 |
+
|
| 105 |
+
@app.get("/health")
|
| 106 |
+
async def health() -> Dict[str, str]:
|
| 107 |
+
"""Health check endpoint."""
|
| 108 |
+
return {"status": "healthy"}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _deserialize_action(self, action_data: Dict[str, Any]) -> Action:
|
| 112 |
+
"""
|
| 113 |
+
Convert JSON dict to Action instance.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
action_data: Dictionary containing action data
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Action instance
|
| 120 |
+
|
| 121 |
+
Note:
|
| 122 |
+
This is a simple implementation. Subclasses may need to override
|
| 123 |
+
for more complex deserialization logic.
|
| 124 |
+
"""
|
| 125 |
+
# Remove metadata if present (it will be set via kw_only field)
|
| 126 |
+
metadata = action_data.pop("metadata", {})
|
| 127 |
+
action = self.action_cls(**action_data)
|
| 128 |
+
action.metadata = metadata
|
| 129 |
+
return action
|
| 130 |
+
|
| 131 |
+
def _serialize_observation(self, observation: Observation) -> Dict[str, Any]:
|
| 132 |
+
"""
|
| 133 |
+
Convert Observation instance to JSON-compatible dict.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
observation: Observation instance
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Dictionary compatible with HTTPEnvClient._parse_result()
|
| 140 |
+
|
| 141 |
+
The format matches what HTTPEnvClient expects:
|
| 142 |
+
{
|
| 143 |
+
"observation": {...}, # Observation fields
|
| 144 |
+
"reward": float | None,
|
| 145 |
+
"done": bool,
|
| 146 |
+
}
|
| 147 |
+
"""
|
| 148 |
+
obs_dict = asdict(observation)
|
| 149 |
+
|
| 150 |
+
# Extract reward and done (these are part of StepResult on client side)
|
| 151 |
+
reward = obs_dict.pop("reward", None)
|
| 152 |
+
done = obs_dict.pop("done", False)
|
| 153 |
+
obs_dict.pop("metadata", None) # Remove metadata from observation
|
| 154 |
+
|
| 155 |
+
# Return in HTTPEnvClient expected format
|
| 156 |
+
return {
|
| 157 |
+
"observation": obs_dict,
|
| 158 |
+
"reward": reward,
|
| 159 |
+
"done": done,
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def create_app(
|
| 163 |
+
env: Environment,
|
| 164 |
+
action_cls: Type[Action],
|
| 165 |
+
observation_cls: Type[Observation],
|
| 166 |
+
env_name: Optional[str] = None,
|
| 167 |
+
) -> Any:
|
| 168 |
+
"""
|
| 169 |
+
Create a FastAPI application with or without web interface.
|
| 170 |
+
|
| 171 |
+
This function creates a FastAPI app with the web interface enabled by default,
|
| 172 |
+
including README integration for better user experience.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
env: The Environment instance to serve
|
| 176 |
+
action_cls: The Action subclass this environment expects
|
| 177 |
+
observation_cls: The Observation subclass this environment returns
|
| 178 |
+
env_name: Optional environment name for README loading
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
FastAPI application instance with or without web interface and README integration
|
| 182 |
+
"""
|
| 183 |
+
# Check if web interface should be enabled
|
| 184 |
+
# This can be controlled via environment variable or build argument
|
| 185 |
+
enable_web = (
|
| 186 |
+
os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes")
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if enable_web:
|
| 190 |
+
# Import web interface only when needed
|
| 191 |
+
from .web_interface import create_web_interface_app
|
| 192 |
+
return create_web_interface_app(env, action_cls, observation_cls, env_name)
|
| 193 |
+
else:
|
| 194 |
+
# Use standard FastAPI app without web interface
|
| 195 |
+
return create_fastapi_app(env, action_cls, observation_cls)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def create_fastapi_app(
|
| 199 |
+
env: Environment,
|
| 200 |
+
action_cls: Type[Action],
|
| 201 |
+
observation_cls: Type[Observation],
|
| 202 |
+
) -> Any:
|
| 203 |
+
"""
|
| 204 |
+
Create a FastAPI application with routes for the given environment.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
env: The Environment instance to serve
|
| 208 |
+
action_cls: The Action subclass this environment expects
|
| 209 |
+
observation_cls: The Observation subclass this environment returns
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
FastAPI application instance with routes registered
|
| 213 |
+
|
| 214 |
+
Example:
|
| 215 |
+
>>> from envs.coding_env.server import CodeExecutionEnvironment
|
| 216 |
+
>>> from envs.coding_env.models import CodeAction, CodeObservation
|
| 217 |
+
>>>
|
| 218 |
+
>>> env = CodeExecutionEnvironment()
|
| 219 |
+
>>> app = create_fastapi_app(env, CodeAction, CodeObservation)
|
| 220 |
+
>>>
|
| 221 |
+
>>> # Run with: uvicorn module:app --host 0.0.0.0 --port 8000
|
| 222 |
+
"""
|
| 223 |
+
try:
|
| 224 |
+
from fastapi import FastAPI
|
| 225 |
+
except ImportError:
|
| 226 |
+
raise ImportError(
|
| 227 |
+
"FastAPI is required. Install with: pip install fastapi uvicorn"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
app = FastAPI(title="Environment HTTP Server")
|
| 231 |
+
server = HTTPEnvServer(env, action_cls, observation_cls)
|
| 232 |
+
server.register_routes(app)
|
| 233 |
+
return app
|
src/core/env_server/interfaces.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Any, Protocol, TypedDict
|
| 9 |
+
|
| 10 |
+
from .types import Action, Observation, State
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Message(TypedDict):
|
| 14 |
+
"""A message in a conversation.
|
| 15 |
+
|
| 16 |
+
Compatible with Huggingface chat template format.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
role: str
|
| 20 |
+
content: str
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ModelTokenizer(Protocol):
|
| 24 |
+
"""Protocol for tokenizers that support chat templates.
|
| 25 |
+
|
| 26 |
+
This protocol defines the interface that tokenizers must implement
|
| 27 |
+
to work with chat-based environments. It's compatible with
|
| 28 |
+
Huggingface transformers tokenizers.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def apply_chat_template(
|
| 32 |
+
self,
|
| 33 |
+
conversation: list[Message],
|
| 34 |
+
tokenize: bool = True,
|
| 35 |
+
return_tensors: str | None = None,
|
| 36 |
+
**kwargs: Any,
|
| 37 |
+
) -> Any:
|
| 38 |
+
"""Apply a chat template to format and optionally tokenize a conversation.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
conversation: List of message dictionaries with 'role' and 'content'
|
| 42 |
+
tokenize: Whether to tokenize the output
|
| 43 |
+
return_tensors: Format for returned tensors ('pt' for PyTorch)
|
| 44 |
+
**kwargs: Additional arguments
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Formatted and optionally tokenized conversation
|
| 48 |
+
"""
|
| 49 |
+
...
|
| 50 |
+
|
| 51 |
+
def decode(
|
| 52 |
+
self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
|
| 53 |
+
) -> str:
|
| 54 |
+
"""Decode token IDs back to text.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
token_ids: Token IDs to decode
|
| 58 |
+
skip_special_tokens: Whether to skip special tokens in output
|
| 59 |
+
**kwargs: Additional arguments
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Decoded text string
|
| 63 |
+
"""
|
| 64 |
+
...
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Transform(ABC):
|
| 68 |
+
"""Transform observations to add rewards, metrics, or other modifications.
|
| 69 |
+
|
| 70 |
+
Transforms follow the TorchRL pattern where they take an observation
|
| 71 |
+
and return a (potentially modified) observation. This allows for
|
| 72 |
+
flexible reward computation and observation augmentation.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def __call__(self, observation: Observation) -> Observation:
|
| 77 |
+
"""Transform an observation.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
observation: The input observation
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
The transformed observation
|
| 84 |
+
"""
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Environment(ABC):
|
| 89 |
+
"""Base class for all environment servers following Gym/Gymnasium API.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
transform: Optional transform to apply to observations
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, transform: Transform | None = None):
|
| 96 |
+
self.transform = transform
|
| 97 |
+
|
| 98 |
+
@abstractmethod
|
| 99 |
+
def reset(self) -> Observation:
|
| 100 |
+
"""Reset the environment and return initial observation."""
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
@abstractmethod
|
| 104 |
+
def step(self, action: Action) -> Observation:
|
| 105 |
+
"""Take a step in the environment."""
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
@abstractmethod
|
| 110 |
+
def state(self) -> State:
|
| 111 |
+
"""Get the current environment state."""
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
def _apply_transform(self, observation: Observation) -> Observation:
|
| 115 |
+
"""Apply transform if one is provided."""
|
| 116 |
+
if self.transform is not None:
|
| 117 |
+
return self.transform(observation)
|
| 118 |
+
return observation
|
src/core/env_server/types.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Any, Dict, List, Optional, Union
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Type aliases
|
| 12 |
+
Scalar = Union[int, float, bool]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(kw_only=True)
|
| 16 |
+
class Action:
|
| 17 |
+
"""Base class for all environment actions."""
|
| 18 |
+
|
| 19 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(kw_only=True)
|
| 23 |
+
class Observation:
|
| 24 |
+
"""Base class for all environment observations."""
|
| 25 |
+
|
| 26 |
+
done: bool = False
|
| 27 |
+
reward: Union[bool, int, float, None] = None
|
| 28 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class State:
|
| 33 |
+
"""Base class for environment state."""
|
| 34 |
+
|
| 35 |
+
episode_id: Optional[str] = None
|
| 36 |
+
step_count: int = 0
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class CodeExecResult:
|
| 41 |
+
"""Result of code execution containing stdout, stderr, and exit code."""
|
| 42 |
+
|
| 43 |
+
stdout: str
|
| 44 |
+
stderr: str
|
| 45 |
+
exit_code: int
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class EnvironmentMetadata:
|
| 50 |
+
"""Metadata about an environment for documentation and UI purposes."""
|
| 51 |
+
|
| 52 |
+
name: str
|
| 53 |
+
description: str
|
| 54 |
+
readme_content: Optional[str] = None
|
| 55 |
+
version: Optional[str] = None
|
| 56 |
+
author: Optional[str] = None
|
| 57 |
+
documentation_url: Optional[str] = None
|
src/core/env_server/web_interface.py
ADDED
|
@@ -0,0 +1,1613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Web interface for OpenEnv environments.
|
| 9 |
+
|
| 10 |
+
This module provides a web-based interface for interacting with OpenEnv environments,
|
| 11 |
+
including a two-pane layout for HumanAgent interaction and state observation.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import time
|
| 18 |
+
from dataclasses import asdict, dataclass
|
| 19 |
+
from typing import Any, Dict, List, Optional, Type
|
| 20 |
+
from datetime import datetime
|
| 21 |
+
|
| 22 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
|
| 23 |
+
from fastapi.responses import HTMLResponse, FileResponse
|
| 24 |
+
from fastapi.staticfiles import StaticFiles
|
| 25 |
+
from pydantic import BaseModel
|
| 26 |
+
|
| 27 |
+
from .interfaces import Environment
|
| 28 |
+
from .types import Action, Observation, State, EnvironmentMetadata
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata:
|
| 32 |
+
"""
|
| 33 |
+
Load environment metadata including README content.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
env: The environment instance
|
| 37 |
+
env_name: Optional environment name for README file lookup
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
EnvironmentMetadata with loaded information
|
| 41 |
+
"""
|
| 42 |
+
# Try to get metadata from environment if it has a method for it
|
| 43 |
+
if hasattr(env, 'get_metadata'):
|
| 44 |
+
return env.get_metadata()
|
| 45 |
+
|
| 46 |
+
# Default metadata
|
| 47 |
+
metadata = EnvironmentMetadata(
|
| 48 |
+
name=env_name or env.__class__.__name__,
|
| 49 |
+
description=f"{env.__class__.__name__} environment",
|
| 50 |
+
version="1.0.0"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Try to load README from file system
|
| 54 |
+
readme_content = _load_readme_from_filesystem(env_name)
|
| 55 |
+
if readme_content:
|
| 56 |
+
metadata.readme_content = readme_content
|
| 57 |
+
|
| 58 |
+
return metadata
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
|
| 62 |
+
"""
|
| 63 |
+
Load README content from the filesystem.
|
| 64 |
+
|
| 65 |
+
Tries multiple locations:
|
| 66 |
+
1. Container filesystem: /app/README.md
|
| 67 |
+
2. Local development: src/envs/{env_name}/README.md
|
| 68 |
+
3. Environment variable: ENV_README_PATH
|
| 69 |
+
"""
|
| 70 |
+
import os
|
| 71 |
+
from pathlib import Path
|
| 72 |
+
|
| 73 |
+
# Try container filesystem first
|
| 74 |
+
container_readme = Path("/app/README.md")
|
| 75 |
+
if container_readme.exists():
|
| 76 |
+
try:
|
| 77 |
+
return container_readme.read_text(encoding='utf-8')
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
# Try environment variable path
|
| 82 |
+
custom_path = os.environ.get("ENV_README_PATH")
|
| 83 |
+
if custom_path and Path(custom_path).exists():
|
| 84 |
+
try:
|
| 85 |
+
return Path(custom_path).read_text(encoding='utf-8')
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
# Try local development path
|
| 90 |
+
if env_name:
|
| 91 |
+
local_readme = Path(f"src/envs/{env_name}/README.md")
|
| 92 |
+
if local_readme.exists():
|
| 93 |
+
try:
|
| 94 |
+
return local_readme.read_text(encoding='utf-8')
|
| 95 |
+
except Exception:
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@dataclass
|
| 102 |
+
class ActionLog:
|
| 103 |
+
"""Log entry for an action taken."""
|
| 104 |
+
timestamp: str
|
| 105 |
+
action: Dict[str, Any]
|
| 106 |
+
observation: Dict[str, Any]
|
| 107 |
+
reward: Optional[float]
|
| 108 |
+
done: bool
|
| 109 |
+
step_count: int
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class EpisodeState:
|
| 114 |
+
"""Current episode state for the web interface."""
|
| 115 |
+
episode_id: Optional[str]
|
| 116 |
+
step_count: int
|
| 117 |
+
current_observation: Optional[Dict[str, Any]]
|
| 118 |
+
action_logs: List[ActionLog]
|
| 119 |
+
is_reset: bool = True
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class WebInterfaceManager:
|
| 123 |
+
"""Manages the web interface for an environment."""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
env: Environment,
|
| 128 |
+
action_cls: Type[Action],
|
| 129 |
+
observation_cls: Type[Observation],
|
| 130 |
+
metadata: Optional[EnvironmentMetadata] = None,
|
| 131 |
+
):
|
| 132 |
+
self.env = env
|
| 133 |
+
self.action_cls = action_cls
|
| 134 |
+
self.observation_cls = observation_cls
|
| 135 |
+
self.metadata = metadata or EnvironmentMetadata(
|
| 136 |
+
name=env.__class__.__name__,
|
| 137 |
+
description=f"{env.__class__.__name__} environment"
|
| 138 |
+
)
|
| 139 |
+
self.episode_state = EpisodeState(
|
| 140 |
+
episode_id=None,
|
| 141 |
+
step_count=0,
|
| 142 |
+
current_observation=None,
|
| 143 |
+
action_logs=[]
|
| 144 |
+
)
|
| 145 |
+
self.connected_clients: List[WebSocket] = []
|
| 146 |
+
|
| 147 |
+
async def connect_websocket(self, websocket: WebSocket):
|
| 148 |
+
"""Connect a new WebSocket client."""
|
| 149 |
+
await websocket.accept()
|
| 150 |
+
self.connected_clients.append(websocket)
|
| 151 |
+
|
| 152 |
+
# Send current state to the new client
|
| 153 |
+
await self._send_state_update()
|
| 154 |
+
|
| 155 |
+
async def disconnect_websocket(self, websocket: WebSocket):
|
| 156 |
+
"""Disconnect a WebSocket client."""
|
| 157 |
+
if websocket in self.connected_clients:
|
| 158 |
+
self.connected_clients.remove(websocket)
|
| 159 |
+
|
| 160 |
+
async def _send_state_update(self):
|
| 161 |
+
"""Send current state to all connected clients."""
|
| 162 |
+
if not self.connected_clients:
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
state_data = {
|
| 166 |
+
"type": "state_update",
|
| 167 |
+
"episode_state": asdict(self.episode_state)
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
# Send to all connected clients
|
| 171 |
+
disconnected_clients = []
|
| 172 |
+
for client in self.connected_clients:
|
| 173 |
+
try:
|
| 174 |
+
await client.send_text(json.dumps(state_data))
|
| 175 |
+
except:
|
| 176 |
+
disconnected_clients.append(client)
|
| 177 |
+
|
| 178 |
+
# Remove disconnected clients
|
| 179 |
+
for client in disconnected_clients:
|
| 180 |
+
self.connected_clients.remove(client)
|
| 181 |
+
|
| 182 |
+
async def reset_environment(self) -> Dict[str, Any]:
|
| 183 |
+
"""Reset the environment and update state."""
|
| 184 |
+
observation = self.env.reset()
|
| 185 |
+
state = self.env.state
|
| 186 |
+
|
| 187 |
+
# Update episode state
|
| 188 |
+
self.episode_state.episode_id = state.episode_id
|
| 189 |
+
self.episode_state.step_count = 0
|
| 190 |
+
self.episode_state.current_observation = asdict(observation)
|
| 191 |
+
self.episode_state.action_logs = []
|
| 192 |
+
self.episode_state.is_reset = True
|
| 193 |
+
|
| 194 |
+
# Send state update
|
| 195 |
+
await self._send_state_update()
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"observation": asdict(observation),
|
| 199 |
+
"reward": observation.reward,
|
| 200 |
+
"done": observation.done,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 204 |
+
"""Execute a step in the environment and update state."""
|
| 205 |
+
# Deserialize action
|
| 206 |
+
action = self._deserialize_action(action_data)
|
| 207 |
+
|
| 208 |
+
# Execute step
|
| 209 |
+
observation = self.env.step(action)
|
| 210 |
+
state = self.env.state
|
| 211 |
+
|
| 212 |
+
# Create action log
|
| 213 |
+
action_log = ActionLog(
|
| 214 |
+
timestamp=datetime.now().isoformat(),
|
| 215 |
+
action=asdict(action),
|
| 216 |
+
observation=asdict(observation),
|
| 217 |
+
reward=observation.reward,
|
| 218 |
+
done=observation.done,
|
| 219 |
+
step_count=state.step_count
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Update episode state
|
| 223 |
+
self.episode_state.episode_id = state.episode_id
|
| 224 |
+
self.episode_state.step_count = state.step_count
|
| 225 |
+
self.episode_state.current_observation = asdict(observation)
|
| 226 |
+
self.episode_state.action_logs.append(action_log)
|
| 227 |
+
self.episode_state.is_reset = False
|
| 228 |
+
|
| 229 |
+
# Send state update
|
| 230 |
+
await self._send_state_update()
|
| 231 |
+
|
| 232 |
+
return {
|
| 233 |
+
"observation": asdict(observation),
|
| 234 |
+
"reward": observation.reward,
|
| 235 |
+
"done": observation.done,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
def get_state(self) -> Dict[str, Any]:
|
| 239 |
+
"""Get current environment state."""
|
| 240 |
+
state = self.env.state
|
| 241 |
+
return asdict(state)
|
| 242 |
+
|
| 243 |
+
def _deserialize_action(self, action_data: Dict[str, Any]) -> Action:
|
| 244 |
+
"""Convert JSON dict to Action instance."""
|
| 245 |
+
metadata = action_data.pop("metadata", {})
|
| 246 |
+
|
| 247 |
+
# Handle tensor fields that come from JSON as lists
|
| 248 |
+
processed_data = {}
|
| 249 |
+
for key, value in action_data.items():
|
| 250 |
+
if key == "tokens" and isinstance(value, (list, str)):
|
| 251 |
+
# Convert list or string to tensor
|
| 252 |
+
if isinstance(value, str):
|
| 253 |
+
# If it's a string, try to parse it as a list of numbers
|
| 254 |
+
try:
|
| 255 |
+
import json
|
| 256 |
+
value = json.loads(value)
|
| 257 |
+
except:
|
| 258 |
+
# If parsing fails, treat as empty list
|
| 259 |
+
value = []
|
| 260 |
+
if isinstance(value, list):
|
| 261 |
+
import torch
|
| 262 |
+
processed_data[key] = torch.tensor(value, dtype=torch.long)
|
| 263 |
+
else:
|
| 264 |
+
processed_data[key] = value
|
| 265 |
+
elif key == "action_id" and isinstance(value, str):
|
| 266 |
+
# Convert action_id from string to int
|
| 267 |
+
try:
|
| 268 |
+
processed_data[key] = int(value)
|
| 269 |
+
except ValueError:
|
| 270 |
+
# If conversion fails, keep original value
|
| 271 |
+
processed_data[key] = value
|
| 272 |
+
else:
|
| 273 |
+
processed_data[key] = value
|
| 274 |
+
|
| 275 |
+
action = self.action_cls(**processed_data)
|
| 276 |
+
action.metadata = metadata
|
| 277 |
+
return action
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def create_web_interface_app(
|
| 281 |
+
env: Environment,
|
| 282 |
+
action_cls: Type[Action],
|
| 283 |
+
observation_cls: Type[Observation],
|
| 284 |
+
env_name: Optional[str] = None,
|
| 285 |
+
) -> FastAPI:
|
| 286 |
+
"""
|
| 287 |
+
Create a FastAPI application with web interface for the given environment.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
env: The Environment instance to serve
|
| 291 |
+
action_cls: The Action subclass this environment expects
|
| 292 |
+
observation_cls: The Observation subclass this environment returns
|
| 293 |
+
env_name: Optional environment name for README loading
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
FastAPI application instance with web interface
|
| 297 |
+
"""
|
| 298 |
+
from .http_server import create_fastapi_app
|
| 299 |
+
|
| 300 |
+
# Create the base environment app
|
| 301 |
+
app = create_fastapi_app(env, action_cls, observation_cls)
|
| 302 |
+
|
| 303 |
+
# Load environment metadata
|
| 304 |
+
metadata = load_environment_metadata(env, env_name)
|
| 305 |
+
|
| 306 |
+
# Create web interface manager
|
| 307 |
+
web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
|
| 308 |
+
|
| 309 |
+
# Add web interface routes
|
| 310 |
+
@app.get("/web", response_class=HTMLResponse)
|
| 311 |
+
async def web_interface():
|
| 312 |
+
"""Serve the web interface."""
|
| 313 |
+
return get_web_interface_html(action_cls, web_manager.metadata)
|
| 314 |
+
|
| 315 |
+
@app.get("/web/metadata")
|
| 316 |
+
async def web_metadata():
|
| 317 |
+
"""Get environment metadata."""
|
| 318 |
+
return asdict(web_manager.metadata)
|
| 319 |
+
|
| 320 |
+
@app.websocket("/ws")
|
| 321 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 322 |
+
"""WebSocket endpoint for real-time updates."""
|
| 323 |
+
await web_manager.connect_websocket(websocket)
|
| 324 |
+
try:
|
| 325 |
+
while True:
|
| 326 |
+
# Keep connection alive
|
| 327 |
+
await websocket.receive_text()
|
| 328 |
+
except WebSocketDisconnect:
|
| 329 |
+
await web_manager.disconnect_websocket(websocket)
|
| 330 |
+
|
| 331 |
+
@app.post("/web/reset")
|
| 332 |
+
async def web_reset():
|
| 333 |
+
"""Reset endpoint for web interface."""
|
| 334 |
+
return await web_manager.reset_environment()
|
| 335 |
+
|
| 336 |
+
@app.post("/web/step")
|
| 337 |
+
async def web_step(request: Dict[str, Any]):
|
| 338 |
+
"""Step endpoint for web interface."""
|
| 339 |
+
# Check if this is a message-based request (chat environment)
|
| 340 |
+
if "message" in request:
|
| 341 |
+
message = request["message"]
|
| 342 |
+
# Convert message to action using the environment's message_to_action method
|
| 343 |
+
action = web_manager.env.message_to_action(message)
|
| 344 |
+
action_data = {"tokens": action.tokens.tolist()}
|
| 345 |
+
else:
|
| 346 |
+
action_data = request.get("action", {})
|
| 347 |
+
|
| 348 |
+
return await web_manager.step_environment(action_data)
|
| 349 |
+
|
| 350 |
+
@app.get("/web/state")
|
| 351 |
+
async def web_state():
|
| 352 |
+
"""State endpoint for web interface."""
|
| 353 |
+
return web_manager.get_state()
|
| 354 |
+
|
| 355 |
+
return app
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str:
|
| 359 |
+
"""Generate the HTML for the web interface."""
|
| 360 |
+
|
| 361 |
+
# Check if this is a chat environment by looking for tokens field
|
| 362 |
+
is_chat_env = False
|
| 363 |
+
if hasattr(action_cls, '__dataclass_fields__'):
|
| 364 |
+
for field_name, field_info in action_cls.__dataclass_fields__.items():
|
| 365 |
+
if field_name == 'tokens' and hasattr(field_info.type, '__name__') and 'Tensor' in field_info.type.__name__:
|
| 366 |
+
is_chat_env = True
|
| 367 |
+
break
|
| 368 |
+
|
| 369 |
+
# Get action fields for dynamic form generation with enhanced metadata
|
| 370 |
+
action_fields = _extract_action_fields(action_cls)
|
| 371 |
+
|
| 372 |
+
return f"""
|
| 373 |
+
<!DOCTYPE html>
|
| 374 |
+
<html lang="en">
|
| 375 |
+
<head>
|
| 376 |
+
<meta charset="UTF-8">
|
| 377 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 378 |
+
<title>OpenEnv Web Interface</title>
|
| 379 |
+
<style>
|
| 380 |
+
* {{
|
| 381 |
+
margin: 0;
|
| 382 |
+
padding: 0;
|
| 383 |
+
box-sizing: border-box;
|
| 384 |
+
}}
|
| 385 |
+
|
| 386 |
+
body {{
|
| 387 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 388 |
+
background-color: #f5f5f5;
|
| 389 |
+
height: 100vh;
|
| 390 |
+
overflow: hidden;
|
| 391 |
+
}}
|
| 392 |
+
|
| 393 |
+
.container {{
|
| 394 |
+
display: flex;
|
| 395 |
+
height: 100vh;
|
| 396 |
+
}}
|
| 397 |
+
|
| 398 |
+
.left-pane {{
|
| 399 |
+
width: 50%;
|
| 400 |
+
background: white;
|
| 401 |
+
border-right: 1px solid #e0e0e0;
|
| 402 |
+
display: flex;
|
| 403 |
+
flex-direction: column;
|
| 404 |
+
}}
|
| 405 |
+
|
| 406 |
+
.right-pane {{
|
| 407 |
+
width: 50%;
|
| 408 |
+
background: #fafafa;
|
| 409 |
+
display: flex;
|
| 410 |
+
flex-direction: column;
|
| 411 |
+
}}
|
| 412 |
+
|
| 413 |
+
.pane-header {{
|
| 414 |
+
padding: 20px;
|
| 415 |
+
border-bottom: 1px solid #e0e0e0;
|
| 416 |
+
background: #f8f9fa;
|
| 417 |
+
font-weight: 600;
|
| 418 |
+
font-size: 16px;
|
| 419 |
+
}}
|
| 420 |
+
|
| 421 |
+
.pane-content {{
|
| 422 |
+
flex: 1;
|
| 423 |
+
padding: 20px;
|
| 424 |
+
overflow-y: auto;
|
| 425 |
+
}}
|
| 426 |
+
|
| 427 |
+
.action-form {{
|
| 428 |
+
background: white;
|
| 429 |
+
border: 1px solid #e0e0e0;
|
| 430 |
+
border-radius: 8px;
|
| 431 |
+
padding: 20px;
|
| 432 |
+
margin-bottom: 20px;
|
| 433 |
+
}}
|
| 434 |
+
|
| 435 |
+
.form-group {{
|
| 436 |
+
margin-bottom: 15px;
|
| 437 |
+
}}
|
| 438 |
+
|
| 439 |
+
.form-group label {{
|
| 440 |
+
display: block;
|
| 441 |
+
margin-bottom: 5px;
|
| 442 |
+
font-weight: 500;
|
| 443 |
+
color: #333;
|
| 444 |
+
}}
|
| 445 |
+
|
| 446 |
+
.form-group input, .form-group textarea {{
|
| 447 |
+
width: 100%;
|
| 448 |
+
padding: 8px 12px;
|
| 449 |
+
border: 1px solid #ddd;
|
| 450 |
+
border-radius: 4px;
|
| 451 |
+
font-size: 14px;
|
| 452 |
+
}}
|
| 453 |
+
|
| 454 |
+
.form-group input:focus, .form-group textarea:focus {{
|
| 455 |
+
outline: none;
|
| 456 |
+
border-color: #007bff;
|
| 457 |
+
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 458 |
+
}}
|
| 459 |
+
|
| 460 |
+
.btn {{
|
| 461 |
+
background: #007bff;
|
| 462 |
+
color: white;
|
| 463 |
+
border: none;
|
| 464 |
+
padding: 10px 20px;
|
| 465 |
+
border-radius: 4px;
|
| 466 |
+
cursor: pointer;
|
| 467 |
+
font-size: 14px;
|
| 468 |
+
margin-right: 10px;
|
| 469 |
+
margin-bottom: 10px;
|
| 470 |
+
}}
|
| 471 |
+
|
| 472 |
+
.btn:hover {{
|
| 473 |
+
background: #0056b3;
|
| 474 |
+
}}
|
| 475 |
+
|
| 476 |
+
.btn:disabled {{
|
| 477 |
+
background: #6c757d;
|
| 478 |
+
cursor: not-allowed;
|
| 479 |
+
}}
|
| 480 |
+
|
| 481 |
+
.btn-secondary {{
|
| 482 |
+
background: #6c757d;
|
| 483 |
+
}}
|
| 484 |
+
|
| 485 |
+
.btn-secondary:hover {{
|
| 486 |
+
background: #545b62;
|
| 487 |
+
}}
|
| 488 |
+
|
| 489 |
+
.state-display {{
|
| 490 |
+
background: white;
|
| 491 |
+
border: 1px solid #e0e0e0;
|
| 492 |
+
border-radius: 8px;
|
| 493 |
+
padding: 15px;
|
| 494 |
+
margin-bottom: 20px;
|
| 495 |
+
}}
|
| 496 |
+
|
| 497 |
+
.state-item {{
|
| 498 |
+
margin-bottom: 8px;
|
| 499 |
+
}}
|
| 500 |
+
|
| 501 |
+
.state-label {{
|
| 502 |
+
font-weight: 500;
|
| 503 |
+
color: #666;
|
| 504 |
+
}}
|
| 505 |
+
|
| 506 |
+
.state-value {{
|
| 507 |
+
color: #333;
|
| 508 |
+
font-family: monospace;
|
| 509 |
+
}}
|
| 510 |
+
|
| 511 |
+
.logs-container {{
|
| 512 |
+
background: white;
|
| 513 |
+
border: 1px solid #e0e0e0;
|
| 514 |
+
border-radius: 8px;
|
| 515 |
+
padding: 15px;
|
| 516 |
+
max-height: 400px;
|
| 517 |
+
overflow-y: auto;
|
| 518 |
+
}}
|
| 519 |
+
|
| 520 |
+
.log-entry {{
|
| 521 |
+
border-bottom: 1px solid #f0f0f0;
|
| 522 |
+
padding: 10px 0;
|
| 523 |
+
}}
|
| 524 |
+
|
| 525 |
+
.log-entry:last-child {{
|
| 526 |
+
border-bottom: none;
|
| 527 |
+
}}
|
| 528 |
+
|
| 529 |
+
.log-timestamp {{
|
| 530 |
+
font-size: 12px;
|
| 531 |
+
color: #666;
|
| 532 |
+
margin-bottom: 5px;
|
| 533 |
+
}}
|
| 534 |
+
|
| 535 |
+
.log-action {{
|
| 536 |
+
background: #e3f2fd;
|
| 537 |
+
padding: 8px;
|
| 538 |
+
border-radius: 4px;
|
| 539 |
+
margin-bottom: 5px;
|
| 540 |
+
font-family: monospace;
|
| 541 |
+
font-size: 12px;
|
| 542 |
+
}}
|
| 543 |
+
|
| 544 |
+
.log-observation {{
|
| 545 |
+
background: #f3e5f5;
|
| 546 |
+
padding: 8px;
|
| 547 |
+
border-radius: 4px;
|
| 548 |
+
font-family: monospace;
|
| 549 |
+
font-size: 12px;
|
| 550 |
+
}}
|
| 551 |
+
|
| 552 |
+
.log-reward {{
|
| 553 |
+
font-weight: 600;
|
| 554 |
+
color: #28a745;
|
| 555 |
+
}}
|
| 556 |
+
|
| 557 |
+
.log-done {{
|
| 558 |
+
font-weight: 600;
|
| 559 |
+
color: #dc3545;
|
| 560 |
+
}}
|
| 561 |
+
|
| 562 |
+
.status-indicator {{
|
| 563 |
+
display: inline-block;
|
| 564 |
+
width: 8px;
|
| 565 |
+
height: 8px;
|
| 566 |
+
border-radius: 50%;
|
| 567 |
+
margin-right: 8px;
|
| 568 |
+
}}
|
| 569 |
+
|
| 570 |
+
.status-connected {{
|
| 571 |
+
background: #28a745;
|
| 572 |
+
}}
|
| 573 |
+
|
| 574 |
+
.status-disconnected {{
|
| 575 |
+
background: #dc3545;
|
| 576 |
+
}}
|
| 577 |
+
|
| 578 |
+
.json-display {{
|
| 579 |
+
background: #f8f9fa;
|
| 580 |
+
border: 1px solid #e9ecef;
|
| 581 |
+
border-radius: 4px;
|
| 582 |
+
padding: 10px;
|
| 583 |
+
font-family: monospace;
|
| 584 |
+
font-size: 12px;
|
| 585 |
+
white-space: pre-wrap;
|
| 586 |
+
max-height: 200px;
|
| 587 |
+
overflow-y: auto;
|
| 588 |
+
}}
|
| 589 |
+
|
| 590 |
+
/* Chat Interface Styles */
|
| 591 |
+
.chat-interface {{
|
| 592 |
+
background: white;
|
| 593 |
+
border: 1px solid #e0e0e0;
|
| 594 |
+
border-radius: 8px;
|
| 595 |
+
padding: 20px;
|
| 596 |
+
margin-bottom: 20px;
|
| 597 |
+
}}
|
| 598 |
+
|
| 599 |
+
.chat-messages {{
|
| 600 |
+
background: #f8f9fa;
|
| 601 |
+
border: 1px solid #e0e0e0;
|
| 602 |
+
border-radius: 8px;
|
| 603 |
+
padding: 15px;
|
| 604 |
+
margin-bottom: 15px;
|
| 605 |
+
max-height: 400px;
|
| 606 |
+
overflow-y: auto;
|
| 607 |
+
}}
|
| 608 |
+
|
| 609 |
+
.chat-message {{
|
| 610 |
+
margin-bottom: 15px;
|
| 611 |
+
padding: 10px;
|
| 612 |
+
border-radius: 8px;
|
| 613 |
+
}}
|
| 614 |
+
|
| 615 |
+
.chat-message:last-child {{
|
| 616 |
+
margin-bottom: 0;
|
| 617 |
+
}}
|
| 618 |
+
|
| 619 |
+
.chat-message.user {{
|
| 620 |
+
background: #e3f2fd;
|
| 621 |
+
margin-left: 20px;
|
| 622 |
+
}}
|
| 623 |
+
|
| 624 |
+
.chat-message.assistant {{
|
| 625 |
+
background: #f3e5f5;
|
| 626 |
+
margin-right: 20px;
|
| 627 |
+
}}
|
| 628 |
+
|
| 629 |
+
.chat-message.system {{
|
| 630 |
+
background: #e8f5e8;
|
| 631 |
+
font-style: italic;
|
| 632 |
+
}}
|
| 633 |
+
|
| 634 |
+
.message-role {{
|
| 635 |
+
font-weight: 600;
|
| 636 |
+
font-size: 12px;
|
| 637 |
+
color: #666;
|
| 638 |
+
margin-bottom: 5px;
|
| 639 |
+
}}
|
| 640 |
+
|
| 641 |
+
.message-content {{
|
| 642 |
+
font-size: 14px;
|
| 643 |
+
line-height: 1.4;
|
| 644 |
+
}}
|
| 645 |
+
|
| 646 |
+
.chat-input-container {{
|
| 647 |
+
border-top: 1px solid #e0e0e0;
|
| 648 |
+
padding-top: 15px;
|
| 649 |
+
}}
|
| 650 |
+
|
| 651 |
+
.role-selector {{
|
| 652 |
+
margin-bottom: 10px;
|
| 653 |
+
}}
|
| 654 |
+
|
| 655 |
+
.role-selector label {{
|
| 656 |
+
font-weight: 500;
|
| 657 |
+
margin-right: 10px;
|
| 658 |
+
}}
|
| 659 |
+
|
| 660 |
+
.role-selector select {{
|
| 661 |
+
padding: 5px 10px;
|
| 662 |
+
border: 1px solid #ddd;
|
| 663 |
+
border-radius: 4px;
|
| 664 |
+
}}
|
| 665 |
+
|
| 666 |
+
.message-input {{
|
| 667 |
+
display: flex;
|
| 668 |
+
gap: 10px;
|
| 669 |
+
align-items: flex-end;
|
| 670 |
+
}}
|
| 671 |
+
|
| 672 |
+
.message-input textarea {{
|
| 673 |
+
flex: 1;
|
| 674 |
+
padding: 10px;
|
| 675 |
+
border: 1px solid #ddd;
|
| 676 |
+
border-radius: 4px;
|
| 677 |
+
resize: vertical;
|
| 678 |
+
font-family: inherit;
|
| 679 |
+
}}
|
| 680 |
+
|
| 681 |
+
.message-input textarea:focus {{
|
| 682 |
+
outline: none;
|
| 683 |
+
border-color: #007bff;
|
| 684 |
+
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 685 |
+
}}
|
| 686 |
+
|
| 687 |
+
/* Instructions Section Styles */
|
| 688 |
+
.instructions-section {{
|
| 689 |
+
background: white;
|
| 690 |
+
border: 1px solid #e0e0e0;
|
| 691 |
+
border-radius: 8px;
|
| 692 |
+
padding: 20px;
|
| 693 |
+
margin-bottom: 20px;
|
| 694 |
+
}}
|
| 695 |
+
|
| 696 |
+
.instructions-header {{
|
| 697 |
+
display: flex;
|
| 698 |
+
justify-content: space-between;
|
| 699 |
+
align-items: center;
|
| 700 |
+
margin-bottom: 15px;
|
| 701 |
+
}}
|
| 702 |
+
|
| 703 |
+
.instructions-title {{
|
| 704 |
+
font-size: 18px;
|
| 705 |
+
font-weight: 600;
|
| 706 |
+
color: #333;
|
| 707 |
+
margin: 0;
|
| 708 |
+
}}
|
| 709 |
+
|
| 710 |
+
.instructions-toggle {{
|
| 711 |
+
background: #f8f9fa;
|
| 712 |
+
border: 1px solid #dee2e6;
|
| 713 |
+
border-radius: 4px;
|
| 714 |
+
padding: 5px 10px;
|
| 715 |
+
cursor: pointer;
|
| 716 |
+
font-size: 12px;
|
| 717 |
+
color: #6c757d;
|
| 718 |
+
}}
|
| 719 |
+
|
| 720 |
+
.instructions-toggle:hover {{
|
| 721 |
+
background: #e9ecef;
|
| 722 |
+
}}
|
| 723 |
+
|
| 724 |
+
.instructions-content {{
|
| 725 |
+
display: none;
|
| 726 |
+
max-height: 400px;
|
| 727 |
+
overflow-y: auto;
|
| 728 |
+
border-top: 1px solid #e0e0e0;
|
| 729 |
+
padding-top: 15px;
|
| 730 |
+
}}
|
| 731 |
+
|
| 732 |
+
.instructions-content.expanded {{
|
| 733 |
+
display: block;
|
| 734 |
+
}}
|
| 735 |
+
|
| 736 |
+
.instructions-content h1,
|
| 737 |
+
.instructions-content h2,
|
| 738 |
+
.instructions-content h3 {{
|
| 739 |
+
color: #333;
|
| 740 |
+
margin-top: 20px;
|
| 741 |
+
margin-bottom: 10px;
|
| 742 |
+
}}
|
| 743 |
+
|
| 744 |
+
.instructions-content h1 {{
|
| 745 |
+
font-size: 24px;
|
| 746 |
+
border-bottom: 2px solid #007bff;
|
| 747 |
+
padding-bottom: 10px;
|
| 748 |
+
}}
|
| 749 |
+
|
| 750 |
+
.instructions-content h2 {{
|
| 751 |
+
font-size: 20px;
|
| 752 |
+
}}
|
| 753 |
+
|
| 754 |
+
.instructions-content h3 {{
|
| 755 |
+
font-size: 16px;
|
| 756 |
+
}}
|
| 757 |
+
|
| 758 |
+
.instructions-content p {{
|
| 759 |
+
margin-bottom: 10px;
|
| 760 |
+
line-height: 1.6;
|
| 761 |
+
}}
|
| 762 |
+
|
| 763 |
+
.instructions-content code {{
|
| 764 |
+
background: #f8f9fa;
|
| 765 |
+
padding: 2px 4px;
|
| 766 |
+
border-radius: 3px;
|
| 767 |
+
font-family: monospace;
|
| 768 |
+
font-size: 14px;
|
| 769 |
+
}}
|
| 770 |
+
|
| 771 |
+
.instructions-content pre {{
|
| 772 |
+
background: #f8f9fa;
|
| 773 |
+
border: 1px solid #e9ecef;
|
| 774 |
+
border-radius: 4px;
|
| 775 |
+
padding: 15px;
|
| 776 |
+
overflow-x: auto;
|
| 777 |
+
margin: 10px 0;
|
| 778 |
+
}}
|
| 779 |
+
|
| 780 |
+
.instructions-content pre code {{
|
| 781 |
+
background: none;
|
| 782 |
+
padding: 0;
|
| 783 |
+
}}
|
| 784 |
+
|
| 785 |
+
.instructions-content ul,
|
| 786 |
+
.instructions-content ol {{
|
| 787 |
+
margin: 10px 0;
|
| 788 |
+
padding-left: 20px;
|
| 789 |
+
}}
|
| 790 |
+
|
| 791 |
+
.instructions-content li {{
|
| 792 |
+
margin-bottom: 5px;
|
| 793 |
+
}}
|
| 794 |
+
|
| 795 |
+
.instructions-content table {{
|
| 796 |
+
border-collapse: collapse;
|
| 797 |
+
width: 100%;
|
| 798 |
+
margin: 15px 0;
|
| 799 |
+
}}
|
| 800 |
+
|
| 801 |
+
.instructions-content th,
|
| 802 |
+
.instructions-content td {{
|
| 803 |
+
border: 1px solid #dee2e6;
|
| 804 |
+
padding: 8px 12px;
|
| 805 |
+
text-align: left;
|
| 806 |
+
}}
|
| 807 |
+
|
| 808 |
+
.instructions-content th {{
|
| 809 |
+
background: #f8f9fa;
|
| 810 |
+
font-weight: 600;
|
| 811 |
+
}}
|
| 812 |
+
|
| 813 |
+
/* Enhanced Form Styles */
|
| 814 |
+
.help-text {{
|
| 815 |
+
display: block;
|
| 816 |
+
margin-top: 5px;
|
| 817 |
+
font-size: 12px;
|
| 818 |
+
color: #6c757d;
|
| 819 |
+
font-style: italic;
|
| 820 |
+
}}
|
| 821 |
+
|
| 822 |
+
.form-group label {{
|
| 823 |
+
font-weight: 500;
|
| 824 |
+
color: #333;
|
| 825 |
+
margin-bottom: 5px;
|
| 826 |
+
}}
|
| 827 |
+
|
| 828 |
+
.form-group select {{
|
| 829 |
+
width: 100%;
|
| 830 |
+
padding: 8px 12px;
|
| 831 |
+
border: 1px solid #ddd;
|
| 832 |
+
border-radius: 4px;
|
| 833 |
+
font-size: 14px;
|
| 834 |
+
background-color: white;
|
| 835 |
+
}}
|
| 836 |
+
|
| 837 |
+
.form-group select:focus {{
|
| 838 |
+
outline: none;
|
| 839 |
+
border-color: #007bff;
|
| 840 |
+
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 841 |
+
}}
|
| 842 |
+
|
| 843 |
+
.form-group textarea {{
|
| 844 |
+
width: 100%;
|
| 845 |
+
padding: 8px 12px;
|
| 846 |
+
border: 1px solid #ddd;
|
| 847 |
+
border-radius: 4px;
|
| 848 |
+
font-size: 14px;
|
| 849 |
+
font-family: inherit;
|
| 850 |
+
resize: vertical;
|
| 851 |
+
}}
|
| 852 |
+
|
| 853 |
+
.form-group textarea:focus {{
|
| 854 |
+
outline: none;
|
| 855 |
+
border-color: #007bff;
|
| 856 |
+
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 857 |
+
}}
|
| 858 |
+
|
| 859 |
+
.form-group input[type="number"] {{
|
| 860 |
+
width: 100%;
|
| 861 |
+
padding: 8px 12px;
|
| 862 |
+
border: 1px solid #ddd;
|
| 863 |
+
border-radius: 4px;
|
| 864 |
+
font-size: 14px;
|
| 865 |
+
}}
|
| 866 |
+
|
| 867 |
+
.form-group input[type="number"]:focus {{
|
| 868 |
+
outline: none;
|
| 869 |
+
border-color: #007bff;
|
| 870 |
+
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 871 |
+
}}
|
| 872 |
+
|
| 873 |
+
.form-group input[type="text"]:focus {{
|
| 874 |
+
outline: none;
|
| 875 |
+
border-color: #007bff;
|
| 876 |
+
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 877 |
+
}}
|
| 878 |
+
|
| 879 |
+
.required-indicator {{
|
| 880 |
+
color: #dc3545;
|
| 881 |
+
font-weight: bold;
|
| 882 |
+
}}
|
| 883 |
+
|
| 884 |
+
.form-group .field-description {{
|
| 885 |
+
font-size: 11px;
|
| 886 |
+
color: #666;
|
| 887 |
+
margin-top: 2px;
|
| 888 |
+
font-style: italic;
|
| 889 |
+
}}
|
| 890 |
+
</style>
|
| 891 |
+
</head>
|
| 892 |
+
<body>
|
| 893 |
+
<div class="container">
|
| 894 |
+
<!-- Left Pane: HumanAgent Interface -->
|
| 895 |
+
<div class="left-pane">
|
| 896 |
+
<div class="pane-header">
|
| 897 |
+
<span class="status-indicator status-disconnected" id="connection-status"></span>
|
| 898 |
+
HumanAgent Interface
|
| 899 |
+
</div>
|
| 900 |
+
<div class="pane-content">
|
| 901 |
+
<!-- Instructions Section -->
|
| 902 |
+
{_generate_instructions_section(metadata)}
|
| 903 |
+
|
| 904 |
+
<!-- Action Form or Chat Interface -->
|
| 905 |
+
{_generate_action_interface(action_fields, is_chat_env)}
|
| 906 |
+
|
| 907 |
+
<!-- Control Buttons -->
|
| 908 |
+
<div style="margin-bottom: 20px;">
|
| 909 |
+
<button class="btn btn-secondary" id="reset-btn">Reset Environment</button>
|
| 910 |
+
<button class="btn btn-secondary" id="state-btn">Get State</button>
|
| 911 |
+
</div>
|
| 912 |
+
|
| 913 |
+
<!-- Current State Display -->
|
| 914 |
+
<div class="state-display">
|
| 915 |
+
<h3>Current State</h3>
|
| 916 |
+
<div id="current-state">
|
| 917 |
+
<div class="state-item">
|
| 918 |
+
<span class="state-label">Status:</span>
|
| 919 |
+
<span class="state-value" id="env-status">Not initialized</span>
|
| 920 |
+
</div>
|
| 921 |
+
<div class="state-item">
|
| 922 |
+
<span class="state-label">Episode ID:</span>
|
| 923 |
+
<span class="state-value" id="episode-id">-</span>
|
| 924 |
+
</div>
|
| 925 |
+
<div class="state-item">
|
| 926 |
+
<span class="state-label">Step Count:</span>
|
| 927 |
+
<span class="state-value" id="step-count">0</span>
|
| 928 |
+
</div>
|
| 929 |
+
</div>
|
| 930 |
+
</div>
|
| 931 |
+
</div>
|
| 932 |
+
</div>
|
| 933 |
+
|
| 934 |
+
<!-- Right Pane: State Observer -->
|
| 935 |
+
<div class="right-pane">
|
| 936 |
+
<div class="pane-header">
|
| 937 |
+
State Observer
|
| 938 |
+
</div>
|
| 939 |
+
<div class="pane-content">
|
| 940 |
+
<!-- Current Observation -->
|
| 941 |
+
<div class="state-display">
|
| 942 |
+
<h3>Current Observation</h3>
|
| 943 |
+
<div id="current-observation" class="json-display">
|
| 944 |
+
No observation yet
|
| 945 |
+
</div>
|
| 946 |
+
</div>
|
| 947 |
+
|
| 948 |
+
<!-- Action Logs -->
|
| 949 |
+
<div class="logs-container">
|
| 950 |
+
<h3>Action History</h3>
|
| 951 |
+
<div id="action-logs">
|
| 952 |
+
No actions taken yet
|
| 953 |
+
</div>
|
| 954 |
+
</div>
|
| 955 |
+
</div>
|
| 956 |
+
</div>
|
| 957 |
+
</div>
|
| 958 |
+
|
| 959 |
+
<script>
|
| 960 |
+
class OpenEnvWebInterface {{
|
| 961 |
+
constructor() {{
|
| 962 |
+
this.ws = null;
|
| 963 |
+
this.isConnected = false;
|
| 964 |
+
this.init();
|
| 965 |
+
}}
|
| 966 |
+
|
| 967 |
+
init() {{
|
| 968 |
+
this.connectWebSocket();
|
| 969 |
+
this.setupEventListeners();
|
| 970 |
+
}}
|
| 971 |
+
|
| 972 |
+
connectWebSocket() {{
|
| 973 |
+
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 974 |
+
const wsUrl = `${{protocol}}//${{window.location.host}}/ws`;
|
| 975 |
+
|
| 976 |
+
this.ws = new WebSocket(wsUrl);
|
| 977 |
+
|
| 978 |
+
this.ws.onopen = () => {{
|
| 979 |
+
this.isConnected = true;
|
| 980 |
+
this.updateConnectionStatus(true);
|
| 981 |
+
console.log('WebSocket connected');
|
| 982 |
+
}};
|
| 983 |
+
|
| 984 |
+
this.ws.onmessage = (event) => {{
|
| 985 |
+
const data = JSON.parse(event.data);
|
| 986 |
+
if (data.type === 'state_update') {{
|
| 987 |
+
this.updateUI(data.episode_state);
|
| 988 |
+
}}
|
| 989 |
+
}};
|
| 990 |
+
|
| 991 |
+
this.ws.onclose = () => {{
|
| 992 |
+
this.isConnected = false;
|
| 993 |
+
this.updateConnectionStatus(false);
|
| 994 |
+
console.log('WebSocket disconnected');
|
| 995 |
+
// Attempt to reconnect after 3 seconds
|
| 996 |
+
setTimeout(() => this.connectWebSocket(), 3000);
|
| 997 |
+
}};
|
| 998 |
+
|
| 999 |
+
this.ws.onerror = (error) => {{
|
| 1000 |
+
console.error('WebSocket error:', error);
|
| 1001 |
+
}};
|
| 1002 |
+
}}
|
| 1003 |
+
|
| 1004 |
+
setupEventListeners() {{
|
| 1005 |
+
// Instructions toggle
|
| 1006 |
+
const instructionsToggle = document.getElementById('instructions-toggle');
|
| 1007 |
+
const instructionsContent = document.getElementById('instructions-content');
|
| 1008 |
+
if (instructionsToggle && instructionsContent) {{
|
| 1009 |
+
instructionsToggle.addEventListener('click', () => {{
|
| 1010 |
+
instructionsContent.classList.toggle('expanded');
|
| 1011 |
+
instructionsToggle.textContent = instructionsContent.classList.contains('expanded')
|
| 1012 |
+
? 'Hide Instructions' : 'Show Instructions';
|
| 1013 |
+
}});
|
| 1014 |
+
}}
|
| 1015 |
+
|
| 1016 |
+
// Check if this is a chat environment
|
| 1017 |
+
const isChatEnv = document.getElementById('chat-messages') !== null;
|
| 1018 |
+
|
| 1019 |
+
if (isChatEnv) {{
|
| 1020 |
+
// Chat environment event listeners
|
| 1021 |
+
document.getElementById('send-message-btn').addEventListener('click', () => {{
|
| 1022 |
+
this.sendMessage();
|
| 1023 |
+
}});
|
| 1024 |
+
|
| 1025 |
+
// Send message on Enter (but allow Shift+Enter for new lines)
|
| 1026 |
+
document.getElementById('message-input').addEventListener('keydown', (e) => {{
|
| 1027 |
+
if (e.key === 'Enter' && !e.shiftKey) {{
|
| 1028 |
+
e.preventDefault();
|
| 1029 |
+
this.sendMessage();
|
| 1030 |
+
}}
|
| 1031 |
+
}});
|
| 1032 |
+
}} else {{
|
| 1033 |
+
// Traditional action form submission
|
| 1034 |
+
const actionForm = document.getElementById('action-form');
|
| 1035 |
+
if (actionForm) {{
|
| 1036 |
+
actionForm.addEventListener('submit', (e) => {{
|
| 1037 |
+
e.preventDefault();
|
| 1038 |
+
this.submitAction();
|
| 1039 |
+
}});
|
| 1040 |
+
}}
|
| 1041 |
+
}}
|
| 1042 |
+
|
| 1043 |
+
// Reset button
|
| 1044 |
+
document.getElementById('reset-btn').addEventListener('click', () => {{
|
| 1045 |
+
this.resetEnvironment();
|
| 1046 |
+
}});
|
| 1047 |
+
|
| 1048 |
+
// State button
|
| 1049 |
+
document.getElementById('state-btn').addEventListener('click', () => {{
|
| 1050 |
+
this.getState();
|
| 1051 |
+
}});
|
| 1052 |
+
}}
|
| 1053 |
+
|
| 1054 |
+
async sendMessage() {{
|
| 1055 |
+
const messageInput = document.getElementById('message-input');
|
| 1056 |
+
const roleSelect = document.getElementById('message-role');
|
| 1057 |
+
const message = messageInput.value.trim();
|
| 1058 |
+
const role = roleSelect.value;
|
| 1059 |
+
|
| 1060 |
+
if (!message) {{
|
| 1061 |
+
return;
|
| 1062 |
+
}}
|
| 1063 |
+
|
| 1064 |
+
// Add message to chat display immediately
|
| 1065 |
+
this.addMessageToChat(role, message);
|
| 1066 |
+
|
| 1067 |
+
// Clear input
|
| 1068 |
+
messageInput.value = '';
|
| 1069 |
+
|
| 1070 |
+
try {{
|
| 1071 |
+
// Send message to server to convert to action and step
|
| 1072 |
+
const response = await fetch('/web/step', {{
|
| 1073 |
+
method: 'POST',
|
| 1074 |
+
headers: {{ 'Content-Type': 'application/json' }},
|
| 1075 |
+
body: JSON.stringify({{
|
| 1076 |
+
message: {{
|
| 1077 |
+
role: role,
|
| 1078 |
+
content: message
|
| 1079 |
+
}}
|
| 1080 |
+
}})
|
| 1081 |
+
}});
|
| 1082 |
+
|
| 1083 |
+
if (!response.ok) {{
|
| 1084 |
+
throw new Error(`HTTP error! status: ${{response.status}}`);
|
| 1085 |
+
}}
|
| 1086 |
+
|
| 1087 |
+
const result = await response.json();
|
| 1088 |
+
console.log('Message sent:', result);
|
| 1089 |
+
}} catch (error) {{
|
| 1090 |
+
console.error('Error sending message:', error);
|
| 1091 |
+
alert('Error sending message: ' + error.message);
|
| 1092 |
+
}}
|
| 1093 |
+
}}
|
| 1094 |
+
|
| 1095 |
+
addMessageToChat(role, content) {{
|
| 1096 |
+
const chatMessages = document.getElementById('chat-messages');
|
| 1097 |
+
const messageDiv = document.createElement('div');
|
| 1098 |
+
messageDiv.className = `chat-message ${{role}}`;
|
| 1099 |
+
|
| 1100 |
+
messageDiv.innerHTML = `
|
| 1101 |
+
<div class="message-role">${{role.charAt(0).toUpperCase() + role.slice(1)}}</div>
|
| 1102 |
+
<div class="message-content">${{content}}</div>
|
| 1103 |
+
`;
|
| 1104 |
+
|
| 1105 |
+
chatMessages.appendChild(messageDiv);
|
| 1106 |
+
chatMessages.scrollTop = chatMessages.scrollHeight;
|
| 1107 |
+
}}
|
| 1108 |
+
|
| 1109 |
+
async submitAction() {{
|
| 1110 |
+
const formData = new FormData(document.getElementById('action-form'));
|
| 1111 |
+
const action = {{}};
|
| 1112 |
+
|
| 1113 |
+
// Collect form data
|
| 1114 |
+
for (const [key, value] of formData.entries()) {{
|
| 1115 |
+
if (value !== '') {{
|
| 1116 |
+
// Handle tensor fields (tokens) - convert comma-separated string to array
|
| 1117 |
+
if (key === 'tokens') {{
|
| 1118 |
+
try {{
|
| 1119 |
+
action[key] = value.split(',').map(x => parseInt(x.trim())).filter(x => !isNaN(x));
|
| 1120 |
+
}} catch (e) {{
|
| 1121 |
+
console.error('Error parsing tokens:', e);
|
| 1122 |
+
action[key] = [];
|
| 1123 |
+
}}
|
| 1124 |
+
}} else {{
|
| 1125 |
+
action[key] = value;
|
| 1126 |
+
}}
|
| 1127 |
+
}}
|
| 1128 |
+
}}
|
| 1129 |
+
|
| 1130 |
+
try {{
|
| 1131 |
+
const response = await fetch('/web/step', {{
|
| 1132 |
+
method: 'POST',
|
| 1133 |
+
headers: {{ 'Content-Type': 'application/json' }},
|
| 1134 |
+
body: JSON.stringify({{ action }})
|
| 1135 |
+
}});
|
| 1136 |
+
|
| 1137 |
+
if (!response.ok) {{
|
| 1138 |
+
throw new Error(`HTTP error! status: ${{response.status}}`);
|
| 1139 |
+
}}
|
| 1140 |
+
|
| 1141 |
+
const result = await response.json();
|
| 1142 |
+
console.log('Step result:', result);
|
| 1143 |
+
}} catch (error) {{
|
| 1144 |
+
console.error('Error submitting action:', error);
|
| 1145 |
+
alert('Error submitting action: ' + error.message);
|
| 1146 |
+
}}
|
| 1147 |
+
}}
|
| 1148 |
+
|
| 1149 |
+
async resetEnvironment() {{
|
| 1150 |
+
try {{
|
| 1151 |
+
const response = await fetch('/web/reset', {{
|
| 1152 |
+
method: 'POST',
|
| 1153 |
+
headers: {{ 'Content-Type': 'application/json' }}
|
| 1154 |
+
}});
|
| 1155 |
+
|
| 1156 |
+
if (!response.ok) {{
|
| 1157 |
+
throw new Error(`HTTP error! status: ${{response.status}}`);
|
| 1158 |
+
}}
|
| 1159 |
+
|
| 1160 |
+
const result = await response.json();
|
| 1161 |
+
console.log('Reset result:', result);
|
| 1162 |
+
}} catch (error) {{
|
| 1163 |
+
console.error('Error resetting environment:', error);
|
| 1164 |
+
alert('Error resetting environment: ' + error.message);
|
| 1165 |
+
}}
|
| 1166 |
+
}}
|
| 1167 |
+
|
| 1168 |
+
async getState() {{
|
| 1169 |
+
try {{
|
| 1170 |
+
const response = await fetch('/web/state');
|
| 1171 |
+
const state = await response.json();
|
| 1172 |
+
console.log('Current state:', state);
|
| 1173 |
+
alert('Current state: ' + JSON.stringify(state, null, 2));
|
| 1174 |
+
}} catch (error) {{
|
| 1175 |
+
console.error('Error getting state:', error);
|
| 1176 |
+
alert('Error getting state: ' + error.message);
|
| 1177 |
+
}}
|
| 1178 |
+
}}
|
| 1179 |
+
|
| 1180 |
+
updateConnectionStatus(connected) {{
|
| 1181 |
+
const indicator = document.getElementById('connection-status');
|
| 1182 |
+
if (connected) {{
|
| 1183 |
+
indicator.className = 'status-indicator status-connected';
|
| 1184 |
+
}} else {{
|
| 1185 |
+
indicator.className = 'status-indicator status-disconnected';
|
| 1186 |
+
}}
|
| 1187 |
+
}}
|
| 1188 |
+
|
| 1189 |
+
updateUI(episodeState) {{
|
| 1190 |
+
// Check if this is a chat environment
|
| 1191 |
+
const isChatEnv = document.getElementById('chat-messages') !== null;
|
| 1192 |
+
|
| 1193 |
+
// Update current state
|
| 1194 |
+
document.getElementById('env-status').textContent =
|
| 1195 |
+
episodeState.is_reset ? 'Reset' : 'Running';
|
| 1196 |
+
document.getElementById('episode-id').textContent =
|
| 1197 |
+
episodeState.episode_id || '-';
|
| 1198 |
+
document.getElementById('step-count').textContent =
|
| 1199 |
+
episodeState.step_count.toString();
|
| 1200 |
+
|
| 1201 |
+
if (isChatEnv) {{
|
| 1202 |
+
// Update chat interface
|
| 1203 |
+
this.updateChatInterface(episodeState);
|
| 1204 |
+
}} else {{
|
| 1205 |
+
// Update traditional observation display
|
| 1206 |
+
const observationDiv = document.getElementById('current-observation');
|
| 1207 |
+
if (episodeState.current_observation) {{
|
| 1208 |
+
observationDiv.textContent = JSON.stringify(
|
| 1209 |
+
episodeState.current_observation, null, 2
|
| 1210 |
+
);
|
| 1211 |
+
}} else {{
|
| 1212 |
+
observationDiv.textContent = 'No observation yet';
|
| 1213 |
+
}}
|
| 1214 |
+
}}
|
| 1215 |
+
|
| 1216 |
+
// Update action logs
|
| 1217 |
+
const logsDiv = document.getElementById('action-logs');
|
| 1218 |
+
if (episodeState.action_logs.length === 0) {{
|
| 1219 |
+
logsDiv.innerHTML = 'No actions taken yet';
|
| 1220 |
+
}} else {{
|
| 1221 |
+
logsDiv.innerHTML = episodeState.action_logs.map(log => `
|
| 1222 |
+
<div class="log-entry">
|
| 1223 |
+
<div class="log-timestamp">${{log.timestamp}} (Step ${{log.step_count}})</div>
|
| 1224 |
+
<div class="log-action">Action: ${{JSON.stringify(log.action, null, 2)}}</div>
|
| 1225 |
+
<div class="log-observation">Observation: ${{JSON.stringify(log.observation, null, 2)}}</div>
|
| 1226 |
+
<div>
|
| 1227 |
+
<span class="log-reward">Reward: ${{log.reward !== null ? log.reward : 'None'}}</span>
|
| 1228 |
+
${{log.done ? '<span class="log-done">DONE</span>' : ''}}
|
| 1229 |
+
</div>
|
| 1230 |
+
</div>
|
| 1231 |
+
`).join('');
|
| 1232 |
+
}}
|
| 1233 |
+
}}
|
| 1234 |
+
|
| 1235 |
+
updateChatInterface(episodeState) {{
|
| 1236 |
+
const chatMessages = document.getElementById('chat-messages');
|
| 1237 |
+
if (!chatMessages) return;
|
| 1238 |
+
|
| 1239 |
+
// Clear existing messages (except system message)
|
| 1240 |
+
const systemMessage = chatMessages.querySelector('.chat-message.system');
|
| 1241 |
+
chatMessages.innerHTML = '';
|
| 1242 |
+
if (systemMessage) {{
|
| 1243 |
+
chatMessages.appendChild(systemMessage);
|
| 1244 |
+
}}
|
| 1245 |
+
|
| 1246 |
+
// Add messages from current observation
|
| 1247 |
+
if (episodeState.current_observation && episodeState.current_observation.messages) {{
|
| 1248 |
+
episodeState.current_observation.messages.forEach(msg => {{
|
| 1249 |
+
this.addMessageToChat(msg.role, msg.content);
|
| 1250 |
+
}});
|
| 1251 |
+
}}
|
| 1252 |
+
}}
|
| 1253 |
+
}}
|
| 1254 |
+
|
| 1255 |
+
// Initialize the web interface when the page loads
|
| 1256 |
+
document.addEventListener('DOMContentLoaded', () => {{
|
| 1257 |
+
new OpenEnvWebInterface();
|
| 1258 |
+
}});
|
| 1259 |
+
</script>
|
| 1260 |
+
</body>
|
| 1261 |
+
</html>
|
| 1262 |
+
""".replace('{_generate_action_form_fields(action_fields)}', _generate_action_form_fields(action_fields))
|
| 1263 |
+
|
| 1264 |
+
|
| 1265 |
+
def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> str:
|
| 1266 |
+
"""Generate the instructions section with environment documentation."""
|
| 1267 |
+
if not metadata or not metadata.readme_content:
|
| 1268 |
+
return ''
|
| 1269 |
+
|
| 1270 |
+
# Convert markdown to HTML (basic conversion)
|
| 1271 |
+
import re
|
| 1272 |
+
html_content = _markdown_to_html(metadata.readme_content)
|
| 1273 |
+
|
| 1274 |
+
return f'''
|
| 1275 |
+
<!-- Instructions Section -->
|
| 1276 |
+
<div class="instructions-section">
|
| 1277 |
+
<div class="instructions-header">
|
| 1278 |
+
<h3 class="instructions-title">{metadata.name}</h3>
|
| 1279 |
+
<button class="instructions-toggle" id="instructions-toggle">Show Instructions</button>
|
| 1280 |
+
</div>
|
| 1281 |
+
<div class="instructions-content" id="instructions-content">
|
| 1282 |
+
<div class="instructions-readme">
|
| 1283 |
+
{html_content}
|
| 1284 |
+
</div>
|
| 1285 |
+
</div>
|
| 1286 |
+
</div>
|
| 1287 |
+
'''
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
|
| 1291 |
+
"""Extract enhanced field metadata from Action class for form generation."""
|
| 1292 |
+
import typing
|
| 1293 |
+
from typing import get_origin, get_args
|
| 1294 |
+
|
| 1295 |
+
action_fields = []
|
| 1296 |
+
if not hasattr(action_cls, '__dataclass_fields__'):
|
| 1297 |
+
return action_fields
|
| 1298 |
+
|
| 1299 |
+
for field_name, field_info in action_cls.__dataclass_fields__.items():
|
| 1300 |
+
if field_name == 'metadata':
|
| 1301 |
+
continue
|
| 1302 |
+
|
| 1303 |
+
field_type = field_info.type
|
| 1304 |
+
field_metadata = _extract_field_metadata(field_name, field_info)
|
| 1305 |
+
|
| 1306 |
+
# Determine input type based on field type
|
| 1307 |
+
input_type = _determine_input_type(field_type)
|
| 1308 |
+
|
| 1309 |
+
# Check if field is required
|
| 1310 |
+
is_required = field_info.default is field_info.default_factory
|
| 1311 |
+
|
| 1312 |
+
action_fields.append({
|
| 1313 |
+
'name': field_name,
|
| 1314 |
+
'type': input_type,
|
| 1315 |
+
'required': is_required,
|
| 1316 |
+
'description': field_metadata.get('description', ''),
|
| 1317 |
+
'default_value': field_metadata.get('default_value'),
|
| 1318 |
+
'choices': field_metadata.get('choices', []),
|
| 1319 |
+
'min_value': field_metadata.get('min_value'),
|
| 1320 |
+
'max_value': field_metadata.get('max_value'),
|
| 1321 |
+
'placeholder': field_metadata.get('placeholder', ''),
|
| 1322 |
+
'help_text': field_metadata.get('help_text', ''),
|
| 1323 |
+
})
|
| 1324 |
+
|
| 1325 |
+
return action_fields
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
def _extract_field_metadata(field_name: str, field_info) -> Dict[str, Any]:
|
| 1329 |
+
"""Extract metadata from dataclass field including docstring and type hints."""
|
| 1330 |
+
import typing
|
| 1331 |
+
from typing import get_origin, get_args, Literal, Union, Optional
|
| 1332 |
+
|
| 1333 |
+
metadata = {}
|
| 1334 |
+
|
| 1335 |
+
# Extract description from field docstring or annotation
|
| 1336 |
+
if hasattr(field_info, 'metadata') and field_info.metadata:
|
| 1337 |
+
# Check for custom metadata
|
| 1338 |
+
for meta in field_info.metadata:
|
| 1339 |
+
if isinstance(meta, dict):
|
| 1340 |
+
metadata.update(meta)
|
| 1341 |
+
|
| 1342 |
+
# Extract type information
|
| 1343 |
+
field_type = field_info.type
|
| 1344 |
+
origin = get_origin(field_type)
|
| 1345 |
+
|
| 1346 |
+
# Handle Literal types for dropdown choices
|
| 1347 |
+
if origin is Literal:
|
| 1348 |
+
args = get_args(field_type)
|
| 1349 |
+
metadata['choices'] = list(args)
|
| 1350 |
+
|
| 1351 |
+
# Handle Optional types
|
| 1352 |
+
if origin is Union:
|
| 1353 |
+
args = get_args(field_type)
|
| 1354 |
+
if len(args) == 2 and type(None) in args:
|
| 1355 |
+
# This is Optional[SomeType]
|
| 1356 |
+
non_none_type = args[0] if args[1] is type(None) else args[1]
|
| 1357 |
+
metadata['optional'] = True
|
| 1358 |
+
# Recursively check the non-None type for choices
|
| 1359 |
+
if get_origin(non_none_type) is Literal:
|
| 1360 |
+
metadata['choices'] = list(get_args(non_none_type))
|
| 1361 |
+
else:
|
| 1362 |
+
# Regular Union type
|
| 1363 |
+
metadata['choices'] = [str(arg) for arg in args if arg is not type(None)]
|
| 1364 |
+
|
| 1365 |
+
# Handle numeric constraints
|
| 1366 |
+
if field_type in (int, float):
|
| 1367 |
+
# Check for common constraint patterns in field name
|
| 1368 |
+
if 'count' in field_name.lower() or 'num' in field_name.lower():
|
| 1369 |
+
metadata['min_value'] = 0
|
| 1370 |
+
if 'id' in field_name.lower():
|
| 1371 |
+
metadata['min_value'] = 0
|
| 1372 |
+
|
| 1373 |
+
# Generate placeholder text
|
| 1374 |
+
if 'message' in field_name.lower():
|
| 1375 |
+
metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...'
|
| 1376 |
+
elif 'code' in field_name.lower():
|
| 1377 |
+
metadata['placeholder'] = 'Enter Python code here...'
|
| 1378 |
+
elif 'tokens' in field_name.lower():
|
| 1379 |
+
metadata['placeholder'] = 'Enter comma-separated token IDs (e.g., 1,2,3,4,5)'
|
| 1380 |
+
else:
|
| 1381 |
+
metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...'
|
| 1382 |
+
|
| 1383 |
+
# Generate help text based on field name and type
|
| 1384 |
+
if 'action_id' in field_name.lower():
|
| 1385 |
+
metadata['help_text'] = 'The action ID to execute in the environment'
|
| 1386 |
+
elif 'game_name' in field_name.lower():
|
| 1387 |
+
metadata['help_text'] = 'Name of the game or environment'
|
| 1388 |
+
elif 'tokens' in field_name.lower():
|
| 1389 |
+
metadata['help_text'] = 'Token IDs as a comma-separated list of integers'
|
| 1390 |
+
elif 'code' in field_name.lower():
|
| 1391 |
+
metadata['help_text'] = 'Python code to execute in the environment'
|
| 1392 |
+
elif 'message' in field_name.lower():
|
| 1393 |
+
metadata['help_text'] = 'Text message to send'
|
| 1394 |
+
|
| 1395 |
+
return metadata
|
| 1396 |
+
|
| 1397 |
+
|
| 1398 |
+
def _determine_input_type(field_type) -> str:
|
| 1399 |
+
"""Determine the appropriate HTML input type for a field type."""
|
| 1400 |
+
import typing
|
| 1401 |
+
from typing import get_origin, get_args, Literal, Union
|
| 1402 |
+
|
| 1403 |
+
# Handle direct types
|
| 1404 |
+
if field_type == str:
|
| 1405 |
+
return "text"
|
| 1406 |
+
elif field_type == int:
|
| 1407 |
+
return "number"
|
| 1408 |
+
elif field_type == float:
|
| 1409 |
+
return "number"
|
| 1410 |
+
elif field_type == bool:
|
| 1411 |
+
return "checkbox"
|
| 1412 |
+
|
| 1413 |
+
# Handle complex types
|
| 1414 |
+
origin = get_origin(field_type)
|
| 1415 |
+
|
| 1416 |
+
if origin is Literal:
|
| 1417 |
+
return "select"
|
| 1418 |
+
elif origin is Union:
|
| 1419 |
+
args = get_args(field_type)
|
| 1420 |
+
if len(args) == 2 and type(None) in args:
|
| 1421 |
+
# Optional type - use the non-None type
|
| 1422 |
+
non_none_type = args[0] if args[1] is type(None) else args[1]
|
| 1423 |
+
return _determine_input_type(non_none_type)
|
| 1424 |
+
elif all(isinstance(arg, str) for arg in args if arg is not type(None)):
|
| 1425 |
+
return "select"
|
| 1426 |
+
else:
|
| 1427 |
+
return "text"
|
| 1428 |
+
elif hasattr(field_type, '__name__') and 'Tensor' in field_type.__name__:
|
| 1429 |
+
return "tensor"
|
| 1430 |
+
else:
|
| 1431 |
+
return "text"
|
| 1432 |
+
|
| 1433 |
+
|
| 1434 |
+
def _markdown_to_html(markdown: str) -> str:
|
| 1435 |
+
"""Convert basic markdown to HTML for README display."""
|
| 1436 |
+
import html
|
| 1437 |
+
import re
|
| 1438 |
+
|
| 1439 |
+
# Escape HTML first
|
| 1440 |
+
html_content = html.escape(markdown)
|
| 1441 |
+
|
| 1442 |
+
# Convert headers
|
| 1443 |
+
html_content = re.sub(r'^# (.*?)$', r'<h1>\1</h1>', html_content, flags=re.MULTILINE)
|
| 1444 |
+
html_content = re.sub(r'^## (.*?)$', r'<h2>\1</h2>', html_content, flags=re.MULTILINE)
|
| 1445 |
+
html_content = re.sub(r'^### (.*?)$', r'<h3>\1</h3>', html_content, flags=re.MULTILINE)
|
| 1446 |
+
|
| 1447 |
+
# Convert code blocks
|
| 1448 |
+
html_content = re.sub(r'```(.*?)\n(.*?)\n```', r'<pre><code>\2</code></pre>', html_content, flags=re.DOTALL)
|
| 1449 |
+
html_content = re.sub(r'`([^`]+)`', r'<code>\1</code>', html_content)
|
| 1450 |
+
|
| 1451 |
+
# Convert bold and italic
|
| 1452 |
+
html_content = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', html_content)
|
| 1453 |
+
html_content = re.sub(r'\*(.*?)\*', r'<em>\1</em>', html_content)
|
| 1454 |
+
|
| 1455 |
+
# Convert lists
|
| 1456 |
+
html_content = re.sub(r'^- (.*?)$', r'<li>\1</li>', html_content, flags=re.MULTILINE)
|
| 1457 |
+
html_content = re.sub(r'(<li>.*</li>)', r'<ul>\1</ul>', html_content, flags=re.DOTALL)
|
| 1458 |
+
|
| 1459 |
+
# Convert line breaks
|
| 1460 |
+
html_content = html_content.replace('\n', '<br>')
|
| 1461 |
+
|
| 1462 |
+
return html_content
|
| 1463 |
+
|
| 1464 |
+
|
| 1465 |
+
def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str:
|
| 1466 |
+
"""Generate either a chat interface or action form based on environment type."""
|
| 1467 |
+
if is_chat_env:
|
| 1468 |
+
return _generate_chat_interface()
|
| 1469 |
+
else:
|
| 1470 |
+
return _generate_action_form(action_fields)
|
| 1471 |
+
|
| 1472 |
+
def _generate_chat_interface() -> str:
|
| 1473 |
+
"""Generate a chat-style interface for chat environments."""
|
| 1474 |
+
return '''
|
| 1475 |
+
<!-- Chat Interface -->
|
| 1476 |
+
<div class="chat-interface">
|
| 1477 |
+
<h3>Chat Interface</h3>
|
| 1478 |
+
<div class="chat-messages" id="chat-messages">
|
| 1479 |
+
<div class="chat-message system">
|
| 1480 |
+
<div class="message-role">System</div>
|
| 1481 |
+
<div class="message-content">Chat environment ready. Send a message to start the conversation.</div>
|
| 1482 |
+
</div>
|
| 1483 |
+
</div>
|
| 1484 |
+
<div class="chat-input-container">
|
| 1485 |
+
<div class="role-selector">
|
| 1486 |
+
<label for="message-role">Role:</label>
|
| 1487 |
+
<select id="message-role">
|
| 1488 |
+
<option value="user">User</option>
|
| 1489 |
+
<option value="assistant">Assistant</option>
|
| 1490 |
+
</select>
|
| 1491 |
+
</div>
|
| 1492 |
+
<div class="message-input">
|
| 1493 |
+
<textarea id="message-input" placeholder="Type your message here..." rows="3"></textarea>
|
| 1494 |
+
<button class="btn" id="send-message-btn">Send Message</button>
|
| 1495 |
+
</div>
|
| 1496 |
+
</div>
|
| 1497 |
+
</div>
|
| 1498 |
+
'''
|
| 1499 |
+
|
| 1500 |
+
def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str:
|
| 1501 |
+
"""Generate a traditional action form for non-chat environments."""
|
| 1502 |
+
return f'''
|
| 1503 |
+
<!-- Action Form -->
|
| 1504 |
+
<div class="action-form">
|
| 1505 |
+
<h3>Take Action</h3>
|
| 1506 |
+
<form id="action-form">
|
| 1507 |
+
{_generate_action_form_fields(action_fields)}
|
| 1508 |
+
<button type="submit" class="btn" id="step-btn">Step</button>
|
| 1509 |
+
</form>
|
| 1510 |
+
</div>
|
| 1511 |
+
'''
|
| 1512 |
+
|
| 1513 |
+
def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str:
|
| 1514 |
+
"""Generate HTML form fields for action input with enhanced metadata."""
|
| 1515 |
+
if not action_fields:
|
| 1516 |
+
return '<p>No action fields available</p>'
|
| 1517 |
+
|
| 1518 |
+
fields_html = []
|
| 1519 |
+
for field in action_fields:
|
| 1520 |
+
field_html = _generate_single_field(field)
|
| 1521 |
+
fields_html.append(field_html)
|
| 1522 |
+
|
| 1523 |
+
return '\n'.join(fields_html)
|
| 1524 |
+
|
| 1525 |
+
|
| 1526 |
+
def _generate_single_field(field: Dict[str, Any]) -> str:
|
| 1527 |
+
"""Generate HTML for a single form field with enhanced metadata."""
|
| 1528 |
+
field_name = field['name']
|
| 1529 |
+
field_type = field['type']
|
| 1530 |
+
required = field['required']
|
| 1531 |
+
placeholder = field.get('placeholder', '')
|
| 1532 |
+
help_text = field.get('help_text', '')
|
| 1533 |
+
choices = field.get('choices', [])
|
| 1534 |
+
min_value = field.get('min_value')
|
| 1535 |
+
max_value = field.get('max_value')
|
| 1536 |
+
default_value = field.get('default_value')
|
| 1537 |
+
|
| 1538 |
+
# Build label with required indicator
|
| 1539 |
+
label_text = field_name.replace('_', ' ').title()
|
| 1540 |
+
if required:
|
| 1541 |
+
label_text += ' <span style="color: red;">*</span>'
|
| 1542 |
+
|
| 1543 |
+
# Build input attributes
|
| 1544 |
+
input_attrs = []
|
| 1545 |
+
if required:
|
| 1546 |
+
input_attrs.append('required')
|
| 1547 |
+
if placeholder:
|
| 1548 |
+
input_attrs.append(f'placeholder="{placeholder}"')
|
| 1549 |
+
if min_value is not None:
|
| 1550 |
+
input_attrs.append(f'min="{min_value}"')
|
| 1551 |
+
if max_value is not None:
|
| 1552 |
+
input_attrs.append(f'max="{max_value}"')
|
| 1553 |
+
if default_value is not None:
|
| 1554 |
+
input_attrs.append(f'value="{default_value}"')
|
| 1555 |
+
|
| 1556 |
+
attrs_str = ' '.join(input_attrs)
|
| 1557 |
+
|
| 1558 |
+
if field_type == 'checkbox':
|
| 1559 |
+
return f'''
|
| 1560 |
+
<div class="form-group">
|
| 1561 |
+
<label>
|
| 1562 |
+
<input type="checkbox" name="{field_name}" value="true" {attrs_str}>
|
| 1563 |
+
{label_text}
|
| 1564 |
+
</label>
|
| 1565 |
+
{f'<small class="help-text">{help_text}</small>' if help_text else ''}
|
| 1566 |
+
</div>
|
| 1567 |
+
'''
|
| 1568 |
+
|
| 1569 |
+
elif field_type == 'select':
|
| 1570 |
+
options_html = []
|
| 1571 |
+
if not required:
|
| 1572 |
+
options_html.append(f'<option value="">-- Select {label_text} --</option>')
|
| 1573 |
+
|
| 1574 |
+
for choice in choices:
|
| 1575 |
+
selected = 'selected' if str(choice) == str(default_value) else ''
|
| 1576 |
+
options_html.append(f'<option value="{choice}" {selected}>{choice}</option>')
|
| 1577 |
+
|
| 1578 |
+
return f'''
|
| 1579 |
+
<div class="form-group">
|
| 1580 |
+
<label for="{field_name}">{label_text}:</label>
|
| 1581 |
+
<select name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1582 |
+
{''.join(options_html)}
|
| 1583 |
+
</select>
|
| 1584 |
+
{f'<small class="help-text">{help_text}</small>' if help_text else ''}
|
| 1585 |
+
</div>
|
| 1586 |
+
'''
|
| 1587 |
+
|
| 1588 |
+
elif field_type == 'tensor':
|
| 1589 |
+
return f'''
|
| 1590 |
+
<div class="form-group">
|
| 1591 |
+
<label for="{field_name}">{label_text} (comma-separated integers):</label>
|
| 1592 |
+
<input type="text" name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1593 |
+
<small class="help-text">{help_text or 'Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)'}</small>
|
| 1594 |
+
</div>
|
| 1595 |
+
'''
|
| 1596 |
+
|
| 1597 |
+
elif field_type == 'text' and ('message' in field_name.lower() or 'code' in field_name.lower()):
|
| 1598 |
+
return f'''
|
| 1599 |
+
<div class="form-group">
|
| 1600 |
+
<label for="{field_name}">{label_text}:</label>
|
| 1601 |
+
<textarea name="{field_name}" id="{field_name}" rows="3" {attrs_str}></textarea>
|
| 1602 |
+
{f'<small class="help-text">{help_text}</small>' if help_text else ''}
|
| 1603 |
+
</div>
|
| 1604 |
+
'''
|
| 1605 |
+
|
| 1606 |
+
else:
|
| 1607 |
+
return f'''
|
| 1608 |
+
<div class="form-group">
|
| 1609 |
+
<label for="{field_name}">{label_text}:</label>
|
| 1610 |
+
<input type="{field_type}" name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1611 |
+
{f'<small class="help-text">{help_text}</small>' if help_text else ''}
|
| 1612 |
+
</div>
|
| 1613 |
+
'''
|
src/core/http_env_client.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
core/runner_env.py
|
| 3 |
+
Minimal HTTP-based environment client.
|
| 4 |
+
- Talks to a single env worker exposing: POST /reset, POST /step
|
| 5 |
+
|
| 6 |
+
Future hooks (commented below) for:
|
| 7 |
+
- episode_id, seed on reset
|
| 8 |
+
- request_id on step
|
| 9 |
+
- custom headers (auth/trace)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type, TypeVar
|
| 16 |
+
from .containers.runtime import LocalDockerProvider
|
| 17 |
+
import requests
|
| 18 |
+
|
| 19 |
+
from .types import StepResult
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from .containers.runtime import ContainerProvider
|
| 23 |
+
|
| 24 |
+
ActT = TypeVar("ActT")
|
| 25 |
+
ObsT = TypeVar("ObsT")
|
| 26 |
+
EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HTTPEnvClient(ABC, Generic[ActT, ObsT]):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
base_url: str,
|
| 33 |
+
request_timeout_s: float = 15.0,
|
| 34 |
+
default_headers: Optional[Dict[str, str]] = None,
|
| 35 |
+
provider: Optional["ContainerProvider"] = None,
|
| 36 |
+
):
|
| 37 |
+
self._base = base_url.rstrip("/")
|
| 38 |
+
self._timeout = float(request_timeout_s)
|
| 39 |
+
self._http = requests.Session()
|
| 40 |
+
self._headers = default_headers or {}
|
| 41 |
+
self._provider = provider
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def from_docker_image(
|
| 45 |
+
cls: Type[EnvClientT],
|
| 46 |
+
image: str,
|
| 47 |
+
provider: Optional["ContainerProvider"] = None,
|
| 48 |
+
) -> EnvClientT:
|
| 49 |
+
"""
|
| 50 |
+
Create an environment client by spinning up a Docker container locally.
|
| 51 |
+
|
| 52 |
+
This is a development utility that:
|
| 53 |
+
1. Starts a Docker container from the specified image
|
| 54 |
+
2. Waits for the server to be ready
|
| 55 |
+
3. Creates and returns a client instance connected to the container
|
| 56 |
+
|
| 57 |
+
Note: The container lifecycle management is left to the user or higher-level
|
| 58 |
+
orchestration. The container will keep running until manually stopped.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
image: Docker image name to run (e.g., "echo-env:latest")
|
| 62 |
+
provider: Container provider to use (defaults to LocalDockerProvider)
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
An instance of the client class connected to the running container
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
>>> from envs.coding_env.client import CodingEnv
|
| 69 |
+
>>> from envs.coding_env.models import CodeAction
|
| 70 |
+
>>>
|
| 71 |
+
>>> # Create environment from image
|
| 72 |
+
>>> env = CodingEnv.from_docker_image("coding-env:latest")
|
| 73 |
+
>>>
|
| 74 |
+
>>> # Use the environment
|
| 75 |
+
>>> result = env.reset()
|
| 76 |
+
>>> print(result.observation)
|
| 77 |
+
>>>
|
| 78 |
+
>>> step_result = env.step(CodeAction(code="print('hello')"))
|
| 79 |
+
>>> print(step_result.observation.stdout)
|
| 80 |
+
>>>
|
| 81 |
+
>>> # Cleanup (optional)
|
| 82 |
+
>>> env.close()
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
# Use default provider if none provided
|
| 86 |
+
if provider is None:
|
| 87 |
+
provider = LocalDockerProvider()
|
| 88 |
+
|
| 89 |
+
# 1. Start container
|
| 90 |
+
base_url = provider.start_container(image)
|
| 91 |
+
|
| 92 |
+
# 2. Wait for server to be ready
|
| 93 |
+
provider.wait_for_ready(base_url)
|
| 94 |
+
|
| 95 |
+
# 3. Create and return client instance with provider reference
|
| 96 |
+
return cls(base_url=base_url, provider=provider)
|
| 97 |
+
|
| 98 |
+
@abstractmethod
|
| 99 |
+
def _step_payload(self, action: ActT) -> dict:
|
| 100 |
+
"""Convert an Action object to the JSON body expected by the env server."""
|
| 101 |
+
raise NotImplementedError
|
| 102 |
+
|
| 103 |
+
@abstractmethod
|
| 104 |
+
def _parse_result(self, payload: dict) -> StepResult[ObsT]:
|
| 105 |
+
"""Convert a JSON response from the env server to StepResult[ObsT]."""
|
| 106 |
+
raise NotImplementedError
|
| 107 |
+
|
| 108 |
+
@abstractmethod
|
| 109 |
+
def _parse_state(self, payload: dict) -> Any:
|
| 110 |
+
"""Convert a JSON response from the state endpoint to a State object."""
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
# ---------- Environment Server Interface Methods ----------
|
| 114 |
+
def reset(self) -> StepResult[ObsT]:
|
| 115 |
+
body: Dict[str, Any] = {}
|
| 116 |
+
# TODO: later:
|
| 117 |
+
# body["seed"] = seed
|
| 118 |
+
# body["episode_id"] = episode_id
|
| 119 |
+
r = self._http.post(
|
| 120 |
+
f"{self._base}/reset",
|
| 121 |
+
json=body,
|
| 122 |
+
headers=self._headers,
|
| 123 |
+
timeout=self._timeout,
|
| 124 |
+
)
|
| 125 |
+
r.raise_for_status()
|
| 126 |
+
return self._parse_result(r.json())
|
| 127 |
+
|
| 128 |
+
def step(self, action: ActT) -> StepResult[ObsT]:
|
| 129 |
+
body: Dict[str, Any] = {
|
| 130 |
+
"action": self._step_payload(action),
|
| 131 |
+
"timeout_s": int(self._timeout),
|
| 132 |
+
}
|
| 133 |
+
# TODO: later:
|
| 134 |
+
# body["request_id"] = str(uuid.uuid4())
|
| 135 |
+
# body["episode_id"] = current_episode_id
|
| 136 |
+
r = self._http.post(
|
| 137 |
+
f"{self._base}/step",
|
| 138 |
+
json=body,
|
| 139 |
+
headers=self._headers,
|
| 140 |
+
timeout=self._timeout,
|
| 141 |
+
)
|
| 142 |
+
r.raise_for_status()
|
| 143 |
+
return self._parse_result(r.json())
|
| 144 |
+
|
| 145 |
+
def state(self) -> Any:
|
| 146 |
+
"""
|
| 147 |
+
Get the current environment state from the server.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
State object with environment state information (e.g., episode_id, step_count)
|
| 151 |
+
|
| 152 |
+
Example:
|
| 153 |
+
>>> client = EchoEnv.from_docker_image("echo-env:latest")
|
| 154 |
+
>>> result = client.reset()
|
| 155 |
+
>>> state = client.state()
|
| 156 |
+
>>> print(state.episode_id)
|
| 157 |
+
>>> print(state.step_count)
|
| 158 |
+
"""
|
| 159 |
+
r = self._http.get(
|
| 160 |
+
f"{self._base}/state",
|
| 161 |
+
headers=self._headers,
|
| 162 |
+
timeout=self._timeout,
|
| 163 |
+
)
|
| 164 |
+
r.raise_for_status()
|
| 165 |
+
return self._parse_state(r.json())
|
| 166 |
+
|
| 167 |
+
def close(self) -> None:
|
| 168 |
+
"""
|
| 169 |
+
Close the environment and clean up resources.
|
| 170 |
+
|
| 171 |
+
If this client was created via from_docker_image(), this will stop
|
| 172 |
+
and remove the associated container.
|
| 173 |
+
"""
|
| 174 |
+
if self._provider is not None:
|
| 175 |
+
self._provider.stop_container()
|
src/core/tools/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Core tools for code execution and other utilities."""
|
| 8 |
+
|
| 9 |
+
from .local_python_executor import PyExecutor
|
| 10 |
+
|
| 11 |
+
__all__ = ["PyExecutor"]
|
src/core/tools/local_python_executor.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Local Python Executor.
|
| 9 |
+
|
| 10 |
+
This module provides functionality for executing Python code locally by wrapping
|
| 11 |
+
the smolagents LocalPythonExecutor.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from smolagents import LocalPythonExecutor
|
| 15 |
+
|
| 16 |
+
from core.env_server.types import CodeExecResult
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PyExecutor:
|
| 20 |
+
"""
|
| 21 |
+
Wrapper around smolagents LocalPythonExecutor for executing Python code.
|
| 22 |
+
|
| 23 |
+
This class provides a simple interface to execute Python code in a subprocess
|
| 24 |
+
and capture the results including stdout, stderr, and exit code.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
additional_imports: List of additional module imports to authorize.
|
| 28 |
+
For example: ["numpy", "pandas", "matplotlib"]
|
| 29 |
+
These will be added to the base authorized imports.
|
| 30 |
+
|
| 31 |
+
Example:
|
| 32 |
+
>>> # Basic usage with default imports
|
| 33 |
+
>>> executor = PyExecutor()
|
| 34 |
+
>>> result = executor.run("print('Hello, World!')")
|
| 35 |
+
>>> print(result.stdout) # "Hello, World!\n"
|
| 36 |
+
>>> print(result.exit_code) # 0
|
| 37 |
+
>>>
|
| 38 |
+
>>> # Usage with additional imports
|
| 39 |
+
>>> executor = PyExecutor(additional_imports=["numpy", "pandas"])
|
| 40 |
+
>>> result = executor.run("import numpy as np\\nprint(np.array([1, 2, 3]))")
|
| 41 |
+
>>> print(result.stdout) # "[1 2 3]\n"
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, additional_imports: list[str] | None = None):
|
| 45 |
+
"""
|
| 46 |
+
Initialize the PyExecutor with a LocalPythonExecutor instance.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
additional_imports: List of additional module names to authorize for import.
|
| 50 |
+
Defaults to an empty list if not provided.
|
| 51 |
+
"""
|
| 52 |
+
if additional_imports is None:
|
| 53 |
+
additional_imports = []
|
| 54 |
+
self._executor = LocalPythonExecutor(
|
| 55 |
+
additional_authorized_imports=additional_imports
|
| 56 |
+
)
|
| 57 |
+
# Initialize tools to make BASE_PYTHON_TOOLS available (including print)
|
| 58 |
+
self._executor.send_tools({})
|
| 59 |
+
|
| 60 |
+
def run(self, code: str) -> CodeExecResult:
|
| 61 |
+
"""
|
| 62 |
+
Execute Python code and return the result.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
code: Python code string to execute
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
CodeExecResult containing stdout, stderr, and exit_code
|
| 69 |
+
|
| 70 |
+
Example:
|
| 71 |
+
>>> executor = PyExecutor()
|
| 72 |
+
>>> result = executor.run("x = 5 + 3\\nprint(x)")
|
| 73 |
+
>>> print(result.stdout) # "8\n"
|
| 74 |
+
>>> print(result.exit_code) # 0
|
| 75 |
+
>>>
|
| 76 |
+
>>> # Error handling
|
| 77 |
+
>>> result = executor.run("1 / 0")
|
| 78 |
+
>>> print(result.exit_code) # 1
|
| 79 |
+
>>> print(result.stderr) # Contains error message
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
# Execute the code using LocalPythonExecutor
|
| 83 |
+
# LocalPythonExecutor returns a CodeOutput object with output, logs, is_final_answer
|
| 84 |
+
exec_result = self._executor(code)
|
| 85 |
+
|
| 86 |
+
# Extract the logs (which contain print outputs) as stdout
|
| 87 |
+
# The output field contains the return value of the code
|
| 88 |
+
stdout = exec_result.logs
|
| 89 |
+
stderr = ""
|
| 90 |
+
exit_code = 0 # Success
|
| 91 |
+
|
| 92 |
+
return CodeExecResult(
|
| 93 |
+
stdout=stdout,
|
| 94 |
+
stderr=stderr,
|
| 95 |
+
exit_code=exit_code,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
# LocalPythonExecutor raises InterpreterError for various issues
|
| 100 |
+
# (syntax errors, forbidden operations, runtime errors, etc.)
|
| 101 |
+
return CodeExecResult(
|
| 102 |
+
stdout="",
|
| 103 |
+
stderr=str(e),
|
| 104 |
+
exit_code=1, # Non-zero indicates error
|
| 105 |
+
)
|
src/core/types.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Type definitions for EnvTorch
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Generic, Optional, TypeVar
|
| 4 |
+
|
| 5 |
+
# Generic type for observations
|
| 6 |
+
ObsT = TypeVar("ObsT") # TypeVar for typehinting in IDEs
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class StepResult(Generic[ObsT]):
|
| 11 |
+
"""
|
| 12 |
+
Represents the result of one environment step.
|
| 13 |
+
|
| 14 |
+
Attributes:
|
| 15 |
+
observation: The environment's observation after the action.
|
| 16 |
+
reward: Scalar reward for this step (optional).
|
| 17 |
+
done: Whether the episode is finished.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
observation: ObsT
|
| 21 |
+
reward: Optional[float] = None
|
| 22 |
+
done: bool = False
|
src/envs/sumo_rl_env/README.md
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SUMO-RL Environment
|
| 2 |
+
|
| 3 |
+
Integration of traffic signal control with the OpenEnv framework via SUMO (Simulation of Urban MObility) and SUMO-RL.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This environment enables reinforcement learning for **traffic signal control** using SUMO, a microscopic traffic simulation package. Train RL agents to optimize traffic light timing and minimize vehicle delays.
|
| 8 |
+
|
| 9 |
+
**Key Features**:
|
| 10 |
+
- **Realistic traffic simulation** via SUMO
|
| 11 |
+
- **Single-agent mode** for single intersection control
|
| 12 |
+
- **Configurable rewards** (waiting time, queue, pressure, speed)
|
| 13 |
+
- **Multiple networks** supported (custom .net.xml and .rou.xml files)
|
| 14 |
+
- **Docker-ready** with pre-bundled example network
|
| 15 |
+
|
| 16 |
+
## Quick Start
|
| 17 |
+
|
| 18 |
+
### Using Docker (Recommended)
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
from envs.sumo_rl_env import SumoRLEnv, SumoAction
|
| 22 |
+
|
| 23 |
+
# Automatically starts container
|
| 24 |
+
env = SumoRLEnv.from_docker_image("sumo-rl-env:latest")
|
| 25 |
+
|
| 26 |
+
# Reset environment
|
| 27 |
+
result = env.reset()
|
| 28 |
+
print(f"Observation shape: {result.observation.observation_shape}")
|
| 29 |
+
print(f"Available actions: {result.observation.action_mask}")
|
| 30 |
+
|
| 31 |
+
# Take action (select next green phase)
|
| 32 |
+
result = env.step(SumoAction(phase_id=1))
|
| 33 |
+
print(f"Reward: {result.reward}, Done: {result.done}")
|
| 34 |
+
|
| 35 |
+
# Get state
|
| 36 |
+
state = env.state()
|
| 37 |
+
print(f"Simulation time: {state.sim_time}")
|
| 38 |
+
print(f"Total vehicles: {state.total_vehicles}")
|
| 39 |
+
print(f"Mean waiting time: {state.mean_waiting_time}")
|
| 40 |
+
|
| 41 |
+
# Cleanup
|
| 42 |
+
env.close()
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### Building the Docker Image
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
cd OpenEnv
|
| 49 |
+
|
| 50 |
+
# Build base image first (if not already built)
|
| 51 |
+
docker build -t envtorch-base:latest -f src/core/containers/images/Dockerfile .
|
| 52 |
+
|
| 53 |
+
# Build SUMO-RL environment
|
| 54 |
+
docker build -f src/envs/sumo_rl_env/server/Dockerfile -t sumo-rl-env:latest .
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Running with Different Configurations
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Default: single-intersection
|
| 61 |
+
docker run -p 8000:8000 sumo-rl-env:latest
|
| 62 |
+
|
| 63 |
+
# Longer simulation
|
| 64 |
+
docker run -p 8000:8000 \
|
| 65 |
+
-e SUMO_NUM_SECONDS=50000 \
|
| 66 |
+
sumo-rl-env:latest
|
| 67 |
+
|
| 68 |
+
# Different reward function
|
| 69 |
+
docker run -p 8000:8000 \
|
| 70 |
+
-e SUMO_REWARD_FN=queue \
|
| 71 |
+
sumo-rl-env:latest
|
| 72 |
+
|
| 73 |
+
# Custom seed for reproducibility
|
| 74 |
+
docker run -p 8000:8000 \
|
| 75 |
+
-e SUMO_SEED=123 \
|
| 76 |
+
sumo-rl-env:latest
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## Observation
|
| 80 |
+
|
| 81 |
+
The observation is a vector containing:
|
| 82 |
+
- **Phase one-hot**: Current active green phase (one-hot encoded)
|
| 83 |
+
- **Min green flag**: Binary indicator if minimum green time has passed
|
| 84 |
+
- **Lane densities**: Number of vehicles / lane capacity for each incoming lane
|
| 85 |
+
- **Lane queues**: Number of queued vehicles / lane capacity for each incoming lane
|
| 86 |
+
|
| 87 |
+
Observation size varies by network topology (depends on number of phases and lanes).
|
| 88 |
+
|
| 89 |
+
**Default (single-intersection)**:
|
| 90 |
+
- 4 green phases
|
| 91 |
+
- 8 incoming lanes
|
| 92 |
+
- Observation size: ~21 elements
|
| 93 |
+
|
| 94 |
+
## Action Space
|
| 95 |
+
|
| 96 |
+
The action space is discrete and represents selecting the next green phase to activate.
|
| 97 |
+
|
| 98 |
+
- **Action type**: Discrete
|
| 99 |
+
- **Action range**: `[0, num_green_phases - 1]`
|
| 100 |
+
- **Default (single-intersection)**: 4 actions (one per green phase)
|
| 101 |
+
|
| 102 |
+
When a phase change is requested, SUMO automatically inserts a yellow phase before switching.
|
| 103 |
+
|
| 104 |
+
## Rewards
|
| 105 |
+
|
| 106 |
+
Default reward function is **change in cumulative waiting time**:
|
| 107 |
+
```
|
| 108 |
+
reward = -(total_waiting_time_now - total_waiting_time_previous)
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Positive rewards indicate waiting time decreased (good).
|
| 112 |
+
|
| 113 |
+
### Available Reward Functions
|
| 114 |
+
|
| 115 |
+
Set via `SUMO_REWARD_FN` environment variable:
|
| 116 |
+
|
| 117 |
+
- **`diff-waiting-time`** (default): Change in cumulative waiting time
|
| 118 |
+
- **`average-speed`**: Average speed of all vehicles
|
| 119 |
+
- **`queue`**: Negative total queue length
|
| 120 |
+
- **`pressure`**: Pressure metric (incoming - outgoing vehicles)
|
| 121 |
+
|
| 122 |
+
## Configuration
|
| 123 |
+
|
| 124 |
+
### Environment Variables
|
| 125 |
+
|
| 126 |
+
| Variable | Default | Description |
|
| 127 |
+
|----------|---------|-------------|
|
| 128 |
+
| `SUMO_NET_FILE` | `/app/nets/single-intersection.net.xml` | Network topology file |
|
| 129 |
+
| `SUMO_ROUTE_FILE` | `/app/nets/single-intersection.rou.xml` | Vehicle routes file |
|
| 130 |
+
| `SUMO_NUM_SECONDS` | `20000` | Simulation duration (seconds) |
|
| 131 |
+
| `SUMO_DELTA_TIME` | `5` | Seconds between agent actions |
|
| 132 |
+
| `SUMO_YELLOW_TIME` | `2` | Yellow phase duration (seconds) |
|
| 133 |
+
| `SUMO_MIN_GREEN` | `5` | Minimum green time (seconds) |
|
| 134 |
+
| `SUMO_MAX_GREEN` | `50` | Maximum green time (seconds) |
|
| 135 |
+
| `SUMO_REWARD_FN` | `diff-waiting-time` | Reward function name |
|
| 136 |
+
| `SUMO_SEED` | `42` | Random seed (use for reproducibility) |
|
| 137 |
+
|
| 138 |
+
### Using Custom Networks
|
| 139 |
+
|
| 140 |
+
To use your own SUMO network:
|
| 141 |
+
|
| 142 |
+
```python
|
| 143 |
+
from envs.sumo_rl_env import SumoRLEnv
|
| 144 |
+
|
| 145 |
+
env = SumoRLEnv.from_docker_image(
|
| 146 |
+
"sumo-rl-env:latest",
|
| 147 |
+
volumes={
|
| 148 |
+
"/path/to/your/nets": {"bind": "/nets", "mode": "ro"}
|
| 149 |
+
},
|
| 150 |
+
environment={
|
| 151 |
+
"SUMO_NET_FILE": "/nets/my-network.net.xml",
|
| 152 |
+
"SUMO_ROUTE_FILE": "/nets/my-routes.rou.xml",
|
| 153 |
+
}
|
| 154 |
+
)
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
Your network directory should contain:
|
| 158 |
+
- `.net.xml` - Network topology (roads, junctions, traffic lights)
|
| 159 |
+
- `.rou.xml` - Vehicle routes (trip definitions, flow rates)
|
| 160 |
+
|
| 161 |
+
## API Reference
|
| 162 |
+
|
| 163 |
+
### SumoAction
|
| 164 |
+
|
| 165 |
+
```python
|
| 166 |
+
@dataclass
|
| 167 |
+
class SumoAction(Action):
|
| 168 |
+
phase_id: int # Green phase to activate (0 to num_phases-1)
|
| 169 |
+
ts_id: str = "0" # Traffic signal ID (for multi-agent)
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
### SumoObservation
|
| 173 |
+
|
| 174 |
+
```python
|
| 175 |
+
@dataclass
|
| 176 |
+
class SumoObservation(Observation):
|
| 177 |
+
observation: List[float] # Observation vector
|
| 178 |
+
observation_shape: List[int] # Shape for reshaping
|
| 179 |
+
action_mask: List[int] # Valid action indices
|
| 180 |
+
sim_time: float # Current simulation time
|
| 181 |
+
done: bool # Episode finished
|
| 182 |
+
reward: Optional[float] # Reward from last action
|
| 183 |
+
metadata: Dict # System metrics
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
### SumoState
|
| 187 |
+
|
| 188 |
+
```python
|
| 189 |
+
@dataclass
|
| 190 |
+
class SumoState(State):
|
| 191 |
+
episode_id: str # Unique episode ID
|
| 192 |
+
step_count: int # Steps taken
|
| 193 |
+
net_file: str # Network file path
|
| 194 |
+
route_file: str # Route file path
|
| 195 |
+
sim_time: float # Current simulation time
|
| 196 |
+
total_vehicles: int # Total vehicles in simulation
|
| 197 |
+
total_waiting_time: float # Cumulative waiting time
|
| 198 |
+
mean_waiting_time: float # Mean waiting time
|
| 199 |
+
mean_speed: float # Mean vehicle speed
|
| 200 |
+
# ... configuration parameters
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
## Example Training Loop
|
| 204 |
+
|
| 205 |
+
```python
|
| 206 |
+
from envs.sumo_rl_env import SumoRLEnv, SumoAction
|
| 207 |
+
import numpy as np
|
| 208 |
+
|
| 209 |
+
# Start environment
|
| 210 |
+
env = SumoRLEnv.from_docker_image("sumo-rl-env:latest")
|
| 211 |
+
|
| 212 |
+
# Training loop
|
| 213 |
+
for episode in range(10):
|
| 214 |
+
result = env.reset()
|
| 215 |
+
episode_reward = 0
|
| 216 |
+
steps = 0
|
| 217 |
+
|
| 218 |
+
while not result.done and steps < 1000:
|
| 219 |
+
# Random policy (replace with your RL agent)
|
| 220 |
+
action_id = np.random.choice(result.observation.action_mask)
|
| 221 |
+
|
| 222 |
+
# Take action
|
| 223 |
+
result = env.step(SumoAction(phase_id=action_id))
|
| 224 |
+
|
| 225 |
+
episode_reward += result.reward or 0
|
| 226 |
+
steps += 1
|
| 227 |
+
|
| 228 |
+
# Print progress every 100 steps
|
| 229 |
+
if steps % 100 == 0:
|
| 230 |
+
state = env.state()
|
| 231 |
+
print(f"Step {steps}: "
|
| 232 |
+
f"reward={result.reward:.2f}, "
|
| 233 |
+
f"vehicles={state.total_vehicles}, "
|
| 234 |
+
f"waiting={state.mean_waiting_time:.2f}")
|
| 235 |
+
|
| 236 |
+
print(f"Episode {episode}: total_reward={episode_reward:.2f}, steps={steps}")
|
| 237 |
+
|
| 238 |
+
env.close()
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
## Performance Notes
|
| 242 |
+
|
| 243 |
+
### Simulation Speed
|
| 244 |
+
|
| 245 |
+
- **Reset time**: 1-5 seconds (starts new SUMO simulation)
|
| 246 |
+
- **Step time**: ~50-200ms per step (depends on network size)
|
| 247 |
+
- **Episode duration**: Minutes (20,000 sim seconds with delta_time=5 → ~4,000 steps)
|
| 248 |
+
|
| 249 |
+
### Optimization
|
| 250 |
+
|
| 251 |
+
For faster simulation:
|
| 252 |
+
1. Reduce `SUMO_NUM_SECONDS` for shorter episodes
|
| 253 |
+
2. Increase `SUMO_DELTA_TIME` for fewer decisions
|
| 254 |
+
3. Use simpler networks with fewer vehicles
|
| 255 |
+
|
| 256 |
+
## Architecture
|
| 257 |
+
|
| 258 |
+
```
|
| 259 |
+
┌─────────────────────────────────┐
|
| 260 |
+
│ Client: SumoRLEnv │
|
| 261 |
+
│ .step(phase_id=1) │
|
| 262 |
+
└──────────────┬──────────────────┘
|
| 263 |
+
│ HTTP
|
| 264 |
+
┌──────────────▼──────────────────┐
|
| 265 |
+
│ FastAPI Server (Docker) │
|
| 266 |
+
│ SumoEnvironment │
|
| 267 |
+
│ ├─ Wraps sumo_rl │
|
| 268 |
+
│ ├─ Single-agent mode │
|
| 269 |
+
│ └─ No GUI │
|
| 270 |
+
└──────────────┬──────────────────┘
|
| 271 |
+
│
|
| 272 |
+
┌──────────────▼──────────────────┐
|
| 273 |
+
│ SUMO Simulator │
|
| 274 |
+
│ - Reads .net.xml (network) │
|
| 275 |
+
│ - Reads .rou.xml (routes) │
|
| 276 |
+
│ - Simulates traffic flow │
|
| 277 |
+
│ - Provides observations │
|
| 278 |
+
└─────────────────────────────────┘
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
## Bundled Network
|
| 282 |
+
|
| 283 |
+
The default `single-intersection` network is a simple 4-way intersection with:
|
| 284 |
+
- **4 incoming roads** (North, South, East, West)
|
| 285 |
+
- **4 green phases** (NS straight, NS left, EW straight, EW left)
|
| 286 |
+
- **Vehicle flow**: Continuous stream with varying rates
|
| 287 |
+
|
| 288 |
+
## Limitations
|
| 289 |
+
|
| 290 |
+
- **No GUI in Docker**: SUMO GUI requires X server (not available in containers)
|
| 291 |
+
- **Single-agent only**: Multi-agent (multiple intersections) coming in future version
|
| 292 |
+
- **Fixed network per container**: Each container uses one network topology
|
| 293 |
+
- **Memory usage**: ~500MB for small networks, 2-4GB for large city networks
|
| 294 |
+
|
| 295 |
+
## Troubleshooting
|
| 296 |
+
|
| 297 |
+
### Container won't start
|
| 298 |
+
```bash
|
| 299 |
+
# Check logs
|
| 300 |
+
docker logs <container-id>
|
| 301 |
+
|
| 302 |
+
# Verify network files exist
|
| 303 |
+
docker run sumo-rl-env:latest ls -la /app/nets/
|
| 304 |
+
```
|
| 305 |
+
|
| 306 |
+
### "SUMO_HOME not set" error
|
| 307 |
+
This should be automatic in Docker. If running locally:
|
| 308 |
+
```bash
|
| 309 |
+
export SUMO_HOME=/usr/share/sumo
|
| 310 |
+
```
|
| 311 |
+
|
| 312 |
+
### Slow performance
|
| 313 |
+
- Reduce simulation duration: `SUMO_NUM_SECONDS=5000`
|
| 314 |
+
- Increase action interval: `SUMO_DELTA_TIME=10`
|
| 315 |
+
- Use smaller networks with fewer vehicles
|
| 316 |
+
|
| 317 |
+
## References
|
| 318 |
+
|
| 319 |
+
- [SUMO Documentation](https://sumo.dlr.de/docs/)
|
| 320 |
+
- [SUMO-RL GitHub](https://github.com/LucasAlegre/sumo-rl)
|
| 321 |
+
- [SUMO-RL Paper](https://peerj.com/articles/cs-575/)
|
| 322 |
+
- [RESCO Benchmarks](https://github.com/jault/RESCO)
|
| 323 |
+
|
| 324 |
+
## Citation
|
| 325 |
+
|
| 326 |
+
If you use SUMO-RL in your research, please cite:
|
| 327 |
+
|
| 328 |
+
```bibtex
|
| 329 |
+
@misc{sumorl,
|
| 330 |
+
author = {Lucas N. Alegre},
|
| 331 |
+
title = {{SUMO-RL}},
|
| 332 |
+
year = {2019},
|
| 333 |
+
publisher = {GitHub},
|
| 334 |
+
journal = {GitHub repository},
|
| 335 |
+
howpublished = {\url{https://github.com/LucasAlegre/sumo-rl}},
|
| 336 |
+
}
|
| 337 |
+
```
|
| 338 |
+
|
| 339 |
+
## License
|
| 340 |
+
|
| 341 |
+
This integration is licensed under the BSD-style license. SUMO-RL and SUMO have their own licenses.
|
src/envs/sumo_rl_env/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
SUMO-RL Environment for OpenEnv.
|
| 9 |
+
|
| 10 |
+
This module provides OpenEnv integration for traffic signal control using
|
| 11 |
+
SUMO (Simulation of Urban MObility) via the SUMO-RL library.
|
| 12 |
+
|
| 13 |
+
Example:
|
| 14 |
+
>>> from envs.sumo_rl_env import SumoRLEnv, SumoAction
|
| 15 |
+
>>>
|
| 16 |
+
>>> # Connect to a running server or start via Docker
|
| 17 |
+
>>> env = SumoRLEnv.from_docker_image("sumo-rl-env:latest")
|
| 18 |
+
>>>
|
| 19 |
+
>>> # Reset and interact
|
| 20 |
+
>>> result = env.reset()
|
| 21 |
+
>>> result = env.step(SumoAction(phase_id=1))
|
| 22 |
+
>>> print(result.reward, result.done)
|
| 23 |
+
>>>
|
| 24 |
+
>>> # Cleanup
|
| 25 |
+
>>> env.close()
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from .client import SumoRLEnv
|
| 29 |
+
from .models import SumoAction, SumoObservation, SumoState
|
| 30 |
+
|
| 31 |
+
__all__ = ["SumoRLEnv", "SumoAction", "SumoObservation", "SumoState"]
|
src/envs/sumo_rl_env/client.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
HTTP client for SUMO-RL environment.
|
| 9 |
+
|
| 10 |
+
This module provides a client to interact with the SUMO traffic signal
|
| 11 |
+
control environment over HTTP.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Any, Dict
|
| 15 |
+
|
| 16 |
+
from core.http_env_client import HTTPEnvClient
|
| 17 |
+
from core.types import StepResult
|
| 18 |
+
|
| 19 |
+
from .models import SumoAction, SumoObservation, SumoState
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SumoRLEnv(HTTPEnvClient[SumoAction, SumoObservation]):
|
| 23 |
+
"""
|
| 24 |
+
HTTP client for SUMO-RL traffic signal control environment.
|
| 25 |
+
|
| 26 |
+
This client communicates with a SUMO environment server to control
|
| 27 |
+
traffic signals using reinforcement learning.
|
| 28 |
+
|
| 29 |
+
Example:
|
| 30 |
+
>>> # Start container and connect
|
| 31 |
+
>>> env = SumoRLEnv.from_docker_image("sumo-rl-env:latest")
|
| 32 |
+
>>>
|
| 33 |
+
>>> # Reset environment
|
| 34 |
+
>>> result = env.reset()
|
| 35 |
+
>>> print(f"Observation shape: {result.observation.observation_shape}")
|
| 36 |
+
>>> print(f"Action space: {result.observation.action_mask}")
|
| 37 |
+
>>>
|
| 38 |
+
>>> # Take action
|
| 39 |
+
>>> result = env.step(SumoAction(phase_id=1))
|
| 40 |
+
>>> print(f"Reward: {result.reward}, Done: {result.done}")
|
| 41 |
+
>>>
|
| 42 |
+
>>> # Get state
|
| 43 |
+
>>> state = env.state()
|
| 44 |
+
>>> print(f"Sim time: {state.sim_time}, Total vehicles: {state.total_vehicles}")
|
| 45 |
+
>>>
|
| 46 |
+
>>> # Cleanup
|
| 47 |
+
>>> env.close()
|
| 48 |
+
|
| 49 |
+
Example with custom network:
|
| 50 |
+
>>> # Use custom SUMO network via volume mount
|
| 51 |
+
>>> env = SumoRLEnv.from_docker_image(
|
| 52 |
+
... "sumo-rl-env:latest",
|
| 53 |
+
... port=8000,
|
| 54 |
+
... volumes={
|
| 55 |
+
... "/path/to/my/nets": {"bind": "/nets", "mode": "ro"}
|
| 56 |
+
... },
|
| 57 |
+
... environment={
|
| 58 |
+
... "SUMO_NET_FILE": "/nets/my-network.net.xml",
|
| 59 |
+
... "SUMO_ROUTE_FILE": "/nets/my-routes.rou.xml",
|
| 60 |
+
... }
|
| 61 |
+
... )
|
| 62 |
+
|
| 63 |
+
Example with configuration:
|
| 64 |
+
>>> # Adjust simulation parameters
|
| 65 |
+
>>> env = SumoRLEnv.from_docker_image(
|
| 66 |
+
... "sumo-rl-env:latest",
|
| 67 |
+
... environment={
|
| 68 |
+
... "SUMO_NUM_SECONDS": "10000",
|
| 69 |
+
... "SUMO_DELTA_TIME": "10",
|
| 70 |
+
... "SUMO_REWARD_FN": "queue",
|
| 71 |
+
... "SUMO_SEED": "123",
|
| 72 |
+
... }
|
| 73 |
+
... )
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def _step_payload(self, action: SumoAction) -> Dict[str, Any]:
|
| 77 |
+
"""
|
| 78 |
+
Convert SumoAction to JSON payload for HTTP request.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
action: SumoAction containing phase_id to execute.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Dictionary payload for step endpoint.
|
| 85 |
+
"""
|
| 86 |
+
return {
|
| 87 |
+
"phase_id": action.phase_id,
|
| 88 |
+
"ts_id": action.ts_id,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SumoObservation]:
|
| 92 |
+
"""
|
| 93 |
+
Parse step result from HTTP response JSON.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
payload: JSON response from step endpoint.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
StepResult containing SumoObservation.
|
| 100 |
+
"""
|
| 101 |
+
obs_data = payload.get("observation", {})
|
| 102 |
+
|
| 103 |
+
observation = SumoObservation(
|
| 104 |
+
observation=obs_data.get("observation", []),
|
| 105 |
+
observation_shape=obs_data.get("observation_shape", []),
|
| 106 |
+
action_mask=obs_data.get("action_mask", []),
|
| 107 |
+
sim_time=obs_data.get("sim_time", 0.0),
|
| 108 |
+
done=obs_data.get("done", False),
|
| 109 |
+
reward=obs_data.get("reward"),
|
| 110 |
+
metadata=obs_data.get("metadata", {}),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return StepResult(
|
| 114 |
+
observation=observation,
|
| 115 |
+
reward=payload.get("reward"),
|
| 116 |
+
done=payload.get("done", False),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def _parse_state(self, payload: Dict[str, Any]) -> SumoState:
|
| 120 |
+
"""
|
| 121 |
+
Parse state from HTTP response JSON.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
payload: JSON response from state endpoint.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
SumoState object.
|
| 128 |
+
"""
|
| 129 |
+
return SumoState(
|
| 130 |
+
episode_id=payload.get("episode_id", ""),
|
| 131 |
+
step_count=payload.get("step_count", 0),
|
| 132 |
+
net_file=payload.get("net_file", ""),
|
| 133 |
+
route_file=payload.get("route_file", ""),
|
| 134 |
+
num_seconds=payload.get("num_seconds", 20000),
|
| 135 |
+
delta_time=payload.get("delta_time", 5),
|
| 136 |
+
yellow_time=payload.get("yellow_time", 2),
|
| 137 |
+
min_green=payload.get("min_green", 5),
|
| 138 |
+
max_green=payload.get("max_green", 50),
|
| 139 |
+
reward_fn=payload.get("reward_fn", "diff-waiting-time"),
|
| 140 |
+
sim_time=payload.get("sim_time", 0.0),
|
| 141 |
+
total_vehicles=payload.get("total_vehicles", 0),
|
| 142 |
+
total_waiting_time=payload.get("total_waiting_time", 0.0),
|
| 143 |
+
mean_waiting_time=payload.get("mean_waiting_time", 0.0),
|
| 144 |
+
mean_speed=payload.get("mean_speed", 0.0),
|
| 145 |
+
)
|
src/envs/sumo_rl_env/models.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for SUMO-RL Environment.
|
| 9 |
+
|
| 10 |
+
This module defines the Action, Observation, and State types for traffic
|
| 11 |
+
signal control using SUMO (Simulation of Urban MObility).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from typing import Dict, List, Optional
|
| 16 |
+
|
| 17 |
+
from core.env_server import Action, Observation, State
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class SumoAction(Action):
|
| 22 |
+
"""
|
| 23 |
+
Action for SUMO traffic signal control environment.
|
| 24 |
+
|
| 25 |
+
Represents selecting which traffic light phase to activate next.
|
| 26 |
+
|
| 27 |
+
Attributes:
|
| 28 |
+
phase_id: Index of the green phase to activate (0 to num_phases-1)
|
| 29 |
+
ts_id: Traffic signal ID (for multi-agent support, default "0")
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
phase_id: int
|
| 33 |
+
ts_id: str = "0"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class SumoObservation(Observation):
|
| 38 |
+
"""
|
| 39 |
+
Observation from SUMO traffic signal environment.
|
| 40 |
+
|
| 41 |
+
Contains traffic metrics for decision-making.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
observation: Flattened observation vector containing:
|
| 45 |
+
- One-hot encoded current phase
|
| 46 |
+
- Min green flag (binary)
|
| 47 |
+
- Lane densities (normalized)
|
| 48 |
+
- Lane queues (normalized)
|
| 49 |
+
observation_shape: Shape of observation for reshaping
|
| 50 |
+
action_mask: List of valid action indices
|
| 51 |
+
sim_time: Current simulation time in seconds
|
| 52 |
+
done: Whether episode is complete
|
| 53 |
+
reward: Reward from last action (None on reset)
|
| 54 |
+
metadata: Additional info (system metrics, etc.)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
observation: List[float] = field(default_factory=list)
|
| 58 |
+
observation_shape: List[int] = field(default_factory=list)
|
| 59 |
+
action_mask: List[int] = field(default_factory=list)
|
| 60 |
+
sim_time: float = 0.0
|
| 61 |
+
done: bool = False
|
| 62 |
+
reward: Optional[float] = None
|
| 63 |
+
metadata: Dict = field(default_factory=dict)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class SumoState(State):
|
| 68 |
+
"""
|
| 69 |
+
State of SUMO traffic signal environment.
|
| 70 |
+
|
| 71 |
+
Tracks both configuration and runtime state.
|
| 72 |
+
|
| 73 |
+
Configuration attributes:
|
| 74 |
+
net_file: Path to SUMO network file (.net.xml)
|
| 75 |
+
route_file: Path to SUMO route file (.rou.xml)
|
| 76 |
+
num_seconds: Total simulation duration in seconds
|
| 77 |
+
delta_time: Seconds between agent actions
|
| 78 |
+
yellow_time: Duration of yellow phase in seconds
|
| 79 |
+
min_green: Minimum green time per phase in seconds
|
| 80 |
+
max_green: Maximum green time per phase in seconds
|
| 81 |
+
reward_fn: Name of reward function used
|
| 82 |
+
|
| 83 |
+
Runtime attributes:
|
| 84 |
+
episode_id: Unique episode identifier
|
| 85 |
+
step_count: Number of steps taken in episode
|
| 86 |
+
sim_time: Current simulation time in seconds
|
| 87 |
+
total_vehicles: Total number of vehicles in simulation
|
| 88 |
+
total_waiting_time: Cumulative waiting time across all vehicles
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
# Episode tracking
|
| 92 |
+
episode_id: str = ""
|
| 93 |
+
step_count: int = 0
|
| 94 |
+
|
| 95 |
+
# SUMO configuration
|
| 96 |
+
net_file: str = ""
|
| 97 |
+
route_file: str = ""
|
| 98 |
+
num_seconds: int = 20000
|
| 99 |
+
delta_time: int = 5
|
| 100 |
+
yellow_time: int = 2
|
| 101 |
+
min_green: int = 5
|
| 102 |
+
max_green: int = 50
|
| 103 |
+
reward_fn: str = "diff-waiting-time"
|
| 104 |
+
|
| 105 |
+
# Runtime metrics
|
| 106 |
+
sim_time: float = 0.0
|
| 107 |
+
total_vehicles: int = 0
|
| 108 |
+
total_waiting_time: float = 0.0
|
| 109 |
+
mean_waiting_time: float = 0.0
|
| 110 |
+
mean_speed: float = 0.0
|
src/envs/sumo_rl_env/nets/single-intersection/single-intersection.edg.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<edges>
|
| 2 |
+
<edge from="n" id="n_t" to="t" numLanes="2"/>
|
| 3 |
+
<edge from="w" id="w_t" to="t" numLanes="2"/>
|
| 4 |
+
<edge from="t" id="t_s" to="s" numLanes="2"/>
|
| 5 |
+
<edge from="t" id="t_e" to="e" numLanes="2"/>
|
| 6 |
+
</edges>
|
src/envs/sumo_rl_env/nets/single-intersection/single-intersection.net.xml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
|
| 3 |
+
<!-- generated on seg 17 dez 2018 17:22:14 -02 by Netedit Version 0.32.0
|
| 4 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 5 |
+
|
| 6 |
+
<configuration xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="http://sumo.dlr.de/xsd/netconvertConfiguration.xsd">
|
| 7 |
+
|
| 8 |
+
<input>
|
| 9 |
+
<sumo-net-file value="nets/single-intersection/single-intersection.net.xml"/>
|
| 10 |
+
</input>
|
| 11 |
+
|
| 12 |
+
<output>
|
| 13 |
+
<output-file value="/home/lucas/Documents/sumo-rl/nets/single-intersection/single-intersection2.net.xml"/>
|
| 14 |
+
</output>
|
| 15 |
+
|
| 16 |
+
<processing>
|
| 17 |
+
<no-turnarounds value="true"/>
|
| 18 |
+
<offset.disable-normalization value="true"/>
|
| 19 |
+
<lefthand value="false"/>
|
| 20 |
+
<junctions.corner-detail value="0"/>
|
| 21 |
+
<rectangular-lane-cut value="false"/>
|
| 22 |
+
<walkingareas value="false"/>
|
| 23 |
+
</processing>
|
| 24 |
+
|
| 25 |
+
</configuration>
|
| 26 |
+
-->
|
| 27 |
+
|
| 28 |
+
<net version="0.27" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="http://sumo.dlr.de/xsd/net_file.xsd">
|
| 29 |
+
|
| 30 |
+
<location netOffset="150.00,150.00" convBoundary="0.00,0.00,300.00,300.00" origBoundary="-150.00,-150.00,150.00,150.00" projParameter="!"/>
|
| 31 |
+
|
| 32 |
+
<edge id=":t_0" function="internal">
|
| 33 |
+
<lane id=":t_0_0" index="0" speed="13.90" length="9.50" shape="145.05,151.45 145.05,141.95"/>
|
| 34 |
+
<lane id=":t_0_1" index="1" speed="13.90" length="9.50" shape="148.35,151.45 148.35,141.95"/>
|
| 35 |
+
</edge>
|
| 36 |
+
<edge id=":t_2" function="internal">
|
| 37 |
+
<lane id=":t_2_0" index="0" speed="13.90" length="9.50" shape="141.95,145.05 151.45,145.05"/>
|
| 38 |
+
<lane id=":t_2_1" index="1" speed="13.90" length="9.50" shape="141.95,148.35 151.45,148.35"/>
|
| 39 |
+
</edge>
|
| 40 |
+
|
| 41 |
+
<edge id="n_t" from="n" to="t" priority="-1">
|
| 42 |
+
<lane id="n_t_0" index="0" speed="13.90" length="148.55" shape="145.05,300.00 145.05,151.45"/>
|
| 43 |
+
<lane id="n_t_1" index="1" speed="13.90" length="148.55" shape="148.35,300.00 148.35,151.45"/>
|
| 44 |
+
</edge>
|
| 45 |
+
<edge id="t_e" from="t" to="e" priority="-1">
|
| 46 |
+
<lane id="t_e_0" index="0" speed="13.90" length="148.55" shape="151.45,145.05 300.00,145.05"/>
|
| 47 |
+
<lane id="t_e_1" index="1" speed="13.90" length="148.55" shape="151.45,148.35 300.00,148.35"/>
|
| 48 |
+
</edge>
|
| 49 |
+
<edge id="t_s" from="t" to="s" priority="-1">
|
| 50 |
+
<lane id="t_s_0" index="0" speed="13.90" length="141.95" shape="145.05,141.95 145.05,0.00"/>
|
| 51 |
+
<lane id="t_s_1" index="1" speed="13.90" length="141.95" shape="148.35,141.95 148.35,0.00"/>
|
| 52 |
+
</edge>
|
| 53 |
+
<edge id="w_t" from="w" to="t" priority="-1">
|
| 54 |
+
<lane id="w_t_0" index="0" speed="13.90" length="141.95" shape="0.00,145.05 141.95,145.05"/>
|
| 55 |
+
<lane id="w_t_1" index="1" speed="13.90" length="141.95" shape="0.00,148.35 141.95,148.35"/>
|
| 56 |
+
</edge>
|
| 57 |
+
|
| 58 |
+
<tlLogic id="t" type="static" programID="0" offset="0">
|
| 59 |
+
<phase duration="42" state="GGrr"/>
|
| 60 |
+
<phase duration="2" state="yyrr"/>
|
| 61 |
+
<phase duration="42" state="rrGG"/>
|
| 62 |
+
<phase duration="2" state="rryy"/>
|
| 63 |
+
</tlLogic>
|
| 64 |
+
|
| 65 |
+
<junction id="e" type="dead_end" x="300.00" y="150.00" incLanes="t_e_0 t_e_1" intLanes="" shape="300.00,143.45 300.00,149.95"/>
|
| 66 |
+
<junction id="n" type="dead_end" x="150.00" y="300.00" incLanes="" intLanes="" shape="149.95,300.00 143.45,300.00"/>
|
| 67 |
+
<junction id="s" type="dead_end" x="150.00" y="0.00" incLanes="t_s_0 t_s_1" intLanes="" shape="143.45,0.00 149.95,0.00"/>
|
| 68 |
+
<junction id="t" type="traffic_light" x="150.00" y="150.00" incLanes="n_t_0 n_t_1 w_t_0 w_t_1" intLanes=":t_0_0 :t_0_1 :t_2_0 :t_2_1" shape="143.45,151.45 149.95,151.45 151.45,149.95 151.45,143.45 149.95,141.95 143.45,141.95 141.95,143.45 141.95,149.95">
|
| 69 |
+
<request index="0" response="1100" foes="1100" cont="0"/>
|
| 70 |
+
<request index="1" response="1100" foes="1100" cont="0"/>
|
| 71 |
+
<request index="2" response="0000" foes="0011" cont="0"/>
|
| 72 |
+
<request index="3" response="0000" foes="0011" cont="0"/>
|
| 73 |
+
</junction>
|
| 74 |
+
<junction id="w" type="dead_end" x="0.00" y="150.00" incLanes="" intLanes="" shape="0.00,149.95 0.00,143.45"/>
|
| 75 |
+
|
| 76 |
+
<connection from="n_t" to="t_s" fromLane="0" toLane="0" via=":t_0_0" tl="t" linkIndex="0" dir="s" state="o"/>
|
| 77 |
+
<connection from="n_t" to="t_s" fromLane="1" toLane="1" via=":t_0_1" tl="t" linkIndex="1" dir="s" state="o"/>
|
| 78 |
+
<connection from="w_t" to="t_e" fromLane="0" toLane="0" via=":t_2_0" tl="t" linkIndex="2" dir="s" state="o"/>
|
| 79 |
+
<connection from="w_t" to="t_e" fromLane="1" toLane="1" via=":t_2_1" tl="t" linkIndex="3" dir="s" state="o"/>
|
| 80 |
+
|
| 81 |
+
<connection from=":t_0" to="t_s" fromLane="0" toLane="0" dir="s" state="M"/>
|
| 82 |
+
<connection from=":t_0" to="t_s" fromLane="1" toLane="1" dir="s" state="M"/>
|
| 83 |
+
<connection from=":t_2" to="t_e" fromLane="0" toLane="0" dir="s" state="M"/>
|
| 84 |
+
<connection from=":t_2" to="t_e" fromLane="1" toLane="1" dir="s" state="M"/>
|
| 85 |
+
|
| 86 |
+
</net>
|