Ev3Dev commited on
Commit
df98fca
Β·
verified Β·
1 Parent(s): fadba80

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=hackathon
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ ENV ENABLE_WEB_INTERFACE=true
81
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,255 +1,255 @@
1
- ---
2
- title: Hackathon Environment Server
3
- emoji: 🎢
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: docker
7
- pinned: false
8
- app_port: 8000
9
- base_path: /web
10
- tags:
11
- - openenv
12
- ---
13
-
14
- # Hackathon Environment
15
-
16
- A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
-
18
- ## Quick Start
19
-
20
- The simplest way to use the Hackathon environment is through the `HackathonEnv` class:
21
-
22
- ```python
23
- from hackathon import HackathonAction, HackathonEnv
24
-
25
- try:
26
- # Create environment from Docker image
27
- hackathonenv = HackathonEnv.from_docker_image("hackathon-env:latest")
28
-
29
- # Reset
30
- result = hackathonenv.reset()
31
- print(f"Reset: {result.observation.echoed_message}")
32
-
33
- # Send multiple messages
34
- messages = ["Hello, World!", "Testing echo", "Final message"]
35
-
36
- for msg in messages:
37
- result = hackathonenv.step(HackathonAction(message=msg))
38
- print(f"Sent: '{msg}'")
39
- print(f" β†’ Echoed: '{result.observation.echoed_message}'")
40
- print(f" β†’ Length: {result.observation.message_length}")
41
- print(f" β†’ Reward: {result.reward}")
42
-
43
- finally:
44
- # Always clean up
45
- hackathonenv.close()
46
- ```
47
-
48
- That's it! The `HackathonEnv.from_docker_image()` method handles:
49
- - Starting the Docker container
50
- - Waiting for the server to be ready
51
- - Connecting to the environment
52
- - Container cleanup when you call `close()`
53
-
54
- ## Building the Docker Image
55
-
56
- Before using the environment, you need to build the Docker image:
57
-
58
- ```bash
59
- # From project root
60
- docker build -t hackathon-env:latest -f server/Dockerfile .
61
- ```
62
-
63
- ## Deploying to Hugging Face Spaces
64
-
65
- You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
66
-
67
- ```bash
68
- # From the environment directory (where openenv.yaml is located)
69
- openenv push
70
-
71
- # Or specify options
72
- openenv push --namespace my-org --private
73
- ```
74
-
75
- The `openenv push` command will:
76
- 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
- 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
- 3. Upload to Hugging Face (ensuring you're logged in)
79
-
80
- ### Prerequisites
81
-
82
- - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
-
84
- ### Options
85
-
86
- - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
- - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
- - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
- - `--private`: Deploy the space as private (default: public)
90
-
91
- ### Examples
92
-
93
- ```bash
94
- # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
- openenv push
96
-
97
- # Push to a specific repository
98
- openenv push --repo-id my-org/my-env
99
-
100
- # Push with a custom base image
101
- openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
102
-
103
- # Push as a private space
104
- openenv push --private
105
-
106
- # Combine options
107
- openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
- ```
109
-
110
- After deployment, your space will be available at:
111
- `https://huggingface.co/spaces/<repo-id>`
112
-
113
- The deployed space includes:
114
- - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
- - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
- - **Health Check** at `/health` - Container health monitoring
117
- - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
118
-
119
- ## Environment Details
120
-
121
- ### Action
122
- **HackathonAction**: Contains a single field
123
- - `message` (str) - The message to echo back
124
-
125
- ### Observation
126
- **HackathonObservation**: Contains the echo response and metadata
127
- - `echoed_message` (str) - The message echoed back
128
- - `message_length` (int) - Length of the message
129
- - `reward` (float) - Reward based on message length (length Γ— 0.1)
130
- - `done` (bool) - Always False for echo environment
131
- - `metadata` (dict) - Additional info like step count
132
-
133
- ### Reward
134
- The reward is calculated as: `message_length Γ— 0.1`
135
- - "Hi" β†’ reward: 0.2
136
- - "Hello, World!" β†’ reward: 1.3
137
- - Empty message β†’ reward: 0.0
138
-
139
- ## Advanced Usage
140
-
141
- ### Connecting to an Existing Server
142
-
143
- If you already have a Hackathon environment server running, you can connect directly:
144
-
145
- ```python
146
- from hackathon import HackathonEnv
147
-
148
- # Connect to existing server
149
- hackathonenv = HackathonEnv(base_url="<ENV_HTTP_URL_HERE>")
150
-
151
- # Use as normal
152
- result = hackathonenv.reset()
153
- result = hackathonenv.step(HackathonAction(message="Hello!"))
154
- ```
155
-
156
- Note: When connecting to an existing server, `hackathonenv.close()` will NOT stop the server.
157
-
158
- ### Using the Context Manager
159
-
160
- The client supports context manager usage for automatic connection management:
161
-
162
- ```python
163
- from hackathon import HackathonAction, HackathonEnv
164
-
165
- # Connect with context manager (auto-connects and closes)
166
- with HackathonEnv(base_url="http://localhost:8000") as env:
167
- result = env.reset()
168
- print(f"Reset: {result.observation.echoed_message}")
169
- # Multiple steps with low latency
170
- for msg in ["Hello", "World", "!"]:
171
- result = env.step(HackathonAction(message=msg))
172
- print(f"Echoed: {result.observation.echoed_message}")
173
- ```
174
-
175
- The client uses WebSocket connections for:
176
- - **Lower latency**: No HTTP connection overhead per request
177
- - **Persistent session**: Server maintains your environment state
178
- - **Efficient for episodes**: Better for many sequential steps
179
-
180
- ### Concurrent WebSocket Sessions
181
-
182
- The server supports multiple concurrent WebSocket connections. To enable this,
183
- modify `server/app.py` to use factory mode:
184
-
185
- ```python
186
- # In server/app.py - use factory mode for concurrent sessions
187
- app = create_app(
188
- HackathonEnvironment, # Pass class, not instance
189
- HackathonAction,
190
- HackathonObservation,
191
- max_concurrent_envs=4, # Allow 4 concurrent sessions
192
- )
193
- ```
194
-
195
- Then multiple clients can connect simultaneously:
196
-
197
- ```python
198
- from hackathon import HackathonAction, HackathonEnv
199
- from concurrent.futures import ThreadPoolExecutor
200
-
201
- def run_episode(client_id: int):
202
- with HackathonEnv(base_url="http://localhost:8000") as env:
203
- result = env.reset()
204
- for i in range(10):
205
- result = env.step(HackathonAction(message=f"Client {client_id}, step {i}"))
206
- return client_id, result.observation.message_length
207
-
208
- # Run 4 episodes concurrently
209
- with ThreadPoolExecutor(max_workers=4) as executor:
210
- results = list(executor.map(run_episode, range(4)))
211
- ```
212
-
213
- ## Development & Testing
214
-
215
- ### Direct Environment Testing
216
-
217
- Test the environment logic directly without starting the HTTP server:
218
-
219
- ```bash
220
- # From the server directory
221
- python3 server/hackathon_environment.py
222
- ```
223
-
224
- This verifies that:
225
- - Environment resets correctly
226
- - Step executes actions properly
227
- - State tracking works
228
- - Rewards are calculated correctly
229
-
230
- ### Running Locally
231
-
232
- Run the server locally for development:
233
-
234
- ```bash
235
- uvicorn server.app:app --reload
236
- ```
237
-
238
- ## Project Structure
239
-
240
- ```
241
- hackathon/
242
- β”œβ”€β”€ .dockerignore # Docker build exclusions
243
- β”œβ”€β”€ __init__.py # Module exports
244
- β”œβ”€β”€ README.md # This file
245
- β”œβ”€β”€ openenv.yaml # OpenEnv manifest
246
- β”œβ”€β”€ pyproject.toml # Project metadata and dependencies
247
- β”œβ”€β”€ uv.lock # Locked dependencies (generated)
248
- β”œβ”€β”€ client.py # HackathonEnv client
249
- β”œβ”€β”€ models.py # Action and Observation models
250
- └── server/
251
- β”œβ”€β”€ __init__.py # Server module exports
252
- β”œβ”€β”€ hackathon_environment.py # Core environment logic
253
- β”œβ”€β”€ app.py # FastAPI application (HTTP + WebSocket endpoints)
254
- └── Dockerfile # Container image definition
255
- ```
 
1
+ ---
2
+ title: Hackathon Environment Server
3
+ emoji: 🎢
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ ---
13
+
14
+ # Hackathon Environment
15
+
16
+ A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
+
18
+ ## Quick Start
19
+
20
+ The simplest way to use the Hackathon environment is through the `HackathonEnv` class:
21
+
22
+ ```python
23
+ from hackathon import HackathonAction, HackathonEnv
24
+
25
+ try:
26
+ # Create environment from Docker image
27
+ hackathonenv = HackathonEnv.from_docker_image("hackathon-env:latest")
28
+
29
+ # Reset
30
+ result = hackathonenv.reset()
31
+ print(f"Reset: {result.observation.echoed_message}")
32
+
33
+ # Send multiple messages
34
+ messages = ["Hello, World!", "Testing echo", "Final message"]
35
+
36
+ for msg in messages:
37
+ result = hackathonenv.step(HackathonAction(message=msg))
38
+ print(f"Sent: '{msg}'")
39
+ print(f" β†’ Echoed: '{result.observation.echoed_message}'")
40
+ print(f" β†’ Length: {result.observation.message_length}")
41
+ print(f" β†’ Reward: {result.reward}")
42
+
43
+ finally:
44
+ # Always clean up
45
+ hackathonenv.close()
46
+ ```
47
+
48
+ That's it! The `HackathonEnv.from_docker_image()` method handles:
49
+ - Starting the Docker container
50
+ - Waiting for the server to be ready
51
+ - Connecting to the environment
52
+ - Container cleanup when you call `close()`
53
+
54
+ ## Building the Docker Image
55
+
56
+ Before using the environment, you need to build the Docker image:
57
+
58
+ ```bash
59
+ # From project root
60
+ docker build -t hackathon-env:latest -f server/Dockerfile .
61
+ ```
62
+
63
+ ## Deploying to Hugging Face Spaces
64
+
65
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
66
+
67
+ ```bash
68
+ # From the environment directory (where openenv.yaml is located)
69
+ openenv push
70
+
71
+ # Or specify options
72
+ openenv push --namespace my-org --private
73
+ ```
74
+
75
+ The `openenv push` command will:
76
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
+ 3. Upload to Hugging Face (ensuring you're logged in)
79
+
80
+ ### Prerequisites
81
+
82
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
+
84
+ ### Options
85
+
86
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
+ - `--private`: Deploy the space as private (default: public)
90
+
91
+ ### Examples
92
+
93
+ ```bash
94
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
+ openenv push
96
+
97
+ # Push to a specific repository
98
+ openenv push --repo-id my-org/my-env
99
+
100
+ # Push with a custom base image
101
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
102
+
103
+ # Push as a private space
104
+ openenv push --private
105
+
106
+ # Combine options
107
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
+ ```
109
+
110
+ After deployment, your space will be available at:
111
+ `https://huggingface.co/spaces/<repo-id>`
112
+
113
+ The deployed space includes:
114
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
+ - **Health Check** at `/health` - Container health monitoring
117
+ - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
118
+
119
+ ## Environment Details
120
+
121
+ ### Action
122
+ **HackathonAction**: Contains a single field
123
+ - `message` (str) - The message to echo back
124
+
125
+ ### Observation
126
+ **HackathonObservation**: Contains the echo response and metadata
127
+ - `echoed_message` (str) - The message echoed back
128
+ - `message_length` (int) - Length of the message
129
+ - `reward` (float) - Reward based on message length (length Γ— 0.1)
130
+ - `done` (bool) - Always False for echo environment
131
+ - `metadata` (dict) - Additional info like step count
132
+
133
+ ### Reward
134
+ The reward is calculated as: `message_length Γ— 0.1`
135
+ - "Hi" β†’ reward: 0.2
136
+ - "Hello, World!" β†’ reward: 1.3
137
+ - Empty message β†’ reward: 0.0
138
+
139
+ ## Advanced Usage
140
+
141
+ ### Connecting to an Existing Server
142
+
143
+ If you already have a Hackathon environment server running, you can connect directly:
144
+
145
+ ```python
146
+ from hackathon import HackathonEnv
147
+
148
+ # Connect to existing server
149
+ hackathonenv = HackathonEnv(base_url="<ENV_HTTP_URL_HERE>")
150
+
151
+ # Use as normal
152
+ result = hackathonenv.reset()
153
+ result = hackathonenv.step(HackathonAction(message="Hello!"))
154
+ ```
155
+
156
+ Note: When connecting to an existing server, `hackathonenv.close()` will NOT stop the server.
157
+
158
+ ### Using the Context Manager
159
+
160
+ The client supports context manager usage for automatic connection management:
161
+
162
+ ```python
163
+ from hackathon import HackathonAction, HackathonEnv
164
+
165
+ # Connect with context manager (auto-connects and closes)
166
+ with HackathonEnv(base_url="http://localhost:8000") as env:
167
+ result = env.reset()
168
+ print(f"Reset: {result.observation.echoed_message}")
169
+ # Multiple steps with low latency
170
+ for msg in ["Hello", "World", "!"]:
171
+ result = env.step(HackathonAction(message=msg))
172
+ print(f"Echoed: {result.observation.echoed_message}")
173
+ ```
174
+
175
+ The client uses WebSocket connections for:
176
+ - **Lower latency**: No HTTP connection overhead per request
177
+ - **Persistent session**: Server maintains your environment state
178
+ - **Efficient for episodes**: Better for many sequential steps
179
+
180
+ ### Concurrent WebSocket Sessions
181
+
182
+ The server supports multiple concurrent WebSocket connections. To enable this,
183
+ modify `server/app.py` to use factory mode:
184
+
185
+ ```python
186
+ # In server/app.py - use factory mode for concurrent sessions
187
+ app = create_app(
188
+ HackathonEnvironment, # Pass class, not instance
189
+ HackathonAction,
190
+ HackathonObservation,
191
+ max_concurrent_envs=4, # Allow 4 concurrent sessions
192
+ )
193
+ ```
194
+
195
+ Then multiple clients can connect simultaneously:
196
+
197
+ ```python
198
+ from hackathon import HackathonAction, HackathonEnv
199
+ from concurrent.futures import ThreadPoolExecutor
200
+
201
+ def run_episode(client_id: int):
202
+ with HackathonEnv(base_url="http://localhost:8000") as env:
203
+ result = env.reset()
204
+ for i in range(10):
205
+ result = env.step(HackathonAction(message=f"Client {client_id}, step {i}"))
206
+ return client_id, result.observation.message_length
207
+
208
+ # Run 4 episodes concurrently
209
+ with ThreadPoolExecutor(max_workers=4) as executor:
210
+ results = list(executor.map(run_episode, range(4)))
211
+ ```
212
+
213
+ ## Development & Testing
214
+
215
+ ### Direct Environment Testing
216
+
217
+ Test the environment logic directly without starting the HTTP server:
218
+
219
+ ```bash
220
+ # From the server directory
221
+ python3 server/hackathon_environment.py
222
+ ```
223
+
224
+ This verifies that:
225
+ - Environment resets correctly
226
+ - Step executes actions properly
227
+ - State tracking works
228
+ - Rewards are calculated correctly
229
+
230
+ ### Running Locally
231
+
232
+ Run the server locally for development:
233
+
234
+ ```bash
235
+ uvicorn server.app:app --reload
236
+ ```
237
+
238
+ ## Project Structure
239
+
240
+ ```
241
+ hackathon/
242
+ β”œβ”€β”€ .dockerignore # Docker build exclusions
243
+ β”œβ”€β”€ __init__.py # Module exports
244
+ β”œβ”€β”€ README.md # This file
245
+ β”œβ”€β”€ openenv.yaml # OpenEnv manifest
246
+ β”œβ”€β”€ pyproject.toml # Project metadata and dependencies
247
+ β”œβ”€β”€ uv.lock # Locked dependencies (generated)
248
+ β”œβ”€β”€ client.py # HackathonEnv client
249
+ β”œβ”€β”€ models.py # Action and Observation models
250
+ └── server/
251
+ β”œβ”€β”€ __init__.py # Server module exports
252
+ β”œβ”€β”€ hackathon_environment.py # Core environment logic
253
+ β”œβ”€β”€ app.py # FastAPI application (HTTP + WebSocket endpoints)
254
+ └── Dockerfile # Container image definition
255
+ ```
__init__.py CHANGED
@@ -1,16 +1,48 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Hackathon Environment."""
8
-
9
- from .client import HackathonEnv
10
- from .models import HackathonAction, HackathonObservation
11
-
12
- __all__ = [
13
- "HackathonAction",
14
- "HackathonObservation",
15
- "HackathonEnv",
16
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # pragma: no cover - package import path
2
+ from .client import BioExperimentEnv
3
+ from .models import (
4
+ ActionType,
5
+ ConclusionClaim,
6
+ ExpectedFinding,
7
+ ExperimentAction,
8
+ ExperimentObservation,
9
+ IntermediateOutput,
10
+ OutputType,
11
+ PaperReference,
12
+ PipelineStepRecord,
13
+ ResourceUsage,
14
+ SubagentType,
15
+ TaskSpec,
16
+ )
17
+ except ImportError: # pragma: no cover - direct module import path
18
+ from client import BioExperimentEnv
19
+ from models import (
20
+ ActionType,
21
+ ConclusionClaim,
22
+ ExpectedFinding,
23
+ ExperimentAction,
24
+ ExperimentObservation,
25
+ IntermediateOutput,
26
+ OutputType,
27
+ PaperReference,
28
+ PipelineStepRecord,
29
+ ResourceUsage,
30
+ SubagentType,
31
+ TaskSpec,
32
+ )
33
+
34
+ __all__ = [
35
+ "ActionType",
36
+ "BioExperimentEnv",
37
+ "ConclusionClaim",
38
+ "ExpectedFinding",
39
+ "ExperimentAction",
40
+ "ExperimentObservation",
41
+ "IntermediateOutput",
42
+ "OutputType",
43
+ "PaperReference",
44
+ "PipelineStepRecord",
45
+ "ResourceUsage",
46
+ "SubagentType",
47
+ "TaskSpec",
48
+ ]
client.py CHANGED
@@ -1,99 +1,53 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Hackathon Environment Client."""
8
-
9
- from typing import Dict
10
-
11
- from openenv.core.client_types import StepResult
12
- from openenv.core.env_server.types import State
13
- from openenv.core import EnvClient
14
-
15
- from .models import HackathonAction, HackathonObservation
16
-
17
-
18
- class HackathonEnv(
19
- EnvClient[HackathonAction, HackathonObservation]
20
- ):
21
- """
22
- Client for the Hackathon Environment.
23
-
24
- This client maintains a persistent WebSocket connection to the environment server,
25
- enabling efficient multi-step interactions with lower latency.
26
- Each client instance has its own dedicated environment session on the server.
27
-
28
- Example:
29
- >>> # Connect to a running server
30
- >>> with HackathonEnv(base_url="http://localhost:8000") as client:
31
- ... result = client.reset()
32
- ... print(result.observation.echoed_message)
33
- ...
34
- ... result = client.step(HackathonAction(message="Hello!"))
35
- ... print(result.observation.echoed_message)
36
-
37
- Example with Docker:
38
- >>> # Automatically start container and connect
39
- >>> client = HackathonEnv.from_docker_image("hackathon-env:latest")
40
- >>> try:
41
- ... result = client.reset()
42
- ... result = client.step(HackathonAction(message="Test"))
43
- ... finally:
44
- ... client.close()
45
- """
46
-
47
- def _step_payload(self, action: HackathonAction) -> Dict:
48
- """
49
- Convert HackathonAction to JSON payload for step message.
50
-
51
- Args:
52
- action: HackathonAction instance
53
-
54
- Returns:
55
- Dictionary representation suitable for JSON encoding
56
- """
57
- return {
58
- "message": action.message,
59
- }
60
-
61
- def _parse_result(self, payload: Dict) -> StepResult[HackathonObservation]:
62
- """
63
- Parse server response into StepResult[HackathonObservation].
64
-
65
- Args:
66
- payload: JSON response data from server
67
-
68
- Returns:
69
- StepResult with HackathonObservation
70
- """
71
- obs_data = payload.get("observation", {})
72
- observation = HackathonObservation(
73
- echoed_message=obs_data.get("echoed_message", ""),
74
- message_length=obs_data.get("message_length", 0),
75
- done=payload.get("done", False),
76
- reward=payload.get("reward"),
77
- metadata=obs_data.get("metadata", {}),
78
- )
79
-
80
- return StepResult(
81
- observation=observation,
82
- reward=payload.get("reward"),
83
- done=payload.get("done", False),
84
- )
85
-
86
- def _parse_state(self, payload: Dict) -> State:
87
- """
88
- Parse server response into State object.
89
-
90
- Args:
91
- payload: JSON response from state request
92
-
93
- Returns:
94
- State object with episode_id and step_count
95
- """
96
- return State(
97
- episode_id=payload.get("episode_id"),
98
- step_count=payload.get("step_count", 0),
99
- )
 
1
+ """Bio-Experiment Environment Client.
2
+
3
+ Provides the ``BioExperimentEnv`` class that communicates with the
4
+ environment server over WebSocket / HTTP using the OpenEnv protocol.
5
+ """
6
+
7
+ from typing import Any, Dict, List
8
+
9
+ from openenv.core.client_types import StepResult
10
+ from openenv.core.env_server.types import State
11
+ from openenv.core import EnvClient
12
+
13
+ try: # pragma: no cover - package import path
14
+ from .models import ExperimentAction, ExperimentObservation
15
+ except ImportError: # pragma: no cover - direct module import path
16
+ from models import ExperimentAction, ExperimentObservation
17
+
18
+
19
+ class BioExperimentEnv(
20
+ EnvClient[ExperimentAction, ExperimentObservation, State]
21
+ ):
22
+ """Client for the Bio-Experiment Planning Environment.
23
+
24
+ Example:
25
+ >>> with BioExperimentEnv(base_url="http://localhost:8000") as env:
26
+ ... result = env.reset()
27
+ ... print(result.observation.task.problem_statement)
28
+ ... result = env.step(ExperimentAction(
29
+ ... action_type="collect_sample",
30
+ ... parameters={"n_samples": 6},
31
+ ... ))
32
+ ... print(result.observation.latest_output.summary)
33
+ """
34
+
35
+ def _step_payload(self, action: ExperimentAction) -> Dict:
36
+ return action.model_dump()
37
+
38
+ def _parse_result(
39
+ self, payload: Dict
40
+ ) -> StepResult[ExperimentObservation]:
41
+ obs_data = payload.get("observation", {})
42
+ observation = ExperimentObservation(**obs_data)
43
+ return StepResult(
44
+ observation=observation,
45
+ reward=payload.get("reward"),
46
+ done=payload.get("done", False),
47
+ )
48
+
49
+ def _parse_state(self, payload: Dict) -> State:
50
+ return State(
51
+ episode_id=payload.get("episode_id"),
52
+ step_count=payload.get("step_count", 0),
53
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py CHANGED
@@ -1,28 +1,268 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Data models for the Hackathon Environment.
9
-
10
- The hackathon environment is a simple test environment that echoes back messages.
11
- """
12
-
13
- from pydantic import Field
14
-
15
- from openenv.core.env_server.types import Action, Observation
16
-
17
-
18
- class HackathonAction(Action):
19
- """Action for the Hackathon environment - just a message to echo."""
20
-
21
- message: str = Field(..., description="Message to echo back")
22
-
23
-
24
- class HackathonObservation(Observation):
25
- """Observation from the Hackathon environment - the echoed message."""
26
-
27
- echoed_message: str = Field(default="", description="The echoed message")
28
- message_length: int = Field(default=0, description="Length of the echoed message")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data models for the Bio-Experiment Planning RL Environment.
3
+
4
+ Defines the POMDP action and observation contracts for a scientific agent
5
+ that constructs biological experiment pipelines step-by-step.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from enum import Enum
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ from pydantic import BaseModel, Field
14
+
15
+ from openenv.core.env_server.types import Action, Observation
16
+
17
+
18
+ # ── Action vocabulary ───────────────────────────────────────────────────────
19
+
20
+
21
+ class ActionType(str, Enum):
22
+ COLLECT_SAMPLE = "collect_sample"
23
+ SELECT_COHORT = "select_cohort"
24
+ PREPARE_LIBRARY = "prepare_library"
25
+ CULTURE_CELLS = "culture_cells"
26
+ PERTURB_GENE = "perturb_gene"
27
+ PERTURB_COMPOUND = "perturb_compound"
28
+ SEQUENCE_CELLS = "sequence_cells"
29
+ RUN_QC = "run_qc"
30
+ FILTER_DATA = "filter_data"
31
+ NORMALIZE_DATA = "normalize_data"
32
+ INTEGRATE_BATCHES = "integrate_batches"
33
+ CLUSTER_CELLS = "cluster_cells"
34
+ DIFFERENTIAL_EXPRESSION = "differential_expression"
35
+ TRAJECTORY_ANALYSIS = "trajectory_analysis"
36
+ PATHWAY_ENRICHMENT = "pathway_enrichment"
37
+ REGULATORY_NETWORK_INFERENCE = "regulatory_network_inference"
38
+ MARKER_SELECTION = "marker_selection"
39
+ VALIDATE_MARKER = "validate_marker"
40
+ DESIGN_FOLLOWUP = "design_followup_experiment"
41
+ REQUEST_SUBAGENT_REVIEW = "request_subagent_review"
42
+ SYNTHESIZE_CONCLUSION = "synthesize_conclusion"
43
+
44
+
45
+ WET_LAB_ACTIONS = frozenset({
46
+ ActionType.COLLECT_SAMPLE,
47
+ ActionType.SELECT_COHORT,
48
+ ActionType.PREPARE_LIBRARY,
49
+ ActionType.CULTURE_CELLS,
50
+ ActionType.PERTURB_GENE,
51
+ ActionType.PERTURB_COMPOUND,
52
+ ActionType.SEQUENCE_CELLS,
53
+ ActionType.VALIDATE_MARKER,
54
+ })
55
+
56
+ COMPUTATIONAL_ACTIONS = frozenset({
57
+ ActionType.RUN_QC,
58
+ ActionType.FILTER_DATA,
59
+ ActionType.NORMALIZE_DATA,
60
+ ActionType.INTEGRATE_BATCHES,
61
+ ActionType.CLUSTER_CELLS,
62
+ ActionType.DIFFERENTIAL_EXPRESSION,
63
+ ActionType.TRAJECTORY_ANALYSIS,
64
+ ActionType.PATHWAY_ENRICHMENT,
65
+ ActionType.REGULATORY_NETWORK_INFERENCE,
66
+ ActionType.MARKER_SELECTION,
67
+ })
68
+
69
+ META_ACTIONS = frozenset({
70
+ ActionType.DESIGN_FOLLOWUP,
71
+ ActionType.REQUEST_SUBAGENT_REVIEW,
72
+ ActionType.SYNTHESIZE_CONCLUSION,
73
+ })
74
+
75
+
76
+ class SubagentType(str, Enum):
77
+ WET_LAB_PLANNER = "wet_lab_planner"
78
+ COMPUTATIONAL_ANALYST = "computational_analyst"
79
+ OMICS_QC_AGENT = "omics_qc_agent"
80
+ CAUSAL_REASONING_AGENT = "causal_reasoning_agent"
81
+ BUDGET_SCHEDULER = "budget_scheduler"
82
+ BIOLOGICAL_RULE_CHECKER = "biological_rule_checker"
83
+ TOOL_EXECUTOR = "tool_executor"
84
+ RETROSPECTIVE_CRITIC = "retrospective_critic"
85
+ REPORT_SYNTHESIZER = "report_synthesizer"
86
+
87
+
88
+ # ── Action schema ───────────────────────────────────────────────────────────
89
+
90
+
91
+ class ExperimentAction(Action):
92
+ """Structured, compositional action for one experiment / analysis step.
93
+
94
+ Hybrid representation: discrete *action_type* plus typed arguments,
95
+ optional sub-agent / tool invocation, and calibration fields.
96
+ """
97
+
98
+ action_type: ActionType = Field(
99
+ ..., description="Discrete experiment or analysis step type"
100
+ )
101
+ input_targets: List[str] = Field(
102
+ default_factory=list,
103
+ description="References to prior outputs, samples, or artifacts",
104
+ )
105
+ method: Optional[str] = Field(
106
+ None, description="Specific method or tool (e.g. 'Seurat', 'CellRanger')"
107
+ )
108
+ parameters: Dict[str, Any] = Field(
109
+ default_factory=dict, description="Method-specific parameters"
110
+ )
111
+ expected_output_type: Optional[str] = Field(
112
+ None, description="What the agent expects this step to produce"
113
+ )
114
+ justification: Optional[str] = Field(
115
+ None, description="Scientific rationale for this step"
116
+ )
117
+ invoked_subagent: Optional[SubagentType] = Field(
118
+ None, description="Sub-agent to delegate to, if any"
119
+ )
120
+ tool_call_spec: Optional[Dict[str, Any]] = Field(
121
+ None, description="Structured tool invocation specification"
122
+ )
123
+ confidence: float = Field(
124
+ 0.5, ge=0.0, le=1.0, description="Agent confidence in this step"
125
+ )
126
+
127
+
128
+ # ── Intermediate outputs ────────────────────────────────────────────────────
129
+
130
+
131
+ class OutputType(str, Enum):
132
+ QC_METRICS = "qc_metrics"
133
+ COUNT_MATRIX_SUMMARY = "count_matrix_summary"
134
+ EMBEDDING_SUMMARY = "embedding_summary"
135
+ CLUSTER_RESULT = "cluster_result"
136
+ DE_RESULT = "de_result"
137
+ PATHWAY_RESULT = "pathway_result"
138
+ TRAJECTORY_RESULT = "trajectory_result"
139
+ VALIDATION_RESULT = "validation_result"
140
+ NETWORK_RESULT = "network_result"
141
+ SAMPLE_COLLECTION_RESULT = "sample_collection_result"
142
+ LIBRARY_PREP_RESULT = "library_prep_result"
143
+ SEQUENCING_RESULT = "sequencing_result"
144
+ PERTURBATION_RESULT = "perturbation_result"
145
+ CULTURE_RESULT = "culture_result"
146
+ COHORT_RESULT = "cohort_result"
147
+ FOLLOWUP_DESIGN = "followup_design"
148
+ MARKER_RESULT = "marker_result"
149
+ FAILURE_REPORT = "failure_report"
150
+ SUBAGENT_REPORT = "subagent_report"
151
+ CONCLUSION = "conclusion"
152
+
153
+
154
+ class IntermediateOutput(BaseModel):
155
+ """A single simulated output from one pipeline step."""
156
+
157
+ output_type: OutputType
158
+ step_index: int
159
+ success: bool = True
160
+ quality_score: float = Field(1.0, ge=0.0, le=1.0)
161
+ summary: str = ""
162
+ data: Dict[str, Any] = Field(default_factory=dict)
163
+ uncertainty: float = Field(0.0, ge=0.0, le=1.0)
164
+ warnings: List[str] = Field(default_factory=list)
165
+ artifacts_available: List[str] = Field(default_factory=list)
166
+
167
+
168
+ # ── Observable state components ─────────────────────────────────────────────
169
+
170
+
171
+ class ResourceUsage(BaseModel):
172
+ budget_used: float = 0.0
173
+ budget_remaining: float = 100_000.0
174
+ time_used_days: float = 0.0
175
+ time_remaining_days: float = 180.0
176
+ samples_consumed: int = 0
177
+ compute_hours_used: float = 0.0
178
+
179
+
180
+ class PipelineStepRecord(BaseModel):
181
+ step_index: int
182
+ action_type: ActionType
183
+ method: Optional[str] = None
184
+ parameters: Dict[str, Any] = Field(default_factory=dict)
185
+ output_summary: str = ""
186
+ output_type: OutputType
187
+ success: bool = True
188
+ quality_score: float = 1.0
189
+ resource_cost: float = 0.0
190
+ time_cost_days: float = 0.0
191
+
192
+
193
+ class PaperReference(BaseModel):
194
+ """Metadata for a literature source used to ground a task."""
195
+
196
+ title: str
197
+ citation: Optional[str] = None
198
+ doi: Optional[str] = None
199
+ pmid: Optional[str] = None
200
+ url: Optional[str] = None
201
+
202
+
203
+ class ExpectedFinding(BaseModel):
204
+ """A paper-backed result that the agent should try to recover."""
205
+
206
+ finding: str
207
+ category: str = "claim"
208
+ keywords: List[str] = Field(default_factory=list)
209
+
210
+
211
+ class TaskSpec(BaseModel):
212
+ """Specification of the biological problem to solve."""
213
+
214
+ problem_statement: str = "Unspecified biological problem"
215
+ modality: str = "scRNA-seq"
216
+ organism: str = "human"
217
+ tissue: str = "blood"
218
+ conditions: List[str] = Field(default_factory=list)
219
+ available_assays: List[str] = Field(default_factory=lambda: [
220
+ "10x_chromium", "smart-seq2", "bulk_rna_seq",
221
+ "atac-seq", "cite-seq", "spatial_transcriptomics",
222
+ ])
223
+ available_tools: List[str] = Field(default_factory=lambda: [
224
+ "CellRanger", "Seurat", "Scanpy", "DESeq2", "GSEA",
225
+ "Monocle", "scVelo", "CellChat", "SCENIC",
226
+ ])
227
+ budget_limit: float = 100_000.0
228
+ time_limit_days: float = 180.0
229
+ prior_observations: List[str] = Field(default_factory=list)
230
+ success_criteria: List[str] = Field(default_factory=list)
231
+ dataset_metadata: Dict[str, Any] = Field(default_factory=dict)
232
+ paper_references: List[PaperReference] = Field(default_factory=list)
233
+ expected_findings: List[ExpectedFinding] = Field(default_factory=list)
234
+
235
+
236
+ class ConclusionClaim(BaseModel):
237
+ claim: str
238
+ evidence_steps: List[int] = Field(default_factory=list)
239
+ confidence: float = Field(0.5, ge=0.0, le=1.0)
240
+ claim_type: str = "correlational"
241
+ supporting_data: Dict[str, Any] = Field(default_factory=dict)
242
+
243
+
244
+ # ── Observation schema ──────────────────────────────────────────────────────
245
+
246
+
247
+ class ExperimentObservation(Observation):
248
+ """Full observable state returned to the agent at each timestep.
249
+
250
+ Deliberately excludes hidden latent biological truth, hidden failure
251
+ conditions, and ground-truth mechanisms.
252
+ """
253
+
254
+ task: TaskSpec = Field(default_factory=TaskSpec)
255
+ step_index: int = 0
256
+ pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
257
+ available_assays: List[str] = Field(default_factory=list)
258
+ available_tools: List[str] = Field(default_factory=list)
259
+ resource_usage: ResourceUsage = Field(default_factory=ResourceUsage)
260
+ latest_output: Optional[IntermediateOutput] = None
261
+ all_outputs: List[IntermediateOutput] = Field(default_factory=list)
262
+ discovered_markers: List[str] = Field(default_factory=list)
263
+ candidate_mechanisms: List[str] = Field(default_factory=list)
264
+ uncertainty_summary: Dict[str, float] = Field(default_factory=dict)
265
+ subagent_outputs: List[Dict[str, Any]] = Field(default_factory=list)
266
+ conclusions: List[ConclusionClaim] = Field(default_factory=list)
267
+ rule_violations: List[str] = Field(default_factory=list)
268
+ step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)
openenv.yaml CHANGED
@@ -1,7 +1,7 @@
1
- spec_version: 1
2
- name: hackathon
3
- type: space
4
- runtime: fastapi
5
- app: server.app:app
6
- port: 8000
7
-
 
1
+ spec_version: 1
2
+ name: hackathon
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
pyproject.toml CHANGED
@@ -1,45 +1,63 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- [build-system]
8
- requires = ["setuptools>=45", "wheel"]
9
- build-backend = "setuptools.build_meta"
10
-
11
- [project]
12
- name = "openenv-hackathon"
13
- version = "0.1.0"
14
- description = "Hackathon environment for OpenEnv"
15
- requires-python = ">=3.10"
16
- dependencies = [
17
- # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
- # install from github
19
- # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
- "openenv-core[core]>=0.2.0",
21
- # Environment-specific dependencies
22
- # Add all dependencies needed for your environment here
23
- # Examples:
24
- # "numpy>=1.19.0",
25
- # "torch>=2.0.0",
26
- # "gymnasium>=0.29.0",
27
- # "openspiel>=1.0.0",
28
- # "smolagents>=1.22.0,<2",
29
- ]
30
-
31
- [project.optional-dependencies]
32
- dev = [
33
- "pytest>=8.0.0",
34
- "pytest-cov>=4.0.0",
35
- ]
36
-
37
- [project.scripts]
38
- # Server entry point - enables running via: uv run --project . server
39
- # or: python -m hackathon.server.app
40
- server = "hackathon.server.app:main"
41
-
42
- [tool.setuptools]
43
- include-package-data = true
44
- packages = ["hackathon", "hackathon.server"]
45
- package-dir = { "hackathon" = ".", "hackathon.server" = "server" }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-bio-experiment"
13
+ version = "0.1.0"
14
+ description = "RL environment for biological experiment pipeline planning"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ "openenv-core[core]>=0.2.0",
18
+ "numpy>=1.24.0",
19
+ "scipy>=1.10.0",
20
+ "pydantic>=2.0.0",
21
+ ]
22
+
23
+ [project.optional-dependencies]
24
+ train = [
25
+ "gymnasium>=0.29.0",
26
+ ]
27
+ bio = [
28
+ "biopython>=1.84",
29
+ "gseapy>=1.1.3",
30
+ "scanpy>=1.10.0",
31
+ ]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ "gymnasium>=0.29.0",
36
+ ]
37
+
38
+ [project.scripts]
39
+ server = "hackathon.server.app:main"
40
+
41
+ [tool.setuptools]
42
+ include-package-data = true
43
+ packages = [
44
+ "hackathon",
45
+ "hackathon.server",
46
+ "hackathon.server.simulator",
47
+ "hackathon.server.rules",
48
+ "hackathon.server.rewards",
49
+ "hackathon.server.tasks",
50
+ "hackathon.server.subagents",
51
+ "hackathon.training",
52
+ "hackathon.tests",
53
+ ]
54
+ [tool.setuptools.package-dir]
55
+ hackathon = "."
56
+ "hackathon.server" = "server"
57
+ "hackathon.server.simulator" = "server/simulator"
58
+ "hackathon.server.rules" = "server/rules"
59
+ "hackathon.server.rewards" = "server/rewards"
60
+ "hackathon.server.tasks" = "server/tasks"
61
+ "hackathon.server.subagents" = "server/subagents"
62
+ "hackathon.training" = "training"
63
+ "hackathon.tests" = "tests"
run_agent.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run the bio-experiment environment with Qwen3.5-2B as the planning agent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import re
7
+ import sys
8
+ import time
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+
14
+ from models import ActionType, ExperimentAction, ExperimentObservation
15
+ from server.hackathon_environment import BioExperimentEnvironment
16
+
17
+ MODEL_ID = "Qwen/Qwen3.5-2B"
18
+ MAX_EPISODE_STEPS = 12
19
+
20
+ ACTION_TYPES = [a.value for a in ActionType]
21
+
22
+ SYSTEM_PROMPT = """\
23
+ You are an expert biologist planning a single-cell experiment pipeline.
24
+
25
+ At each turn you see the experiment state and must pick the next step.
26
+
27
+ Action types (in typical order):
28
+ collect_sample, prepare_library, sequence_cells, run_qc, filter_data,
29
+ normalize_data, cluster_cells, differential_expression,
30
+ pathway_enrichment, marker_selection, validate_marker, synthesize_conclusion
31
+
32
+ Other actions: select_cohort, culture_cells, perturb_gene, perturb_compound,
33
+ integrate_batches, trajectory_analysis, regulatory_network_inference,
34
+ design_followup_experiment, request_subagent_review
35
+
36
+ Respond with ONLY valid JSON, nothing else:
37
+ {"action_type": "...", "method": null, "parameters": {}, "justification": "...", "confidence": 0.8}
38
+ """
39
+
40
+
41
+ def format_observation(obs: ExperimentObservation) -> str:
42
+ parts = [
43
+ f"TASK: {obs.task.problem_statement}",
44
+ f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}",
45
+ f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}",
46
+ f"Step: {obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d",
47
+ ]
48
+ if obs.pipeline_history:
49
+ last5 = obs.pipeline_history[-5:]
50
+ parts.append("History:")
51
+ for h in last5:
52
+ tag = "OK" if h.success else "FAIL"
53
+ parts.append(f" [{tag}] {h.action_type.value}: {h.output_summary[:80]}")
54
+ if obs.rule_violations:
55
+ parts.append(f"VIOLATIONS: {obs.rule_violations}")
56
+ if obs.discovered_markers:
57
+ parts.append(f"Markers: {obs.discovered_markers[:5]}")
58
+ return "\n".join(parts)
59
+
60
+
61
+ def parse_action(text: str) -> Optional[ExperimentAction]:
62
+ match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
63
+ if not match:
64
+ return None
65
+ try:
66
+ d = json.loads(match.group())
67
+ except json.JSONDecodeError:
68
+ return None
69
+
70
+ action_type = d.get("action_type")
71
+ if action_type not in ACTION_TYPES:
72
+ return None
73
+
74
+ return ExperimentAction(
75
+ action_type=ActionType(action_type),
76
+ method=d.get("method"),
77
+ parameters=d.get("parameters") or {},
78
+ justification=d.get("justification"),
79
+ confidence=min(1.0, max(0.0, float(d.get("confidence", 0.5)))),
80
+ )
81
+
82
+
83
+ FALLBACK_SEQUENCE = [
84
+ ActionType.COLLECT_SAMPLE,
85
+ ActionType.PREPARE_LIBRARY,
86
+ ActionType.SEQUENCE_CELLS,
87
+ ActionType.RUN_QC,
88
+ ActionType.FILTER_DATA,
89
+ ActionType.NORMALIZE_DATA,
90
+ ActionType.CLUSTER_CELLS,
91
+ ActionType.DIFFERENTIAL_EXPRESSION,
92
+ ActionType.PATHWAY_ENRICHMENT,
93
+ ActionType.MARKER_SELECTION,
94
+ ActionType.SYNTHESIZE_CONCLUSION,
95
+ ]
96
+
97
+
98
+ def fallback_action(step: int) -> ExperimentAction:
99
+ idx = min(step, len(FALLBACK_SEQUENCE) - 1)
100
+ return ExperimentAction(
101
+ action_type=FALLBACK_SEQUENCE[idx],
102
+ justification="fallback",
103
+ confidence=0.3,
104
+ )
105
+
106
+
107
+ def log(msg: str) -> None:
108
+ print(msg, flush=True)
109
+
110
+
111
+ def main():
112
+ log(f"Loading tokenizer for {MODEL_ID} ...")
113
+ tokenizer = AutoTokenizer.from_pretrained(
114
+ MODEL_ID, trust_remote_code=True,
115
+ )
116
+ log("Tokenizer loaded. Loading model (this downloads ~4 GB on first run) ...")
117
+
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ MODEL_ID,
120
+ torch_dtype=torch.bfloat16,
121
+ device_map="auto",
122
+ trust_remote_code=True,
123
+ )
124
+ log(f"Model loaded. Device: {model.device}")
125
+
126
+ eos_ids: List[int] = []
127
+ if tokenizer.eos_token_id is not None:
128
+ eos_ids.append(tokenizer.eos_token_id)
129
+ extra = tokenizer.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
130
+ for tid in extra:
131
+ if isinstance(tid, int) and tid not in eos_ids:
132
+ eos_ids.append(tid)
133
+ log(f"EOS token ids: {eos_ids}")
134
+
135
+ env = BioExperimentEnvironment()
136
+ obs = env.reset()
137
+
138
+ log("\n" + "=" * 70)
139
+ log(f"TASK: {obs.task.problem_statement}")
140
+ log(f"Conditions: {obs.task.conditions}")
141
+ log(f"Budget: ${obs.task.budget_limit:,.0f} | Time: {obs.task.time_limit_days:.0f} days")
142
+ log("=" * 70)
143
+
144
+ cumulative_reward = 0.0
145
+
146
+ for step in range(MAX_EPISODE_STEPS):
147
+ user_msg = format_observation(obs)
148
+
149
+ messages = [
150
+ {"role": "system", "content": SYSTEM_PROMPT},
151
+ {"role": "user", "content": user_msg},
152
+ ]
153
+
154
+ try:
155
+ prompt = tokenizer.apply_chat_template(
156
+ messages,
157
+ tokenize=False,
158
+ add_generation_prompt=True,
159
+ enable_thinking=False,
160
+ )
161
+ except TypeError:
162
+ prompt = tokenizer.apply_chat_template(
163
+ messages,
164
+ tokenize=False,
165
+ add_generation_prompt=True,
166
+ )
167
+
168
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
169
+ n_input = inputs["input_ids"].shape[1]
170
+
171
+ t0 = time.time()
172
+ with torch.no_grad():
173
+ output_ids = model.generate(
174
+ **inputs,
175
+ max_new_tokens=200,
176
+ do_sample=True,
177
+ temperature=0.7,
178
+ top_p=0.8,
179
+ top_k=20,
180
+ repetition_penalty=1.3,
181
+ eos_token_id=eos_ids if eos_ids else None,
182
+ )
183
+ gen_time = time.time() - t0
184
+
185
+ new_tokens = output_ids[0][n_input:]
186
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
187
+
188
+ action = parse_action(response)
189
+ used_fallback = False
190
+ if action is None:
191
+ log(f"\n [!] Parse failed, using fallback. Raw: {response[:150]}")
192
+ action = fallback_action(step)
193
+ used_fallback = True
194
+
195
+ tag = " [FALLBACK]" if used_fallback else ""
196
+ log(f"\nStep {step + 1}: {action.action_type.value}{tag} ({gen_time:.1f}s)")
197
+ if action.justification:
198
+ log(f" Rationale: {action.justification}")
199
+
200
+ obs = env.step(action)
201
+
202
+ if obs.latest_output:
203
+ lo = obs.latest_output
204
+ status = "OK" if lo.success else "FAIL"
205
+ log(f" [{status}] {lo.summary}")
206
+ if lo.warnings:
207
+ log(f" Warnings: {lo.warnings}")
208
+
209
+ step_reward = obs.reward
210
+ cumulative_reward += step_reward
211
+ log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})")
212
+ log(f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d")
213
+
214
+ if obs.rule_violations:
215
+ log(f" Violations: {obs.rule_violations}")
216
+
217
+ if obs.done:
218
+ break
219
+
220
+ log(f"\n{'=' * 70}")
221
+ log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})")
222
+ log(f" Steps: {obs.step_index}")
223
+ log(f" Total reward: {cumulative_reward:+.3f}")
224
+ log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}")
225
+ log(f" Time used: {obs.resource_usage.time_used_days:.0f} days")
226
+ if obs.conclusions:
227
+ log(" Conclusions:")
228
+ for c in obs.conclusions:
229
+ log(f" [{c.claim_type}, conf={c.confidence:.2f}] {c.claim}")
230
+ log("=" * 70)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ main()
server/__init__.py CHANGED
@@ -1,11 +1,3 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Hackathon environment server components."""
8
-
9
- from .hackathon_environment import HackathonEnvironment
10
-
11
- __all__ = ["HackathonEnvironment"]
 
1
+ from .hackathon_environment import BioExperimentEnvironment
2
+
3
+ __all__ = ["BioExperimentEnvironment"]
 
 
 
 
 
 
 
 
server/app.py CHANGED
@@ -1,81 +1,41 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- FastAPI application for the Hackathon Environment.
9
-
10
- This module creates an HTTP server that exposes the HackathonEnvironment
11
- over HTTP and WebSocket endpoints, compatible with EnvClient.
12
-
13
- Endpoints:
14
- - POST /reset: Reset the environment
15
- - POST /step: Execute an action
16
- - GET /state: Get current environment state
17
- - GET /schema: Get action/observation schemas
18
- - WS /ws: WebSocket endpoint for persistent sessions
19
-
20
- Usage:
21
- # Development (with auto-reload):
22
- uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
-
24
- # Production:
25
- uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
-
27
- # Or run directly:
28
- python -m server.app
29
- """
30
-
31
- try:
32
- from openenv.core.env_server.http_server import create_app
33
- except Exception as e: # pragma: no cover
34
- raise ImportError(
35
- "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
36
- ) from e
37
-
38
- # Import from local models.py (PYTHONPATH includes /app/env in Docker)
39
- from models import HackathonAction, HackathonObservation
40
- from .hackathon_environment import HackathonEnvironment
41
-
42
-
43
- # Create the app with web interface and README integration
44
- app = create_app(
45
- HackathonEnvironment,
46
- HackathonAction,
47
- HackathonObservation,
48
- env_name="hackathon",
49
- max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
50
- )
51
-
52
-
53
- def main(host: str = "0.0.0.0", port: int = 8000):
54
- """
55
- Entry point for direct execution via uv run or python -m.
56
-
57
- This function enables running the server without Docker:
58
- uv run --project . server
59
- uv run --project . server --port 8001
60
- python -m hackathon.server.app
61
-
62
- Args:
63
- host: Host address to bind to (default: "0.0.0.0")
64
- port: Port number to listen on (default: 8000)
65
-
66
- For production deployments, consider using uvicorn directly with
67
- multiple workers:
68
- uvicorn hackathon.server.app:app --workers 4
69
- """
70
- import uvicorn
71
-
72
- uvicorn.run(app, host=host, port=port)
73
-
74
-
75
- if __name__ == "__main__":
76
- import argparse
77
-
78
- parser = argparse.ArgumentParser()
79
- parser.add_argument("--port", type=int, default=8000)
80
- args = parser.parse_args()
81
- main(port=args.port)
 
1
+ """FastAPI application for the Bio-Experiment Planning Environment.
2
+
3
+ Endpoints:
4
+ - POST /reset: Reset the environment
5
+ - POST /step: Execute an action
6
+ - GET /state: Get current environment state
7
+ - GET /schema: Get action/observation schemas
8
+ - WS /ws: WebSocket endpoint for persistent sessions
9
+ """
10
+
11
+ try:
12
+ from openenv.core.env_server.http_server import create_app
13
+ except Exception as e: # pragma: no cover
14
+ raise ImportError(
15
+ "openenv is required for the web interface. "
16
+ "Install dependencies with 'uv sync'"
17
+ ) from e
18
+
19
+ from models import ExperimentAction, ExperimentObservation
20
+ from .hackathon_environment import BioExperimentEnvironment
21
+
22
+ app = create_app(
23
+ BioExperimentEnvironment,
24
+ ExperimentAction,
25
+ ExperimentObservation,
26
+ env_name="bio_experiment",
27
+ max_concurrent_envs=1,
28
+ )
29
+
30
+
31
+ def main(host: str = "0.0.0.0", port: int = 8000):
32
+ import uvicorn
33
+ uvicorn.run(app, host=host, port=port)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ import argparse
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--port", type=int, default=8000)
40
+ args = parser.parse_args()
41
+ main(port=args.port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server/hackathon_environment.py CHANGED
@@ -1,101 +1,239 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Hackathon Environment Implementation.
9
-
10
- A simple test environment that echoes back messages sent to it.
11
- Perfect for testing HTTP server infrastructure.
12
- """
13
-
14
- from uuid import uuid4
15
-
16
- from openenv.core.env_server.interfaces import Environment
17
- from openenv.core.env_server.types import State
18
-
19
- from models import HackathonAction, HackathonObservation
20
-
21
-
22
- class HackathonEnvironment(Environment):
23
- """
24
- A simple echo environment that echoes back messages.
25
-
26
- This environment is designed for testing the HTTP server infrastructure.
27
- It maintains minimal state and simply echoes back whatever message it receives.
28
-
29
- Example:
30
- >>> env = HackathonEnvironment()
31
- >>> obs = env.reset()
32
- >>> print(obs.echoed_message) # "Hackathon environment ready!"
33
- >>>
34
- >>> obs = env.step(HackathonAction(message="Hello"))
35
- >>> print(obs.echoed_message) # "Hello"
36
- >>> print(obs.message_length) # 5
37
- """
38
-
39
- # Enable concurrent WebSocket sessions.
40
- # Set to True if your environment isolates state between instances.
41
- # When True, multiple WebSocket clients can connect simultaneously, each
42
- # getting their own environment instance (when using factory mode in app.py).
43
- SUPPORTS_CONCURRENT_SESSIONS: bool = True
44
-
45
- def __init__(self):
46
- """Initialize the hackathon environment."""
47
- self._state = State(episode_id=str(uuid4()), step_count=0)
48
- self._reset_count = 0
49
-
50
- def reset(self) -> HackathonObservation:
51
- """
52
- Reset the environment.
53
-
54
- Returns:
55
- HackathonObservation with a ready message
56
- """
57
- self._state = State(episode_id=str(uuid4()), step_count=0)
58
- self._reset_count += 1
59
-
60
- return HackathonObservation(
61
- echoed_message="Hackathon environment ready!",
62
- message_length=0,
63
- done=False,
64
- reward=0.0,
65
- )
66
-
67
- def step(self, action: HackathonAction) -> HackathonObservation: # type: ignore[override]
68
- """
69
- Execute a step in the environment by echoing the message.
70
-
71
- Args:
72
- action: HackathonAction containing the message to echo
73
-
74
- Returns:
75
- HackathonObservation with the echoed message and its length
76
- """
77
- self._state.step_count += 1
78
-
79
- message = action.message
80
- length = len(message)
81
-
82
- # Simple reward: longer messages get higher rewards
83
- reward = length * 0.1
84
-
85
- return HackathonObservation(
86
- echoed_message=message,
87
- message_length=length,
88
- done=False,
89
- reward=reward,
90
- metadata={"original_message": message, "step": self._state.step_count},
91
- )
92
-
93
- @property
94
- def state(self) -> State:
95
- """
96
- Get the current environment state.
97
-
98
- Returns:
99
- Current State with episode_id and step_count
100
- """
101
- return self._state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bio-Experiment Planning Environment.
2
+
3
+ Implements the OpenEnv ``Environment`` interface as a POMDP where the
4
+ agent proposes one structured experiment / analysis step at a time and
5
+ receives simulated intermediate outputs from a latent biological world.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Dict, List, Optional
11
+ from uuid import uuid4
12
+
13
+ from openenv.core.env_server.interfaces import Environment
14
+ from openenv.core.env_server.types import State
15
+
16
+ from models import (
17
+ ActionType,
18
+ ConclusionClaim,
19
+ ExperimentAction,
20
+ ExperimentObservation,
21
+ IntermediateOutput,
22
+ PipelineStepRecord,
23
+ ResourceUsage,
24
+ TaskSpec,
25
+ )
26
+
27
+ from server.rules.engine import RuleEngine
28
+ from server.rewards.reward import RewardBreakdown, RewardComputer
29
+ from server.simulator.latent_state import FullLatentState
30
+ from server.simulator.noise import NoiseModel
31
+ from server.simulator.transition import ACTION_COSTS, TransitionEngine
32
+ from server.tasks.generator import TaskGenerator
33
+
34
+
35
+ MAX_STEPS = 30
36
+
37
+
38
+ class BioExperimentEnvironment(Environment):
39
+ """POMDP environment for iterative biological experiment planning.
40
+
41
+ The agent observes ``ExperimentObservation`` (partial view) while the
42
+ environment maintains a ``FullLatentState`` (hidden ground truth).
43
+ """
44
+
45
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
46
+
47
+ def __init__(
48
+ self,
49
+ scenario_name: Optional[str] = None,
50
+ *,
51
+ domain_randomise: bool = True,
52
+ ) -> None:
53
+ self._state = State(episode_id=str(uuid4()), step_count=0)
54
+ self._latent: Optional[FullLatentState] = None
55
+ self._task: Optional[TaskSpec] = None
56
+ self._scenario_name = scenario_name
57
+ self._noise = NoiseModel()
58
+ self._engine = TransitionEngine(self._noise)
59
+ self._rules = RuleEngine()
60
+ self._rewards = RewardComputer()
61
+ self._task_gen = TaskGenerator(domain_randomise=domain_randomise)
62
+
63
+ self._history: List[PipelineStepRecord] = []
64
+ self._outputs: List[IntermediateOutput] = []
65
+ self._conclusions: List[ConclusionClaim] = []
66
+ self._subagent_outputs: List[Dict[str, Any]] = []
67
+ self._discovered_markers: List[str] = []
68
+ self._candidate_mechanisms: List[str] = []
69
+ self._cumulative_reward: float = 0.0
70
+
71
+ # ── Environment interface ───────────────────────────────────────────
72
+
73
+ def reset(self) -> ExperimentObservation:
74
+ seed = hash(uuid4()) % (2**31)
75
+ self._noise.reseed(seed)
76
+ self._state = State(episode_id=str(uuid4()), step_count=0)
77
+
78
+ self._task, self._latent = self._task_gen.generate(
79
+ seed=seed,
80
+ scenario_name=self._scenario_name,
81
+ )
82
+ self._latent.rng_seed = seed
83
+
84
+ self._history.clear()
85
+ self._outputs.clear()
86
+ self._conclusions.clear()
87
+ self._subagent_outputs.clear()
88
+ self._discovered_markers.clear()
89
+ self._candidate_mechanisms.clear()
90
+ self._cumulative_reward = 0.0
91
+
92
+ return self._build_observation(reward=0.0, done=False)
93
+
94
+ def step( # type: ignore[override]
95
+ self, action: ExperimentAction
96
+ ) -> ExperimentObservation:
97
+ assert self._latent is not None, "Call reset() before step()"
98
+ assert self._task is not None
99
+
100
+ self._state.step_count += 1
101
+ prev_state = self._latent.model_copy(deep=True)
102
+
103
+ violations = self._rules.check(action, self._latent)
104
+ hard_v = self._rules.hard_violations(violations)
105
+ soft_v = self._rules.soft_violations(violations)
106
+
107
+ result = self._engine.step(
108
+ self._latent,
109
+ action,
110
+ hard_violations=hard_v,
111
+ soft_violations=soft_v,
112
+ )
113
+ self._latent = result.next_state
114
+
115
+ step_rb = self._rewards.step_reward(
116
+ action, prev_state, self._latent, result.output, hard_v, soft_v,
117
+ )
118
+
119
+ cost_budget, cost_time = ACTION_COSTS.get(action.action_type, (0, 0))
120
+ self._history.append(PipelineStepRecord(
121
+ step_index=self._state.step_count,
122
+ action_type=action.action_type,
123
+ method=action.method,
124
+ parameters=action.parameters,
125
+ output_summary=result.output.summary,
126
+ output_type=result.output.output_type,
127
+ success=result.output.success,
128
+ quality_score=result.output.quality_score,
129
+ resource_cost=cost_budget,
130
+ time_cost_days=cost_time,
131
+ ))
132
+ self._outputs.append(result.output)
133
+ self._update_discoveries(action, result.output)
134
+
135
+ if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
136
+ raw_claims = action.parameters.get("claims", [])
137
+ for c in raw_claims:
138
+ if isinstance(c, dict):
139
+ self._conclusions.append(ConclusionClaim(**c))
140
+
141
+ done = result.done or self._state.step_count >= MAX_STEPS
142
+
143
+ terminal_rb = RewardBreakdown()
144
+ if done:
145
+ terminal_rb = self._rewards.terminal_reward(
146
+ self._latent, self._conclusions, self._task.success_criteria,
147
+ )
148
+
149
+ total_reward = step_rb.total + terminal_rb.total
150
+ self._cumulative_reward += total_reward
151
+
152
+ breakdown = step_rb.to_dict()
153
+ breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()})
154
+
155
+ return self._build_observation(
156
+ reward=total_reward,
157
+ done=done,
158
+ latest_output=result.output,
159
+ rule_violations=hard_v + soft_v,
160
+ reward_breakdown=breakdown,
161
+ )
162
+
163
+ @property
164
+ def state(self) -> State:
165
+ return self._state
166
+
167
+ def set_scenario(self, scenario_name: Optional[str]) -> None:
168
+ """Set the scenario used on the next reset."""
169
+
170
+ self._scenario_name = scenario_name
171
+
172
+ # ── internal helpers ────────────────────────────────────────────────
173
+
174
+ def _build_observation(
175
+ self,
176
+ *,
177
+ reward: float,
178
+ done: bool,
179
+ latest_output: Optional[IntermediateOutput] = None,
180
+ rule_violations: Optional[List[str]] = None,
181
+ reward_breakdown: Optional[Dict[str, float]] = None,
182
+ ) -> ExperimentObservation:
183
+ assert self._task is not None
184
+ assert self._latent is not None
185
+ res = self._latent.resources
186
+ return ExperimentObservation(
187
+ task=self._task,
188
+ step_index=self._state.step_count,
189
+ pipeline_history=list(self._history),
190
+ available_assays=list(self._task.available_assays),
191
+ available_tools=list(self._task.available_tools),
192
+ resource_usage=ResourceUsage(
193
+ budget_used=res.budget_used,
194
+ budget_remaining=res.budget_remaining,
195
+ time_used_days=res.time_used_days,
196
+ time_remaining_days=res.time_remaining_days,
197
+ samples_consumed=res.samples_consumed,
198
+ compute_hours_used=res.compute_hours_used,
199
+ ),
200
+ latest_output=latest_output,
201
+ all_outputs=list(self._outputs),
202
+ discovered_markers=list(self._discovered_markers),
203
+ candidate_mechanisms=list(self._candidate_mechanisms),
204
+ uncertainty_summary=self._compute_uncertainty_summary(),
205
+ subagent_outputs=list(self._subagent_outputs),
206
+ conclusions=list(self._conclusions),
207
+ rule_violations=rule_violations or [],
208
+ step_reward_breakdown=reward_breakdown or {},
209
+ done=done,
210
+ reward=reward,
211
+ metadata={
212
+ "episode_id": self._state.episode_id,
213
+ "step": self._state.step_count,
214
+ "cumulative_reward": self._cumulative_reward,
215
+ },
216
+ )
217
+
218
+ def _compute_uncertainty_summary(self) -> Dict[str, float]:
219
+ if not self._outputs:
220
+ return {}
221
+ recent = self._outputs[-5:]
222
+ avg_unc = sum(o.uncertainty for o in recent) / len(recent)
223
+ avg_qual = sum(o.quality_score for o in recent) / len(recent)
224
+ return {"avg_uncertainty": avg_unc, "avg_quality": avg_qual}
225
+
226
+ def _update_discoveries(
227
+ self, action: ExperimentAction, output: IntermediateOutput
228
+ ) -> None:
229
+ if action.action_type == ActionType.MARKER_SELECTION:
230
+ markers = output.data.get("markers", [])
231
+ self._discovered_markers.extend(markers)
232
+ if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
233
+ regs = output.data.get("top_regulators", [])
234
+ self._candidate_mechanisms.extend(regs)
235
+ if action.action_type == ActionType.PATHWAY_ENRICHMENT:
236
+ pathways = output.data.get("top_pathways", [])
237
+ self._candidate_mechanisms.extend(
238
+ [p["pathway"] for p in pathways if isinstance(p, dict)]
239
+ )
server/requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- openenv[core]>=0.2.0
2
- fastapi>=0.115.0
3
- uvicorn>=0.24.0
4
-
5
-
6
-
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+
5
+
6
+
server/rewards/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .reward import RewardBreakdown, RewardComputer
2
+
3
+ __all__ = ["RewardBreakdown", "RewardComputer"]
server/rewards/reward.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decomposable reward function for the bio-experiment planning POMDP.
2
+
3
+ Reward components
4
+ ─────────────────
5
+ r_validity β€” biological validity of the chosen action
6
+ r_ordering β€” correct ordering of experiment steps
7
+ r_info_gain β€” information gain from the step's output
8
+ r_efficiency β€” resource efficiency (budget & time normalised)
9
+ r_novelty β€” bonus for non-redundant, non-trivial actions
10
+ r_penalty β€” penalties for violations, redundancy, waste
11
+ r_terminal β€” terminal quality & calibration against hidden truth
12
+
13
+ Potential-based shaping
14
+ Ο†(s) β€” progress potential used for dense shaping signal
15
+
16
+ The final step reward is:
17
+ R_t = r_validity + r_ordering + r_info_gain + r_efficiency
18
+ + r_novelty + r_penalty + Ξ³[Ο†(s_{t+1}) βˆ’ Ο†(s_t)]
19
+
20
+ The terminal reward adds:
21
+ R_T += r_terminal
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Dict, List, Optional
28
+
29
+ from models import (
30
+ ActionType,
31
+ ConclusionClaim,
32
+ ExperimentAction,
33
+ IntermediateOutput,
34
+ META_ACTIONS,
35
+ WET_LAB_ACTIONS,
36
+ )
37
+
38
+ from server.simulator.latent_state import FullLatentState
39
+
40
+
41
+ @dataclass
42
+ class RewardBreakdown:
43
+ validity: float = 0.0
44
+ ordering: float = 0.0
45
+ info_gain: float = 0.0
46
+ efficiency: float = 0.0
47
+ novelty: float = 0.0
48
+ penalty: float = 0.0
49
+ shaping: float = 0.0
50
+ terminal: float = 0.0
51
+ components: Dict[str, float] = field(default_factory=dict)
52
+
53
+ @property
54
+ def total(self) -> float:
55
+ return (
56
+ self.validity
57
+ + self.ordering
58
+ + self.info_gain
59
+ + self.efficiency
60
+ + self.novelty
61
+ + self.penalty
62
+ + self.shaping
63
+ + self.terminal
64
+ )
65
+
66
+ def to_dict(self) -> Dict[str, float]:
67
+ d = {
68
+ "validity": self.validity,
69
+ "ordering": self.ordering,
70
+ "info_gain": self.info_gain,
71
+ "efficiency": self.efficiency,
72
+ "novelty": self.novelty,
73
+ "penalty": self.penalty,
74
+ "shaping": self.shaping,
75
+ "terminal": self.terminal,
76
+ "total": self.total,
77
+ }
78
+ d.update(self.components)
79
+ return d
80
+
81
+
82
+ class RewardComputer:
83
+ """Computes step-wise and terminal rewards.
84
+
85
+ Parameters
86
+ ----------
87
+ gamma : float
88
+ Discount factor for potential-based shaping (default 0.99).
89
+ efficiency_weight : float
90
+ Relative importance of resource efficiency.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ gamma: float = 0.99,
96
+ efficiency_weight: float = 0.3,
97
+ info_gain_weight: float = 0.4,
98
+ validity_weight: float = 0.3,
99
+ ):
100
+ self.gamma = gamma
101
+ self.w_eff = efficiency_weight
102
+ self.w_ig = info_gain_weight
103
+ self.w_val = validity_weight
104
+
105
+ # ── step reward ─────────────────────────────────────────────────────
106
+
107
+ def step_reward(
108
+ self,
109
+ action: ExperimentAction,
110
+ prev_state: FullLatentState,
111
+ next_state: FullLatentState,
112
+ output: IntermediateOutput,
113
+ hard_violations: List[str],
114
+ soft_violations: List[str],
115
+ ) -> RewardBreakdown:
116
+ rb = RewardBreakdown()
117
+
118
+ # validity
119
+ if hard_violations:
120
+ rb.validity = -1.0
121
+ rb.penalty = -0.5 * len(hard_violations)
122
+ rb.components["hard_violations"] = len(hard_violations)
123
+ return rb
124
+
125
+ rb.validity = self.w_val * (1.0 if output.success else 0.0)
126
+
127
+ # ordering bonus: +0.2 if the step was a natural next step
128
+ rb.ordering = 0.2 * self._ordering_score(action, prev_state)
129
+
130
+ # information gain proxy: quality Γ— (1 - uncertainty)
131
+ rb.info_gain = self.w_ig * output.quality_score * (1.0 - output.uncertainty)
132
+
133
+ # efficiency: normalised cost relative to budget
134
+ budget_frac = (
135
+ (next_state.resources.budget_used - prev_state.resources.budget_used)
136
+ / max(next_state.resources.budget_total, 1)
137
+ )
138
+ rb.efficiency = self.w_eff * max(0.0, 1.0 - 5.0 * budget_frac)
139
+
140
+ # novelty: small bonus for non-redundant steps
141
+ if not soft_violations:
142
+ rb.novelty = 0.1
143
+
144
+ # penalties
145
+ rb.penalty = -0.15 * len(soft_violations)
146
+
147
+ # potential-based shaping
148
+ phi_prev = self._potential(prev_state)
149
+ phi_next = self._potential(next_state)
150
+ rb.shaping = self.gamma * phi_next - phi_prev
151
+
152
+ return rb
153
+
154
+ # ── terminal reward ──────────────────────────────────��──────────────
155
+
156
+ def terminal_reward(
157
+ self,
158
+ state: FullLatentState,
159
+ conclusions: List[ConclusionClaim],
160
+ task_success_criteria: List[str],
161
+ ) -> RewardBreakdown:
162
+ rb = RewardBreakdown()
163
+
164
+ # pipeline completeness (0-1)
165
+ completeness = self._completeness(state)
166
+ rb.components["completeness"] = completeness
167
+
168
+ # calibration: how well conclusions align with hidden ground truth
169
+ calibration = self._calibration(state, conclusions)
170
+ rb.components["calibration"] = calibration
171
+
172
+ # efficiency bonus at terminal
173
+ budget_eff = state.resources.budget_remaining / max(
174
+ state.resources.budget_total, 1
175
+ )
176
+ time_eff = state.resources.time_remaining_days / max(
177
+ state.resources.time_limit_days, 1
178
+ )
179
+ rb.components["budget_efficiency"] = budget_eff
180
+ rb.components["time_efficiency"] = time_eff
181
+
182
+ # over-confidence penalty
183
+ overconf = self._overconfidence_penalty(state, conclusions)
184
+ rb.components["overconfidence_penalty"] = overconf
185
+
186
+ rb.terminal = (
187
+ 3.0 * completeness
188
+ + 4.0 * calibration
189
+ + 1.0 * (budget_eff + time_eff) / 2.0
190
+ + overconf
191
+ )
192
+ return rb
193
+
194
+ # ── helpers ─────────────────────────────────────────────────────────
195
+
196
+ def _ordering_score(
197
+ self, action: ExperimentAction, s: FullLatentState
198
+ ) -> float:
199
+ """Heuristic: 1.0 if this step naturally follows the current progress."""
200
+ at = action.action_type
201
+ p = s.progress
202
+ NATURAL_NEXT = {
203
+ ActionType.COLLECT_SAMPLE: not p.samples_collected,
204
+ ActionType.PREPARE_LIBRARY: p.samples_collected and not p.library_prepared,
205
+ ActionType.SEQUENCE_CELLS: p.library_prepared and not p.cells_sequenced,
206
+ ActionType.RUN_QC: p.cells_sequenced and not p.qc_performed,
207
+ ActionType.FILTER_DATA: p.qc_performed and not p.data_filtered,
208
+ ActionType.NORMALIZE_DATA: p.data_filtered and not p.data_normalized,
209
+ ActionType.CLUSTER_CELLS: p.data_normalized and not p.cells_clustered,
210
+ ActionType.DIFFERENTIAL_EXPRESSION: p.data_normalized and not p.de_performed,
211
+ ActionType.PATHWAY_ENRICHMENT: p.de_performed and not p.pathways_analyzed,
212
+ ActionType.MARKER_SELECTION: p.de_performed and not p.markers_discovered,
213
+ ActionType.VALIDATE_MARKER: p.markers_discovered and not p.markers_validated,
214
+ ActionType.SYNTHESIZE_CONCLUSION: (
215
+ p.de_performed or p.cells_clustered
216
+ ) and not p.conclusion_reached,
217
+ }
218
+ return 1.0 if NATURAL_NEXT.get(at, False) else 0.3
219
+
220
+ def _potential(self, s: FullLatentState) -> float:
221
+ """Progress potential Ο†(s) β€” counts completed milestones."""
222
+ p = s.progress
223
+ milestones = [
224
+ p.samples_collected,
225
+ p.library_prepared,
226
+ p.cells_sequenced,
227
+ p.qc_performed,
228
+ p.data_filtered,
229
+ p.data_normalized,
230
+ p.cells_clustered,
231
+ p.de_performed,
232
+ p.pathways_analyzed,
233
+ p.markers_discovered,
234
+ p.markers_validated,
235
+ p.conclusion_reached,
236
+ ]
237
+ return sum(milestones) / len(milestones)
238
+
239
+ def _completeness(self, s: FullLatentState) -> float:
240
+ p = s.progress
241
+ core = [
242
+ p.samples_collected,
243
+ p.cells_sequenced,
244
+ p.qc_performed,
245
+ p.data_filtered,
246
+ p.data_normalized,
247
+ p.de_performed or p.cells_clustered,
248
+ p.conclusion_reached,
249
+ ]
250
+ return sum(core) / len(core)
251
+
252
+ def _calibration(
253
+ self, s: FullLatentState, conclusions: List[ConclusionClaim]
254
+ ) -> float:
255
+ if not conclusions:
256
+ return 0.0
257
+
258
+ true_mechanisms = set(s.biology.causal_mechanisms)
259
+ true_markers = set(s.biology.true_markers)
260
+ score = 0.0
261
+ n = len(conclusions)
262
+
263
+ for c in conclusions:
264
+ claim_lower = c.claim.lower()
265
+ match = any(m.lower() in claim_lower for m in true_mechanisms)
266
+ marker_match = any(m.lower() in claim_lower for m in true_markers)
267
+ if match or marker_match:
268
+ score += 1.0
269
+ else:
270
+ score -= 0.3
271
+ return max(0.0, min(1.0, score / max(n, 1)))
272
+
273
+ def _overconfidence_penalty(
274
+ self, s: FullLatentState, conclusions: List[ConclusionClaim]
275
+ ) -> float:
276
+ """Penalise high-confidence claims that disagree with ground truth."""
277
+ penalty = 0.0
278
+ true_set = set(
279
+ m.lower() for m in s.biology.causal_mechanisms + s.biology.true_markers
280
+ )
281
+ for c in conclusions:
282
+ is_correct = any(t in c.claim.lower() for t in true_set)
283
+ if c.confidence > 0.8 and not is_correct:
284
+ penalty -= 0.5 * c.confidence
285
+ return penalty
server/rules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .engine import RuleEngine, RuleViolation
2
+
3
+ __all__ = ["RuleEngine", "RuleViolation"]
server/rules/engine.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Biological rule engine β€” hard and soft constraint checking.
2
+
3
+ Hard constraints block action execution entirely.
4
+ Soft constraints allow execution but degrade output quality and incur penalties.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from typing import List
12
+
13
+ from models import ActionType, ExperimentAction
14
+
15
+ from server.simulator.latent_state import FullLatentState
16
+
17
+
18
+ class Severity(str, Enum):
19
+ HARD = "hard"
20
+ SOFT = "soft"
21
+
22
+
23
+ @dataclass
24
+ class RuleViolation:
25
+ rule_id: str
26
+ severity: Severity
27
+ message: str
28
+
29
+
30
+ class RuleEngine:
31
+ """Evaluates biological and resource constraints against the current
32
+ latent state before each action is applied.
33
+ """
34
+
35
+ def check(
36
+ self, action: ExperimentAction, state: FullLatentState
37
+ ) -> List[RuleViolation]:
38
+ violations: List[RuleViolation] = []
39
+ violations.extend(self._check_prerequisites(action, state))
40
+ violations.extend(self._check_resource_constraints(action, state))
41
+ violations.extend(self._check_redundancy(action, state))
42
+ violations.extend(self._check_causal_validity(action, state))
43
+ return violations
44
+
45
+ def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
46
+ return [v.message for v in violations if v.severity == Severity.HARD]
47
+
48
+ def soft_violations(self, violations: List[RuleViolation]) -> List[str]:
49
+ return [v.message for v in violations if v.severity == Severity.SOFT]
50
+
51
+ # ── prerequisite rules ──────────────────────────────────────────────
52
+
53
+ def _check_prerequisites(
54
+ self, action: ExperimentAction, s: FullLatentState
55
+ ) -> List[RuleViolation]:
56
+ vs: List[RuleViolation] = []
57
+ at = action.action_type
58
+ p = s.progress
59
+
60
+ REQUIRES = {
61
+ ActionType.PREPARE_LIBRARY: [
62
+ ("samples_collected", "Cannot prepare library without collected samples"),
63
+ ],
64
+ ActionType.SEQUENCE_CELLS: [
65
+ ("library_prepared", "Cannot sequence without library preparation"),
66
+ ],
67
+ ActionType.RUN_QC: [
68
+ ("cells_sequenced", "Cannot run QC before sequencing"),
69
+ ],
70
+ ActionType.FILTER_DATA: [
71
+ ("qc_performed", "Cannot filter data before QC"),
72
+ ],
73
+ ActionType.NORMALIZE_DATA: [
74
+ ("data_filtered", "Cannot normalise before filtering"),
75
+ ],
76
+ ActionType.INTEGRATE_BATCHES: [
77
+ ("data_normalized", "Cannot integrate batches before normalisation"),
78
+ ],
79
+ ActionType.CLUSTER_CELLS: [
80
+ ("data_normalized", "Cannot cluster before normalisation"),
81
+ ],
82
+ ActionType.DIFFERENTIAL_EXPRESSION: [
83
+ ("data_normalized", "Cannot run DE before normalisation"),
84
+ ],
85
+ ActionType.TRAJECTORY_ANALYSIS: [
86
+ ("data_normalized", "Cannot infer trajectories before normalisation"),
87
+ ],
88
+ ActionType.PATHWAY_ENRICHMENT: [
89
+ ("de_performed", "Cannot run pathway enrichment without DE results"),
90
+ ],
91
+ ActionType.REGULATORY_NETWORK_INFERENCE: [
92
+ ("data_normalized", "Cannot infer networks before normalisation"),
93
+ ],
94
+ ActionType.MARKER_SELECTION: [
95
+ ("de_performed", "Cannot select markers without DE results"),
96
+ ],
97
+ ActionType.VALIDATE_MARKER: [
98
+ ("markers_discovered", "Cannot validate markers before discovery"),
99
+ ],
100
+ ActionType.PERTURB_GENE: [
101
+ ("samples_collected", "Cannot perturb without samples"),
102
+ ],
103
+ ActionType.PERTURB_COMPOUND: [
104
+ ("samples_collected", "Cannot perturb without samples"),
105
+ ],
106
+ ActionType.CULTURE_CELLS: [
107
+ ("samples_collected", "Cannot culture without samples"),
108
+ ],
109
+ }
110
+
111
+ for flag, msg in REQUIRES.get(at, []):
112
+ if not getattr(p, flag, False):
113
+ vs.append(RuleViolation(
114
+ rule_id=f"prereq_{at.value}_{flag}",
115
+ severity=Severity.HARD,
116
+ message=msg,
117
+ ))
118
+ return vs
119
+
120
+ # ── resource constraints ────────────────────────────────────────────
121
+
122
+ def _check_resource_constraints(
123
+ self, action: ExperimentAction, s: FullLatentState
124
+ ) -> List[RuleViolation]:
125
+ vs: List[RuleViolation] = []
126
+ if s.resources.budget_exhausted:
127
+ vs.append(RuleViolation(
128
+ rule_id="budget_exhausted",
129
+ severity=Severity.HARD,
130
+ message="Budget exhausted β€” no further actions possible",
131
+ ))
132
+ if s.resources.time_exhausted:
133
+ vs.append(RuleViolation(
134
+ rule_id="time_exhausted",
135
+ severity=Severity.HARD,
136
+ message="Time limit reached β€” no further actions possible",
137
+ ))
138
+
139
+ remaining = s.resources.budget_remaining
140
+ from server.simulator.transition import ACTION_COSTS
141
+ cost, _ = ACTION_COSTS.get(action.action_type, (0, 0))
142
+ if cost > remaining and remaining > 0:
143
+ vs.append(RuleViolation(
144
+ rule_id="budget_insufficient",
145
+ severity=Severity.SOFT,
146
+ message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
147
+ ))
148
+ return vs
149
+
150
+ # ── redundancy checks ───────────────────────────────────────────────
151
+
152
+ def _check_redundancy(
153
+ self, action: ExperimentAction, s: FullLatentState
154
+ ) -> List[RuleViolation]:
155
+ vs: List[RuleViolation] = []
156
+ at = action.action_type
157
+ p = s.progress
158
+
159
+ REDUNDANT = {
160
+ ActionType.COLLECT_SAMPLE: "samples_collected",
161
+ ActionType.PREPARE_LIBRARY: "library_prepared",
162
+ ActionType.SEQUENCE_CELLS: "cells_sequenced",
163
+ ActionType.RUN_QC: "qc_performed",
164
+ ActionType.FILTER_DATA: "data_filtered",
165
+ ActionType.NORMALIZE_DATA: "data_normalized",
166
+ }
167
+ flag = REDUNDANT.get(at)
168
+ if flag and getattr(p, flag, False):
169
+ vs.append(RuleViolation(
170
+ rule_id=f"redundant_{at.value}",
171
+ severity=Severity.SOFT,
172
+ message=f"Step '{at.value}' already completed β€” redundant action",
173
+ ))
174
+ return vs
175
+
176
+ # ── causal validity ─────────────────────────────────────────────────
177
+
178
+ def _check_causal_validity(
179
+ self, action: ExperimentAction, s: FullLatentState
180
+ ) -> List[RuleViolation]:
181
+ vs: List[RuleViolation] = []
182
+ if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
183
+ if not s.progress.de_performed and not s.progress.cells_clustered:
184
+ vs.append(RuleViolation(
185
+ rule_id="premature_conclusion",
186
+ severity=Severity.SOFT,
187
+ message="Synthesising conclusion without substantive analysis",
188
+ ))
189
+
190
+ claims = action.parameters.get("claims", [])
191
+ for claim in claims:
192
+ if isinstance(claim, dict) and claim.get("claim_type") == "causal":
193
+ if not s.progress.markers_validated and not s.progress.networks_inferred:
194
+ vs.append(RuleViolation(
195
+ rule_id="unsupported_causal_claim",
196
+ severity=Severity.SOFT,
197
+ message="Causal claim without validation or network evidence",
198
+ ))
199
+ break
200
+
201
+ if action.action_type == ActionType.PATHWAY_ENRICHMENT:
202
+ if not s.progress.de_performed:
203
+ vs.append(RuleViolation(
204
+ rule_id="pathway_without_de",
205
+ severity=Severity.SOFT,
206
+ message="Pathway enrichment without DE may yield unreliable results",
207
+ ))
208
+ return vs
server/simulator/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .latent_state import (
2
+ CellPopulation,
3
+ ExperimentProgress,
4
+ FullLatentState,
5
+ GeneProgram,
6
+ LatentBiologicalState,
7
+ ResourceState,
8
+ TechnicalState,
9
+ )
10
+ from .noise import NoiseModel
11
+ from .output_generator import OutputGenerator
12
+ from .transition import TransitionEngine
13
+
14
+ __all__ = [
15
+ "CellPopulation",
16
+ "ExperimentProgress",
17
+ "FullLatentState",
18
+ "GeneProgram",
19
+ "LatentBiologicalState",
20
+ "NoiseModel",
21
+ "OutputGenerator",
22
+ "ResourceState",
23
+ "TechnicalState",
24
+ "TransitionEngine",
25
+ ]
server/simulator/latent_state.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Latent biological and technical state β€” hidden from the agent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class CellPopulation(BaseModel):
11
+ """Ground-truth cell sub-population in the simulated tissue."""
12
+
13
+ name: str
14
+ proportion: float = Field(ge=0.0, le=1.0)
15
+ marker_genes: List[str] = Field(default_factory=list)
16
+ state: str = "quiescent"
17
+ condition_response: Dict[str, float] = Field(default_factory=dict)
18
+
19
+
20
+ class GeneProgram(BaseModel):
21
+ """A latent gene-regulatory programme."""
22
+
23
+ name: str
24
+ genes: List[str] = Field(default_factory=list)
25
+ activity_level: float = Field(0.5, ge=0.0, le=1.0)
26
+ condition_dependent: bool = False
27
+ conditions_active: List[str] = Field(default_factory=list)
28
+
29
+
30
+ class LatentBiologicalState(BaseModel):
31
+ """Hidden ground-truth biology the agent cannot directly observe."""
32
+
33
+ cell_populations: List[CellPopulation] = Field(default_factory=list)
34
+ true_de_genes: Dict[str, Dict[str, float]] = Field(
35
+ default_factory=dict,
36
+ description="comparison_key β†’ {gene: log2FC}",
37
+ )
38
+ true_pathways: Dict[str, float] = Field(
39
+ default_factory=dict,
40
+ description="pathway β†’ activity level",
41
+ )
42
+ gene_programs: List[GeneProgram] = Field(default_factory=list)
43
+ true_trajectory: Optional[Dict[str, Any]] = None
44
+ true_regulatory_network: Dict[str, List[str]] = Field(
45
+ default_factory=dict,
46
+ description="TF β†’ target genes",
47
+ )
48
+ perturbation_effects: Dict[str, Dict[str, float]] = Field(
49
+ default_factory=dict,
50
+ description="perturbation β†’ {gene: effect_size}",
51
+ )
52
+ confounders: Dict[str, float] = Field(default_factory=dict)
53
+ true_markers: List[str] = Field(default_factory=list)
54
+ causal_mechanisms: List[str] = Field(default_factory=list)
55
+ n_true_cells: int = 10_000
56
+
57
+
58
+ class TechnicalState(BaseModel):
59
+ """Hidden technical parameters that shape experimental noise."""
60
+
61
+ batch_effects: Dict[str, float] = Field(default_factory=dict)
62
+ ambient_rna_fraction: float = 0.05
63
+ doublet_rate: float = 0.04
64
+ dropout_rate: float = 0.1
65
+ sample_quality: float = Field(0.9, ge=0.0, le=1.0)
66
+ library_complexity: float = Field(0.8, ge=0.0, le=1.0)
67
+ sequencing_depth_factor: float = 1.0
68
+ capture_efficiency: float = 0.6
69
+
70
+
71
+ class ExperimentProgress(BaseModel):
72
+ """Flags tracking which experiment stages have been completed."""
73
+
74
+ samples_collected: bool = False
75
+ cohort_selected: bool = False
76
+ cells_cultured: bool = False
77
+ library_prepared: bool = False
78
+ perturbation_applied: bool = False
79
+ cells_sequenced: bool = False
80
+ qc_performed: bool = False
81
+ data_filtered: bool = False
82
+ data_normalized: bool = False
83
+ batches_integrated: bool = False
84
+ cells_clustered: bool = False
85
+ de_performed: bool = False
86
+ trajectories_inferred: bool = False
87
+ pathways_analyzed: bool = False
88
+ networks_inferred: bool = False
89
+ markers_discovered: bool = False
90
+ markers_validated: bool = False
91
+ conclusion_reached: bool = False
92
+
93
+ n_cells_after_filter: Optional[int] = None
94
+ n_clusters_found: Optional[int] = None
95
+ n_de_genes_found: Optional[int] = None
96
+ n_markers_found: Optional[int] = None
97
+
98
+
99
+ class ResourceState(BaseModel):
100
+ """Full internal resource tracking (superset of agent-visible ResourceUsage)."""
101
+
102
+ budget_total: float = 100_000.0
103
+ budget_used: float = 0.0
104
+ time_limit_days: float = 180.0
105
+ time_used_days: float = 0.0
106
+ samples_available: int = 0
107
+ samples_consumed: int = 0
108
+ compute_hours_used: float = 0.0
109
+ sequencing_lanes_used: int = 0
110
+ reagent_kits_used: int = 0
111
+
112
+ @property
113
+ def budget_remaining(self) -> float:
114
+ return max(0.0, self.budget_total - self.budget_used)
115
+
116
+ @property
117
+ def time_remaining_days(self) -> float:
118
+ return max(0.0, self.time_limit_days - self.time_used_days)
119
+
120
+ @property
121
+ def budget_exhausted(self) -> bool:
122
+ return self.budget_remaining <= 0
123
+
124
+ @property
125
+ def time_exhausted(self) -> bool:
126
+ return self.time_remaining_days <= 0
127
+
128
+
129
+ class FullLatentState(BaseModel):
130
+ """Complete hidden state of the simulated biological world."""
131
+
132
+ biology: LatentBiologicalState = Field(
133
+ default_factory=LatentBiologicalState
134
+ )
135
+ technical: TechnicalState = Field(default_factory=TechnicalState)
136
+ progress: ExperimentProgress = Field(default_factory=ExperimentProgress)
137
+ resources: ResourceState = Field(default_factory=ResourceState)
138
+ hidden_failure_conditions: List[str] = Field(default_factory=list)
139
+ mechanism_confidence: Dict[str, float] = Field(default_factory=dict)
140
+ discovered_de_genes: List[str] = Field(default_factory=list)
141
+ discovered_clusters: List[str] = Field(default_factory=list)
142
+ step_count: int = 0
143
+ rng_seed: int = 42
server/simulator/noise.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stochastic noise models for the biological simulator."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Tuple
6
+
7
+ import numpy as np
8
+
9
+
10
+ class NoiseModel:
11
+ """Generates calibrated noise for simulated experimental outputs.
12
+
13
+ All randomness is funnelled through a single ``numpy.Generator``
14
+ so that episodes are reproducible given the same seed.
15
+ """
16
+
17
+ def __init__(self, seed: int = 42):
18
+ self.rng = np.random.default_rng(seed)
19
+
20
+ def reseed(self, seed: int) -> None:
21
+ self.rng = np.random.default_rng(seed)
22
+
23
+ # ── expression-level noise ──────────────────────────────────────────
24
+
25
+ def add_expression_noise(
26
+ self,
27
+ true_values: Dict[str, float],
28
+ noise_level: float,
29
+ dropout_rate: float,
30
+ ) -> Dict[str, float]:
31
+ noisy: Dict[str, float] = {}
32
+ for gene, value in true_values.items():
33
+ if self.rng.random() < dropout_rate:
34
+ noisy[gene] = 0.0
35
+ else:
36
+ sigma = noise_level * abs(value) + 0.1
37
+ noisy[gene] = float(value + self.rng.normal(0, sigma))
38
+ return noisy
39
+
40
+ # ── effect-size sampling ────────────────────────────────────────────
41
+
42
+ def sample_effect_sizes(
43
+ self,
44
+ true_effects: Dict[str, float],
45
+ sample_size: int,
46
+ noise_level: float,
47
+ ) -> Dict[str, float]:
48
+ se = noise_level / max(np.sqrt(max(sample_size, 1)), 1e-6)
49
+ return {
50
+ gene: float(effect + self.rng.normal(0, se))
51
+ for gene, effect in true_effects.items()
52
+ }
53
+
54
+ def sample_p_values(
55
+ self,
56
+ true_effects: Dict[str, float],
57
+ sample_size: int,
58
+ noise_level: float,
59
+ ) -> Dict[str, float]:
60
+ """Simulate approximate p-values from z-statistics."""
61
+ from scipy import stats # type: ignore[import-untyped]
62
+
63
+ p_values: Dict[str, float] = {}
64
+ se = noise_level / max(np.sqrt(max(sample_size, 1)), 1e-6)
65
+ for gene, effect in true_effects.items():
66
+ z = abs(effect) / max(se, 1e-8)
67
+ p_values[gene] = float(2 * stats.norm.sf(z))
68
+ return p_values
69
+
70
+ # ── false discovery helpers ─────────────────────────────────────────
71
+
72
+ def generate_false_positives(
73
+ self, n_background_genes: int, fdr: float
74
+ ) -> List[str]:
75
+ n_fp = int(self.rng.binomial(n_background_genes, fdr))
76
+ return [f"FP_GENE_{i}" for i in range(n_fp)]
77
+
78
+ def generate_false_negatives(
79
+ self, true_genes: List[str], fnr: float
80
+ ) -> List[str]:
81
+ """Return the subset of *true_genes* that are missed."""
82
+ return [g for g in true_genes if self.rng.random() < fnr]
83
+
84
+ # ── quality helpers ─────────────────────────────────────────────────
85
+
86
+ def quality_degradation(
87
+ self, base_quality: float, factors: List[float]
88
+ ) -> float:
89
+ q = base_quality
90
+ for f in factors:
91
+ q *= f
92
+ return float(np.clip(q + self.rng.normal(0, 0.02), 0.0, 1.0))
93
+
94
+ def sample_qc_metric(
95
+ self, mean: float, std: float, clip_lo: float = 0.0, clip_hi: float = 1.0
96
+ ) -> float:
97
+ return float(np.clip(self.rng.normal(mean, std), clip_lo, clip_hi))
98
+
99
+ def sample_count(self, lam: float) -> int:
100
+ return int(self.rng.poisson(max(lam, 0)))
101
+
102
+ def coin_flip(self, p: float) -> bool:
103
+ return bool(self.rng.random() < p)
104
+
105
+ def sample_cluster_count(
106
+ self, n_true_populations: int, quality: float
107
+ ) -> int:
108
+ """Over- or under-clustering depending on preprocessing quality."""
109
+ delta = self.rng.integers(-2, 3)
110
+ noise_clusters = max(0, int(round((1.0 - quality) * 3)))
111
+ return max(1, n_true_populations + delta + noise_clusters)
112
+
113
+ def shuffle_ranking(
114
+ self, items: List[str], noise_level: float
115
+ ) -> List[str]:
116
+ """Permute a ranking with Gaussian noise on ordinals."""
117
+ n = len(items)
118
+ if n == 0:
119
+ return []
120
+ scores = np.arange(n, dtype=float) + self.rng.normal(
121
+ 0, noise_level * n, size=n
122
+ )
123
+ order = np.argsort(scores)
124
+ return [items[int(i)] for i in order]
server/simulator/output_generator.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate simulated intermediate outputs conditioned on latent state."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List
6
+
7
+ from models import (
8
+ ActionType,
9
+ ExperimentAction,
10
+ IntermediateOutput,
11
+ OutputType,
12
+ )
13
+
14
+ from .latent_state import FullLatentState
15
+ from .noise import NoiseModel
16
+
17
+
18
+ class OutputGenerator:
19
+ """Creates structured ``IntermediateOutput`` objects conditioned on the
20
+ hidden latent state, the action taken, and a stochastic noise model.
21
+ """
22
+
23
+ def __init__(self, noise: NoiseModel):
24
+ self.noise = noise
25
+
26
+ def generate(
27
+ self,
28
+ action: ExperimentAction,
29
+ state: FullLatentState,
30
+ step_index: int,
31
+ ) -> IntermediateOutput:
32
+ handler = _HANDLERS.get(action.action_type, self._default)
33
+ return handler(self, action, state, step_index)
34
+
35
+ # ── wet-lab outputs ─────────────────────────────────────────────────
36
+
37
+ def _collect_sample(
38
+ self, action: ExperimentAction, s: FullLatentState, idx: int
39
+ ) -> IntermediateOutput:
40
+ n_samples = action.parameters.get("n_samples", 6)
41
+ quality = self.noise.quality_degradation(
42
+ s.technical.sample_quality, [s.technical.capture_efficiency]
43
+ )
44
+ return IntermediateOutput(
45
+ output_type=OutputType.SAMPLE_COLLECTION_RESULT,
46
+ step_index=idx,
47
+ quality_score=quality,
48
+ summary=f"Collected {n_samples} samples (quality={quality:.2f})",
49
+ data={
50
+ "n_samples": n_samples,
51
+ "quality": quality,
52
+ "organism": "human",
53
+ "tissue": "blood",
54
+ },
55
+ artifacts_available=["raw_samples"],
56
+ )
57
+
58
+ def _select_cohort(
59
+ self, action: ExperimentAction, s: FullLatentState, idx: int
60
+ ) -> IntermediateOutput:
61
+ criteria = action.parameters.get("criteria", {})
62
+ n_selected = action.parameters.get("n_selected", 4)
63
+ return IntermediateOutput(
64
+ output_type=OutputType.COHORT_RESULT,
65
+ step_index=idx,
66
+ summary=f"Selected cohort of {n_selected} samples with criteria {criteria}",
67
+ data={"n_selected": n_selected, "criteria": criteria},
68
+ artifacts_available=["cohort_manifest"],
69
+ )
70
+
71
+ def _prepare_library(
72
+ self, action: ExperimentAction, s: FullLatentState, idx: int
73
+ ) -> IntermediateOutput:
74
+ complexity = self.noise.quality_degradation(
75
+ s.technical.library_complexity,
76
+ [s.technical.sample_quality],
77
+ )
78
+ return IntermediateOutput(
79
+ output_type=OutputType.LIBRARY_PREP_RESULT,
80
+ step_index=idx,
81
+ quality_score=complexity,
82
+ summary=f"Library prepared (complexity={complexity:.2f})",
83
+ data={
84
+ "library_complexity": complexity,
85
+ "method": action.method or "10x_chromium",
86
+ },
87
+ artifacts_available=["prepared_library"],
88
+ )
89
+
90
+ def _culture_cells(
91
+ self, action: ExperimentAction, s: FullLatentState, idx: int
92
+ ) -> IntermediateOutput:
93
+ days = action.parameters.get("days", 7)
94
+ viability = self.noise.sample_qc_metric(0.92, 0.05, 0.5, 1.0)
95
+ return IntermediateOutput(
96
+ output_type=OutputType.CULTURE_RESULT,
97
+ step_index=idx,
98
+ quality_score=viability,
99
+ summary=f"Cultured for {days}d, viability={viability:.2f}",
100
+ data={"days": days, "viability": viability},
101
+ artifacts_available=["cultured_cells"],
102
+ )
103
+
104
+ def _perturb(
105
+ self, action: ExperimentAction, s: FullLatentState, idx: int
106
+ ) -> IntermediateOutput:
107
+ target = action.parameters.get("target", "unknown")
108
+ efficiency = self.noise.sample_qc_metric(0.75, 0.15, 0.0, 1.0)
109
+ return IntermediateOutput(
110
+ output_type=OutputType.PERTURBATION_RESULT,
111
+ step_index=idx,
112
+ quality_score=efficiency,
113
+ summary=f"Perturbation of {target} (efficiency={efficiency:.2f})",
114
+ data={
115
+ "target": target,
116
+ "efficiency": efficiency,
117
+ "type": action.action_type.value,
118
+ },
119
+ artifacts_available=["perturbed_cells"],
120
+ )
121
+
122
+ def _sequence_cells(
123
+ self, action: ExperimentAction, s: FullLatentState, idx: int
124
+ ) -> IntermediateOutput:
125
+ depth = s.technical.sequencing_depth_factor
126
+ n_cells = self.noise.sample_count(
127
+ s.biology.n_true_cells * s.technical.capture_efficiency
128
+ )
129
+ n_genes = self.noise.sample_count(18_000)
130
+ median_umi = self.noise.sample_count(int(3000 * depth))
131
+ quality = self.noise.quality_degradation(
132
+ s.technical.sample_quality,
133
+ [s.technical.library_complexity, s.technical.capture_efficiency],
134
+ )
135
+ return IntermediateOutput(
136
+ output_type=OutputType.SEQUENCING_RESULT,
137
+ step_index=idx,
138
+ quality_score=quality,
139
+ summary=(
140
+ f"Sequenced {n_cells} cells, {n_genes} genes detected, "
141
+ f"median UMI={median_umi}"
142
+ ),
143
+ data={
144
+ "n_cells": n_cells,
145
+ "n_genes": n_genes,
146
+ "median_umi": median_umi,
147
+ "sequencing_saturation": self.noise.sample_qc_metric(0.7, 0.1),
148
+ },
149
+ artifacts_available=["raw_count_matrix"],
150
+ )
151
+
152
+ # ── computational outputs ───────────────────────────────────────────
153
+
154
+ def _run_qc(
155
+ self, action: ExperimentAction, s: FullLatentState, idx: int
156
+ ) -> IntermediateOutput:
157
+ doublet_frac = self.noise.sample_qc_metric(
158
+ s.technical.doublet_rate, 0.01, 0.0, 0.2
159
+ )
160
+ mito_frac = self.noise.sample_qc_metric(0.05, 0.02, 0.0, 0.3)
161
+ ambient_frac = self.noise.sample_qc_metric(
162
+ s.technical.ambient_rna_fraction, 0.01, 0.0, 0.2
163
+ )
164
+ warnings: List[str] = []
165
+ if doublet_frac > 0.08:
166
+ warnings.append(f"High doublet rate ({doublet_frac:.1%})")
167
+ if mito_frac > 0.1:
168
+ warnings.append(f"High mitochondrial fraction ({mito_frac:.1%})")
169
+ quality = 1.0 - (doublet_frac + mito_frac + ambient_frac)
170
+ return IntermediateOutput(
171
+ output_type=OutputType.QC_METRICS,
172
+ step_index=idx,
173
+ quality_score=max(0.0, quality),
174
+ summary="QC metrics computed",
175
+ data={
176
+ "doublet_fraction": doublet_frac,
177
+ "mitochondrial_fraction": mito_frac,
178
+ "ambient_rna_fraction": ambient_frac,
179
+ "median_genes_per_cell": self.noise.sample_count(2500),
180
+ "median_umi_per_cell": self.noise.sample_count(8000),
181
+ },
182
+ warnings=warnings,
183
+ artifacts_available=["qc_report"],
184
+ )
185
+
186
+ def _filter_data(
187
+ self, action: ExperimentAction, s: FullLatentState, idx: int
188
+ ) -> IntermediateOutput:
189
+ retain_frac = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
190
+ n_before = s.biology.n_true_cells
191
+ n_after = max(100, int(n_before * retain_frac))
192
+ return IntermediateOutput(
193
+ output_type=OutputType.COUNT_MATRIX_SUMMARY,
194
+ step_index=idx,
195
+ quality_score=retain_frac,
196
+ summary=f"Filtered {n_before} β†’ {n_after} cells ({retain_frac:.0%} retained)",
197
+ data={
198
+ "n_cells_before": n_before,
199
+ "n_cells_after": n_after,
200
+ "n_genes_retained": self.noise.sample_count(15_000),
201
+ "retain_fraction": retain_frac,
202
+ },
203
+ artifacts_available=["filtered_count_matrix"],
204
+ )
205
+
206
+ def _normalize_data(
207
+ self, action: ExperimentAction, s: FullLatentState, idx: int
208
+ ) -> IntermediateOutput:
209
+ method = action.method or "log_normalize"
210
+ return IntermediateOutput(
211
+ output_type=OutputType.COUNT_MATRIX_SUMMARY,
212
+ step_index=idx,
213
+ summary=f"Normalized with {method}",
214
+ data={"method": method, "n_hvg": self.noise.sample_count(2000)},
215
+ artifacts_available=["normalized_matrix", "hvg_list"],
216
+ )
217
+
218
+ def _integrate_batches(
219
+ self, action: ExperimentAction, s: FullLatentState, idx: int
220
+ ) -> IntermediateOutput:
221
+ method = action.method or "harmony"
222
+ residual = self.noise.sample_qc_metric(0.05, 0.03, 0.0, 0.3)
223
+ return IntermediateOutput(
224
+ output_type=OutputType.EMBEDDING_SUMMARY,
225
+ step_index=idx,
226
+ quality_score=1.0 - residual,
227
+ summary=f"Batch integration ({method}), residual batch effect={residual:.2f}",
228
+ data={
229
+ "method": method,
230
+ "residual_batch_effect": residual,
231
+ "n_batches": len(s.technical.batch_effects) or 1,
232
+ },
233
+ artifacts_available=["integrated_embedding"],
234
+ )
235
+
236
+ def _cluster_cells(
237
+ self, action: ExperimentAction, s: FullLatentState, idx: int
238
+ ) -> IntermediateOutput:
239
+ n_true = len(s.biology.cell_populations) or 5
240
+ quality = self.noise.quality_degradation(0.8, [0.95])
241
+ n_clusters = self.noise.sample_cluster_count(n_true, quality)
242
+ cluster_names = [f"cluster_{i}" for i in range(n_clusters)]
243
+ sizes = self._random_partition(s.biology.n_true_cells, n_clusters)
244
+ return IntermediateOutput(
245
+ output_type=OutputType.CLUSTER_RESULT,
246
+ step_index=idx,
247
+ quality_score=quality,
248
+ summary=f"Found {n_clusters} clusters (ground-truth populations: {n_true})",
249
+ data={
250
+ "n_clusters": n_clusters,
251
+ "cluster_names": cluster_names,
252
+ "cluster_sizes": sizes,
253
+ "silhouette_score": self.noise.sample_qc_metric(0.35, 0.1, -1.0, 1.0),
254
+ },
255
+ uncertainty=abs(n_clusters - n_true) / max(n_true, 1),
256
+ artifacts_available=["cluster_assignments", "umap_embedding"],
257
+ )
258
+
259
+ def _differential_expression(
260
+ self, action: ExperimentAction, s: FullLatentState, idx: int
261
+ ) -> IntermediateOutput:
262
+ comparison = action.parameters.get("comparison", "disease_vs_healthy")
263
+ true_effects = s.biology.true_de_genes.get(comparison, {})
264
+
265
+ n_cells = s.progress.n_cells_after_filter or s.biology.n_true_cells
266
+ noise_level = s.technical.dropout_rate + 0.1 * (1.0 - s.technical.sample_quality)
267
+ observed = self.noise.sample_effect_sizes(true_effects, n_cells, noise_level)
268
+
269
+ fp_genes = self.noise.generate_false_positives(5000, 0.002 + noise_level * 0.01)
270
+ for g in fp_genes:
271
+ observed[g] = float(self.noise.rng.normal(0, 0.3))
272
+
273
+ fn_genes = self.noise.generate_false_negatives(list(true_effects.keys()), 0.15)
274
+ for g in fn_genes:
275
+ observed.pop(g, None)
276
+
277
+ top_genes = sorted(observed.items(), key=lambda kv: abs(kv[1]), reverse=True)[:50]
278
+ return IntermediateOutput(
279
+ output_type=OutputType.DE_RESULT,
280
+ step_index=idx,
281
+ quality_score=self.noise.quality_degradation(0.8, [1.0 - noise_level]),
282
+ summary=f"DE analysis ({comparison}): {len(observed)} genes tested, {len(top_genes)} top hits",
283
+ data={
284
+ "comparison": comparison,
285
+ "n_tested": len(observed),
286
+ "top_genes": [
287
+ {"gene": g, "log2FC": round(fc, 3)} for g, fc in top_genes
288
+ ],
289
+ "n_significant": sum(1 for _, fc in observed.items() if abs(fc) > 0.5),
290
+ },
291
+ uncertainty=noise_level,
292
+ artifacts_available=["de_table"],
293
+ )
294
+
295
+ def _trajectory_analysis(
296
+ self, action: ExperimentAction, s: FullLatentState, idx: int
297
+ ) -> IntermediateOutput:
298
+ has_trajectory = s.biology.true_trajectory is not None
299
+ quality = self.noise.quality_degradation(0.7 if has_trajectory else 0.3, [0.9])
300
+ summary_data: Dict[str, Any] = {"method": action.method or "monocle3"}
301
+ if has_trajectory:
302
+ summary_data.update({
303
+ "n_lineages": s.biology.true_trajectory.get("n_lineages", 1),
304
+ "pseudotime_range": [0.0, 1.0],
305
+ "branching_detected": s.biology.true_trajectory.get("branching", False),
306
+ })
307
+ else:
308
+ summary_data["n_lineages"] = self.noise.sample_count(1) + 1
309
+ summary_data["pseudotime_range"] = [0.0, 1.0]
310
+ summary_data["branching_detected"] = self.noise.coin_flip(0.3)
311
+
312
+ return IntermediateOutput(
313
+ output_type=OutputType.TRAJECTORY_RESULT,
314
+ step_index=idx,
315
+ quality_score=quality,
316
+ summary="Trajectory / pseudotime analysis complete",
317
+ data=summary_data,
318
+ uncertainty=0.2 if has_trajectory else 0.6,
319
+ artifacts_available=["pseudotime_values", "lineage_graph"],
320
+ )
321
+
322
+ def _pathway_enrichment(
323
+ self, action: ExperimentAction, s: FullLatentState, idx: int
324
+ ) -> IntermediateOutput:
325
+ true_pathways = s.biology.true_pathways
326
+ noise_level = 0.15
327
+ observed: Dict[str, float] = {}
328
+ for pw, activity in true_pathways.items():
329
+ observed[pw] = activity + float(self.noise.rng.normal(0, noise_level))
330
+
331
+ for i in range(self.noise.sample_count(2)):
332
+ observed[f"FP_PATHWAY_{i}"] = float(self.noise.rng.uniform(0.3, 0.6))
333
+
334
+ top = sorted(observed.items(), key=lambda kv: kv[1], reverse=True)[:15]
335
+ return IntermediateOutput(
336
+ output_type=OutputType.PATHWAY_RESULT,
337
+ step_index=idx,
338
+ quality_score=self.noise.quality_degradation(0.8, [0.95]),
339
+ summary=f"Pathway enrichment: {len(top)} significant pathways",
340
+ data={
341
+ "method": action.method or "GSEA",
342
+ "top_pathways": [
343
+ {"pathway": p, "score": round(s, 3)} for p, s in top
344
+ ],
345
+ },
346
+ uncertainty=noise_level,
347
+ artifacts_available=["enrichment_table"],
348
+ )
349
+
350
+ def _regulatory_network(
351
+ self, action: ExperimentAction, s: FullLatentState, idx: int
352
+ ) -> IntermediateOutput:
353
+ true_net = s.biology.true_regulatory_network
354
+ n_edges_true = sum(len(v) for v in true_net.values())
355
+ noise_edges = self.noise.sample_count(max(5, int(n_edges_true * 0.3)))
356
+ return IntermediateOutput(
357
+ output_type=OutputType.NETWORK_RESULT,
358
+ step_index=idx,
359
+ quality_score=self.noise.quality_degradation(0.6, [0.9]),
360
+ summary=f"Regulatory network inferred: {n_edges_true + noise_edges} edges",
361
+ data={
362
+ "method": action.method or "SCENIC",
363
+ "n_regulons": len(true_net) + self.noise.sample_count(3),
364
+ "n_edges": n_edges_true + noise_edges,
365
+ "top_regulators": list(true_net.keys())[:10],
366
+ },
367
+ uncertainty=0.35,
368
+ artifacts_available=["regulon_table", "grn_adjacency"],
369
+ )
370
+
371
+ def _marker_selection(
372
+ self, action: ExperimentAction, s: FullLatentState, idx: int
373
+ ) -> IntermediateOutput:
374
+ true_markers = list(s.biology.true_markers)
375
+ noise_level = 0.2
376
+ observed_markers = [
377
+ m for m in true_markers if not self.noise.coin_flip(noise_level)
378
+ ]
379
+ fp = self.noise.generate_false_positives(200, 0.01)
380
+ observed_markers.extend(fp)
381
+ return IntermediateOutput(
382
+ output_type=OutputType.MARKER_RESULT,
383
+ step_index=idx,
384
+ quality_score=self.noise.quality_degradation(0.75, [0.9]),
385
+ summary=f"Selected {len(observed_markers)} candidate markers",
386
+ data={
387
+ "markers": observed_markers[:20],
388
+ "n_candidates": len(observed_markers),
389
+ },
390
+ uncertainty=noise_level,
391
+ artifacts_available=["marker_list"],
392
+ )
393
+
394
+ def _validate_marker(
395
+ self, action: ExperimentAction, s: FullLatentState, idx: int
396
+ ) -> IntermediateOutput:
397
+ marker = action.parameters.get("marker", "unknown")
398
+ is_true = marker in s.biology.true_markers
399
+ validation_correct = not self.noise.coin_flip(0.1)
400
+ validated = is_true == validation_correct
401
+ return IntermediateOutput(
402
+ output_type=OutputType.VALIDATION_RESULT,
403
+ step_index=idx,
404
+ quality_score=0.9 if validation_correct else 0.4,
405
+ summary=f"Marker {marker}: {'validated' if validated else 'not validated'}",
406
+ data={
407
+ "marker": marker,
408
+ "validated": validated,
409
+ "assay": action.method or "qPCR",
410
+ "effect_size": self.noise.sample_qc_metric(
411
+ 1.5 if is_true else 0.2, 0.3, -0.5, 5.0
412
+ ),
413
+ },
414
+ artifacts_available=["validation_data"],
415
+ )
416
+
417
+ def _design_followup(
418
+ self, action: ExperimentAction, s: FullLatentState, idx: int
419
+ ) -> IntermediateOutput:
420
+ return IntermediateOutput(
421
+ output_type=OutputType.FOLLOWUP_DESIGN,
422
+ step_index=idx,
423
+ summary="Follow-up experiment design proposed",
424
+ data={"proposal": action.parameters},
425
+ artifacts_available=["followup_proposal"],
426
+ )
427
+
428
+ def _subagent_review(
429
+ self, action: ExperimentAction, s: FullLatentState, idx: int
430
+ ) -> IntermediateOutput:
431
+ return IntermediateOutput(
432
+ output_type=OutputType.SUBAGENT_REPORT,
433
+ step_index=idx,
434
+ summary=f"Subagent review ({action.invoked_subagent or 'general'})",
435
+ data={"subagent": action.invoked_subagent, "notes": "Review complete."},
436
+ artifacts_available=["subagent_report"],
437
+ )
438
+
439
+ def _synthesize_conclusion(
440
+ self, action: ExperimentAction, s: FullLatentState, idx: int
441
+ ) -> IntermediateOutput:
442
+ return IntermediateOutput(
443
+ output_type=OutputType.CONCLUSION,
444
+ step_index=idx,
445
+ summary="Conclusion synthesised from pipeline evidence",
446
+ data={"claims": action.parameters.get("claims", [])},
447
+ artifacts_available=["conclusion_report"],
448
+ )
449
+
450
+ def _default(
451
+ self, action: ExperimentAction, s: FullLatentState, idx: int
452
+ ) -> IntermediateOutput:
453
+ return IntermediateOutput(
454
+ output_type=OutputType.FAILURE_REPORT,
455
+ step_index=idx,
456
+ success=False,
457
+ summary=f"Unhandled action type: {action.action_type}",
458
+ data={},
459
+ )
460
+
461
+ # ── helpers ─────────────────────────────────────────────────────────
462
+
463
+ def _random_partition(self, total: int, k: int) -> List[int]:
464
+ if k <= 0:
465
+ return []
466
+ fracs = self.noise.rng.dirichlet(alpha=[1.0] * k)
467
+ sizes = [max(1, int(total * f)) for f in fracs]
468
+ diff = total - sum(sizes)
469
+ sizes[0] += diff
470
+ return sizes
471
+
472
+
473
+ _HANDLERS = {
474
+ ActionType.COLLECT_SAMPLE: OutputGenerator._collect_sample,
475
+ ActionType.SELECT_COHORT: OutputGenerator._select_cohort,
476
+ ActionType.PREPARE_LIBRARY: OutputGenerator._prepare_library,
477
+ ActionType.CULTURE_CELLS: OutputGenerator._culture_cells,
478
+ ActionType.PERTURB_GENE: OutputGenerator._perturb,
479
+ ActionType.PERTURB_COMPOUND: OutputGenerator._perturb,
480
+ ActionType.SEQUENCE_CELLS: OutputGenerator._sequence_cells,
481
+ ActionType.RUN_QC: OutputGenerator._run_qc,
482
+ ActionType.FILTER_DATA: OutputGenerator._filter_data,
483
+ ActionType.NORMALIZE_DATA: OutputGenerator._normalize_data,
484
+ ActionType.INTEGRATE_BATCHES: OutputGenerator._integrate_batches,
485
+ ActionType.CLUSTER_CELLS: OutputGenerator._cluster_cells,
486
+ ActionType.DIFFERENTIAL_EXPRESSION: OutputGenerator._differential_expression,
487
+ ActionType.TRAJECTORY_ANALYSIS: OutputGenerator._trajectory_analysis,
488
+ ActionType.PATHWAY_ENRICHMENT: OutputGenerator._pathway_enrichment,
489
+ ActionType.REGULATORY_NETWORK_INFERENCE: OutputGenerator._regulatory_network,
490
+ ActionType.MARKER_SELECTION: OutputGenerator._marker_selection,
491
+ ActionType.VALIDATE_MARKER: OutputGenerator._validate_marker,
492
+ ActionType.DESIGN_FOLLOWUP: OutputGenerator._design_followup,
493
+ ActionType.REQUEST_SUBAGENT_REVIEW: OutputGenerator._subagent_review,
494
+ ActionType.SYNTHESIZE_CONCLUSION: OutputGenerator._synthesize_conclusion,
495
+ }
server/simulator/transition.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transition dynamics engine β€” the heart of the biological simulator.
2
+
3
+ Orchestrates latent-state updates, output generation, resource accounting,
4
+ and constraint propagation for every agent action.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from copy import deepcopy
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ from models import (
14
+ ActionType,
15
+ ExperimentAction,
16
+ IntermediateOutput,
17
+ OutputType,
18
+ )
19
+
20
+ from .latent_state import FullLatentState
21
+ from .noise import NoiseModel
22
+ from .output_generator import OutputGenerator
23
+
24
+
25
+ ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
26
+ ActionType.COLLECT_SAMPLE: (5_000, 7.0),
27
+ ActionType.SELECT_COHORT: ( 500, 1.0),
28
+ ActionType.PREPARE_LIBRARY: (8_000, 3.0),
29
+ ActionType.CULTURE_CELLS: (3_000, 14.0),
30
+ ActionType.PERTURB_GENE: (2_000, 3.0),
31
+ ActionType.PERTURB_COMPOUND: (1_000, 2.0),
32
+ ActionType.SEQUENCE_CELLS: (15_000, 5.0),
33
+ ActionType.RUN_QC: ( 100, 0.5),
34
+ ActionType.FILTER_DATA: ( 50, 0.25),
35
+ ActionType.NORMALIZE_DATA: ( 50, 0.25),
36
+ ActionType.INTEGRATE_BATCHES: ( 100, 0.5),
37
+ ActionType.CLUSTER_CELLS: ( 100, 0.5),
38
+ ActionType.DIFFERENTIAL_EXPRESSION: ( 100, 0.5),
39
+ ActionType.TRAJECTORY_ANALYSIS: ( 200, 1.0),
40
+ ActionType.PATHWAY_ENRICHMENT: ( 100, 0.5),
41
+ ActionType.REGULATORY_NETWORK_INFERENCE: ( 300, 1.0),
42
+ ActionType.MARKER_SELECTION: ( 100, 0.5),
43
+ ActionType.VALIDATE_MARKER: (5_000, 14.0),
44
+ ActionType.DESIGN_FOLLOWUP: ( 0, 0.5),
45
+ ActionType.REQUEST_SUBAGENT_REVIEW: ( 0, 0.25),
46
+ ActionType.SYNTHESIZE_CONCLUSION: ( 0, 0.5),
47
+ }
48
+
49
+
50
+ @dataclass
51
+ class TransitionResult:
52
+ """Bundle returned by the transition engine after one step."""
53
+
54
+ next_state: FullLatentState
55
+ output: IntermediateOutput
56
+ reward_components: Dict[str, float] = field(default_factory=dict)
57
+ hard_violations: List[str] = field(default_factory=list)
58
+ soft_violations: List[str] = field(default_factory=list)
59
+ done: bool = False
60
+
61
+
62
+ class TransitionEngine:
63
+ """Applies one action to the latent state, producing the next state
64
+ and a simulated intermediate output.
65
+
66
+ The engine delegates output generation to ``OutputGenerator`` and
67
+ constraint checking to external rule engines (injected at call time).
68
+ """
69
+
70
+ def __init__(self, noise: NoiseModel):
71
+ self.noise = noise
72
+ self.output_gen = OutputGenerator(noise)
73
+
74
+ def step(
75
+ self,
76
+ state: FullLatentState,
77
+ action: ExperimentAction,
78
+ *,
79
+ hard_violations: Optional[List[str]] = None,
80
+ soft_violations: Optional[List[str]] = None,
81
+ ) -> TransitionResult:
82
+ s = deepcopy(state)
83
+ s.step_count += 1
84
+ step_idx = s.step_count
85
+
86
+ hard_v = hard_violations or []
87
+ soft_v = soft_violations or []
88
+
89
+ if hard_v:
90
+ output = IntermediateOutput(
91
+ output_type=OutputType.FAILURE_REPORT,
92
+ step_index=step_idx,
93
+ success=False,
94
+ summary=f"Action blocked: {'; '.join(hard_v)}",
95
+ )
96
+ return TransitionResult(
97
+ next_state=s,
98
+ output=output,
99
+ hard_violations=hard_v,
100
+ soft_violations=soft_v,
101
+ )
102
+
103
+ self._apply_resource_cost(s, action)
104
+
105
+ if s.resources.budget_exhausted or s.resources.time_exhausted:
106
+ output = IntermediateOutput(
107
+ output_type=OutputType.FAILURE_REPORT,
108
+ step_index=step_idx,
109
+ success=False,
110
+ summary="Resources exhausted",
111
+ )
112
+ return TransitionResult(
113
+ next_state=s, output=output, done=True,
114
+ hard_violations=["resources_exhausted"],
115
+ )
116
+
117
+ self._update_progress(s, action)
118
+
119
+ output = self.output_gen.generate(action, s, step_idx)
120
+
121
+ if soft_v:
122
+ output.quality_score *= 0.5
123
+ output.warnings.extend(soft_v)
124
+
125
+ self._propagate_artifacts(s, action, output)
126
+
127
+ done = action.action_type == ActionType.SYNTHESIZE_CONCLUSION
128
+
129
+ return TransitionResult(
130
+ next_state=s,
131
+ output=output,
132
+ soft_violations=soft_v,
133
+ done=done,
134
+ )
135
+
136
+ # ── internals ───────────────────────────────────────────────────────
137
+
138
+ def _apply_resource_cost(
139
+ self, s: FullLatentState, action: ExperimentAction
140
+ ) -> None:
141
+ budget_cost, time_cost = ACTION_COSTS.get(
142
+ action.action_type, (0.0, 0.0)
143
+ )
144
+ s.resources.budget_used += budget_cost
145
+ s.resources.time_used_days += time_cost
146
+ if action.action_type in {
147
+ ActionType.RUN_QC, ActionType.FILTER_DATA,
148
+ ActionType.NORMALIZE_DATA, ActionType.INTEGRATE_BATCHES,
149
+ ActionType.CLUSTER_CELLS, ActionType.DIFFERENTIAL_EXPRESSION,
150
+ ActionType.TRAJECTORY_ANALYSIS, ActionType.PATHWAY_ENRICHMENT,
151
+ ActionType.REGULATORY_NETWORK_INFERENCE, ActionType.MARKER_SELECTION,
152
+ }:
153
+ s.resources.compute_hours_used += time_cost * 8
154
+
155
+ def _update_progress(
156
+ self, s: FullLatentState, action: ExperimentAction
157
+ ) -> None:
158
+ at = action.action_type
159
+ p = s.progress
160
+ _MAP = {
161
+ ActionType.COLLECT_SAMPLE: "samples_collected",
162
+ ActionType.SELECT_COHORT: "cohort_selected",
163
+ ActionType.PREPARE_LIBRARY: "library_prepared",
164
+ ActionType.CULTURE_CELLS: "cells_cultured",
165
+ ActionType.PERTURB_GENE: "perturbation_applied",
166
+ ActionType.PERTURB_COMPOUND: "perturbation_applied",
167
+ ActionType.SEQUENCE_CELLS: "cells_sequenced",
168
+ ActionType.RUN_QC: "qc_performed",
169
+ ActionType.FILTER_DATA: "data_filtered",
170
+ ActionType.NORMALIZE_DATA: "data_normalized",
171
+ ActionType.INTEGRATE_BATCHES: "batches_integrated",
172
+ ActionType.CLUSTER_CELLS: "cells_clustered",
173
+ ActionType.DIFFERENTIAL_EXPRESSION: "de_performed",
174
+ ActionType.TRAJECTORY_ANALYSIS: "trajectories_inferred",
175
+ ActionType.PATHWAY_ENRICHMENT: "pathways_analyzed",
176
+ ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
177
+ ActionType.MARKER_SELECTION: "markers_discovered",
178
+ ActionType.VALIDATE_MARKER: "markers_validated",
179
+ ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
180
+ }
181
+ flag = _MAP.get(at)
182
+ if flag:
183
+ setattr(p, flag, True)
184
+
185
+ if at == ActionType.COLLECT_SAMPLE:
186
+ n = action.parameters.get("n_samples", 6)
187
+ s.resources.samples_available += n
188
+
189
+ if at == ActionType.SEQUENCE_CELLS:
190
+ s.resources.sequencing_lanes_used += 1
191
+
192
+ if at == ActionType.FILTER_DATA:
193
+ retain = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
194
+ p.n_cells_after_filter = max(
195
+ 100, int(s.biology.n_true_cells * retain)
196
+ )
197
+
198
+ if at == ActionType.CLUSTER_CELLS:
199
+ n_true = len(s.biology.cell_populations) or 5
200
+ p.n_clusters_found = self.noise.sample_cluster_count(n_true, 0.8)
201
+
202
+ def _propagate_artifacts(
203
+ self,
204
+ s: FullLatentState,
205
+ action: ExperimentAction,
206
+ output: IntermediateOutput,
207
+ ) -> None:
208
+ if action.action_type == ActionType.DIFFERENTIAL_EXPRESSION:
209
+ top = output.data.get("top_genes", [])
210
+ s.discovered_de_genes = [g["gene"] for g in top[:20]]
211
+
212
+ if action.action_type == ActionType.CLUSTER_CELLS:
213
+ s.discovered_clusters = output.data.get("cluster_names", [])
214
+
215
+ if action.action_type == ActionType.MARKER_SELECTION:
216
+ s.progress.n_markers_found = output.data.get("n_candidates", 0)
server/subagents/__init__.py ADDED
File without changes
server/tasks/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .generator import TaskGenerator
2
+ from .scenarios import SCENARIO_LIBRARY, Scenario
3
+
4
+ __all__ = ["SCENARIO_LIBRARY", "Scenario", "TaskGenerator"]
server/tasks/generator.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Task generator β€” produces (TaskSpec, FullLatentState) pairs for episodes.
2
+
3
+ Supports three modes:
4
+ 1. Select from the pre-defined scenario library.
5
+ 2. Randomly perturb a scenario for domain-randomisation.
6
+ 3. Compose a fully procedural scenario (tissue Γ— modality Γ— difficulty).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import List, Optional, Tuple
12
+
13
+ import numpy as np
14
+
15
+ from models import TaskSpec
16
+
17
+ from server.simulator.latent_state import (
18
+ CellPopulation,
19
+ ExperimentProgress,
20
+ FullLatentState,
21
+ GeneProgram,
22
+ LatentBiologicalState,
23
+ ResourceState,
24
+ TechnicalState,
25
+ )
26
+ from .scenarios import SCENARIO_LIBRARY, Scenario
27
+
28
+
29
+ class TaskGenerator:
30
+ """Generates task + latent-state pairs for environment episodes."""
31
+
32
+ def __init__(
33
+ self,
34
+ scenarios: Optional[List[Scenario]] = None,
35
+ domain_randomise: bool = True,
36
+ ):
37
+ self.scenarios = scenarios or SCENARIO_LIBRARY
38
+ self.domain_randomise = domain_randomise
39
+
40
+ def generate(
41
+ self,
42
+ *,
43
+ seed: Optional[int] = None,
44
+ scenario_name: Optional[str] = None,
45
+ ) -> Tuple[TaskSpec, FullLatentState]:
46
+ rng = np.random.default_rng(seed)
47
+
48
+ if scenario_name:
49
+ scenario = self._find_scenario(scenario_name)
50
+ else:
51
+ idx = int(rng.integers(0, len(self.scenarios)))
52
+ scenario = self.scenarios[idx]
53
+
54
+ task = scenario.task.model_copy(deep=True)
55
+ biology = scenario.biology.model_copy(deep=True)
56
+ technical = scenario.technical.model_copy(deep=True)
57
+
58
+ if self.domain_randomise:
59
+ self._randomise(rng, task, biology, technical)
60
+
61
+ latent = FullLatentState(
62
+ biology=biology,
63
+ technical=technical,
64
+ progress=ExperimentProgress(),
65
+ resources=ResourceState(
66
+ budget_total=task.budget_limit,
67
+ time_limit_days=task.time_limit_days,
68
+ ),
69
+ hidden_failure_conditions=list(scenario.hidden_failure_conditions),
70
+ rng_seed=seed or 0,
71
+ )
72
+ return task, latent
73
+
74
+ def list_scenarios(self) -> List[str]:
75
+ return [s.name for s in self.scenarios]
76
+
77
+ # ── internals ───────────────────────────────────────────────────────
78
+
79
+ def _find_scenario(self, name: str) -> Scenario:
80
+ for s in self.scenarios:
81
+ if s.name == name:
82
+ return s
83
+ available = ", ".join(self.list_scenarios())
84
+ raise ValueError(f"Unknown scenario '{name}'. Available: {available}")
85
+
86
+ def _randomise(
87
+ self,
88
+ rng: np.random.Generator,
89
+ task: TaskSpec,
90
+ bio: LatentBiologicalState,
91
+ tech: TechnicalState,
92
+ ) -> None:
93
+ budget_scale = float(rng.uniform(0.7, 1.3))
94
+ task.budget_limit *= budget_scale
95
+ task.time_limit_days *= float(rng.uniform(0.8, 1.2))
96
+
97
+ tech.dropout_rate = float(np.clip(
98
+ tech.dropout_rate + rng.normal(0, 0.02), 0.01, 0.3
99
+ ))
100
+ tech.doublet_rate = float(np.clip(
101
+ tech.doublet_rate + rng.normal(0, 0.01), 0.01, 0.15
102
+ ))
103
+ tech.sample_quality = float(np.clip(
104
+ tech.sample_quality + rng.normal(0, 0.05), 0.5, 1.0
105
+ ))
106
+ tech.ambient_rna_fraction = float(np.clip(
107
+ tech.ambient_rna_fraction + rng.normal(0, 0.01), 0.01, 0.15
108
+ ))
109
+ for batch_id in list(tech.batch_effects.keys()):
110
+ tech.batch_effects[batch_id] = float(np.clip(
111
+ tech.batch_effects[batch_id] + rng.normal(0, 0.03), 0.0, 0.4
112
+ ))
113
+
114
+ for pop in bio.cell_populations:
115
+ pop.proportion = float(np.clip(
116
+ pop.proportion * rng.uniform(0.8, 1.2), 0.01, 0.8
117
+ ))
118
+ total = sum(p.proportion for p in bio.cell_populations) or 1.0
119
+ for pop in bio.cell_populations:
120
+ pop.proportion /= total
121
+
122
+ for comparison, effects in bio.true_de_genes.items():
123
+ for gene in list(effects.keys()):
124
+ effects[gene] *= float(rng.uniform(0.8, 1.2))
125
+
126
+ bio.n_true_cells = max(
127
+ 1000,
128
+ int(bio.n_true_cells * rng.uniform(0.6, 1.4)),
129
+ )
server/tasks/scenarios.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pre-defined biological scenarios for task generation.
2
+
3
+ Each ``Scenario`` bundles a task specification together with the matching
4
+ hidden ground-truth biology so the simulator can instantiate consistent
5
+ episodes. The library is intentionally diverse: it covers differential
6
+ expression, trajectory inference, perturbation response, and biomarker
7
+ validation across tissues and modalities.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ from models import ExpectedFinding, PaperReference, TaskSpec
16
+
17
+ from server.simulator.latent_state import (
18
+ CellPopulation,
19
+ GeneProgram,
20
+ LatentBiologicalState,
21
+ TechnicalState,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class Scenario:
27
+ """A reproducible (task, ground-truth) pair."""
28
+
29
+ name: str
30
+ task: TaskSpec
31
+ biology: LatentBiologicalState
32
+ technical: TechnicalState = field(default_factory=TechnicalState)
33
+ hidden_failure_conditions: List[str] = field(default_factory=list)
34
+ difficulty: str = "medium"
35
+ tags: List[str] = field(default_factory=list)
36
+
37
+
38
+ # ── Scenario library ────────────────────────────────────────────────────────
39
+
40
+ SCENARIO_LIBRARY: List[Scenario] = [
41
+ # ── 1. Cardiac disease DE ───────────────────────────────────────────
42
+ Scenario(
43
+ name="cardiac_disease_de",
44
+ difficulty="easy",
45
+ tags=["de", "scRNA-seq", "cardiac"],
46
+ task=TaskSpec(
47
+ problem_statement=(
48
+ "Identify differentially expressed genes between diseased "
49
+ "and healthy cardiomyocytes using single-cell RNA sequencing."
50
+ ),
51
+ modality="scRNA-seq",
52
+ organism="human",
53
+ tissue="heart",
54
+ conditions=["healthy", "dilated_cardiomyopathy"],
55
+ budget_limit=80_000.0,
56
+ time_limit_days=120.0,
57
+ success_criteria=[
58
+ "Identify DE genes between conditions",
59
+ "Validate at least one candidate marker",
60
+ ],
61
+ ),
62
+ biology=LatentBiologicalState(
63
+ cell_populations=[
64
+ CellPopulation(
65
+ name="cardiomyocyte",
66
+ proportion=0.35,
67
+ marker_genes=["TNNT2", "MYH7", "ACTC1"],
68
+ state="contractile",
69
+ condition_response={"dilated_cardiomyopathy": 0.8},
70
+ ),
71
+ CellPopulation(
72
+ name="fibroblast",
73
+ proportion=0.25,
74
+ marker_genes=["COL1A1", "DCN", "LUM"],
75
+ state="quiescent",
76
+ condition_response={"dilated_cardiomyopathy": 1.3},
77
+ ),
78
+ CellPopulation(
79
+ name="endothelial",
80
+ proportion=0.15,
81
+ marker_genes=["PECAM1", "VWF", "CDH5"],
82
+ state="quiescent",
83
+ ),
84
+ CellPopulation(
85
+ name="macrophage",
86
+ proportion=0.10,
87
+ marker_genes=["CD68", "CD163", "CSF1R"],
88
+ state="activated",
89
+ condition_response={"dilated_cardiomyopathy": 1.5},
90
+ ),
91
+ CellPopulation(
92
+ name="smooth_muscle",
93
+ proportion=0.15,
94
+ marker_genes=["ACTA2", "MYH11", "TAGLN"],
95
+ state="quiescent",
96
+ ),
97
+ ],
98
+ true_de_genes={
99
+ "disease_vs_healthy": {
100
+ "NPPA": 2.5, "NPPB": 3.1, "MYH7": 1.8,
101
+ "COL1A1": 1.6, "COL3A1": 1.4, "POSTN": 2.0,
102
+ "CCL2": 1.2, "IL6": 0.9, "TGFB1": 1.1,
103
+ "ANKRD1": 2.2, "XIRP2": -1.3, "MYL2": -0.8,
104
+ },
105
+ },
106
+ true_pathways={
107
+ "cardiac_muscle_contraction": 0.4,
108
+ "extracellular_matrix_organisation": 0.85,
109
+ "inflammatory_response": 0.7,
110
+ "TGF_beta_signalling": 0.75,
111
+ "apoptosis": 0.55,
112
+ },
113
+ true_markers=["NPPA", "NPPB", "POSTN", "COL1A1"],
114
+ causal_mechanisms=[
115
+ "TGF-beta-driven fibrosis",
116
+ "inflammatory macrophage infiltration",
117
+ ],
118
+ n_true_cells=12_000,
119
+ ),
120
+ technical=TechnicalState(
121
+ batch_effects={"batch_1": 0.15, "batch_2": 0.10},
122
+ doublet_rate=0.05,
123
+ dropout_rate=0.08,
124
+ ),
125
+ ),
126
+
127
+ # ── 2. Developmental trajectory ────────────────────────────��────────
128
+ Scenario(
129
+ name="hematopoiesis_trajectory",
130
+ difficulty="medium",
131
+ tags=["trajectory", "scRNA-seq", "hematopoiesis"],
132
+ task=TaskSpec(
133
+ problem_statement=(
134
+ "Infer the developmental trajectory of hematopoietic "
135
+ "stem cells differentiating into mature blood lineages."
136
+ ),
137
+ modality="scRNA-seq",
138
+ organism="human",
139
+ tissue="bone_marrow",
140
+ conditions=["steady_state"],
141
+ budget_limit=100_000.0,
142
+ time_limit_days=150.0,
143
+ success_criteria=[
144
+ "Reconstruct branching lineage structure",
145
+ "Identify key transcription factors driving fate decisions",
146
+ ],
147
+ paper_references=[
148
+ PaperReference(
149
+ title=(
150
+ "Single-cell RNA-sequencing uncovers transcriptional "
151
+ "states and fate decisions in haematopoiesis"
152
+ ),
153
+ citation="Nature Communications (2018)",
154
+ doi="10.1038/s41467-017-02305-6",
155
+ url=(
156
+ "https://www.nature.com/articles/"
157
+ "s41467-017-02305-6"
158
+ ),
159
+ ),
160
+ ],
161
+ expected_findings=[
162
+ ExpectedFinding(
163
+ finding=(
164
+ "Trajectory analysis should recover branching blood "
165
+ "lineages rooted in HSCs."
166
+ ),
167
+ category="trajectory",
168
+ keywords=["HSC", "branching", "lineage", "trajectory"],
169
+ ),
170
+ ExpectedFinding(
171
+ finding=(
172
+ "GATA1 should appear as a driver of erythroid fate "
173
+ "commitment."
174
+ ),
175
+ category="regulatory_network",
176
+ keywords=["GATA1", "erythroid", "commitment"],
177
+ ),
178
+ ExpectedFinding(
179
+ finding=(
180
+ "CEBPA and SPI1 should support myeloid branch "
181
+ "decisions."
182
+ ),
183
+ category="regulatory_network",
184
+ keywords=["CEBPA", "SPI1", "myeloid", "branch"],
185
+ ),
186
+ ],
187
+ ),
188
+ biology=LatentBiologicalState(
189
+ cell_populations=[
190
+ CellPopulation(name="HSC", proportion=0.05,
191
+ marker_genes=["CD34", "KIT", "THY1"],
192
+ state="stem"),
193
+ CellPopulation(name="CMP", proportion=0.10,
194
+ marker_genes=["CD34", "FLT3"],
195
+ state="progenitor"),
196
+ CellPopulation(name="GMP", proportion=0.12,
197
+ marker_genes=["CSF3R", "CEBPA"],
198
+ state="progenitor"),
199
+ CellPopulation(name="MEP", proportion=0.10,
200
+ marker_genes=["GATA1", "KLF1"],
201
+ state="progenitor"),
202
+ CellPopulation(name="erythrocyte", proportion=0.20,
203
+ marker_genes=["HBA1", "HBB", "GYPA"],
204
+ state="mature"),
205
+ CellPopulation(name="neutrophil", proportion=0.18,
206
+ marker_genes=["ELANE", "MPO", "CTSG"],
207
+ state="mature"),
208
+ CellPopulation(name="monocyte", proportion=0.15,
209
+ marker_genes=["CD14", "CSF1R", "FCGR3A"],
210
+ state="mature"),
211
+ CellPopulation(name="megakaryocyte", proportion=0.10,
212
+ marker_genes=["ITGA2B", "GP1BA"],
213
+ state="mature"),
214
+ ],
215
+ true_de_genes={},
216
+ true_pathways={
217
+ "hematopoietic_cell_lineage": 0.9,
218
+ "MAPK_signalling": 0.6,
219
+ "JAK_STAT_signalling": 0.7,
220
+ },
221
+ true_trajectory={
222
+ "root": "HSC",
223
+ "n_lineages": 3,
224
+ "branching": True,
225
+ "branches": [
226
+ ["HSC", "CMP", "GMP", "neutrophil"],
227
+ ["HSC", "CMP", "GMP", "monocyte"],
228
+ ["HSC", "MEP", "erythrocyte"],
229
+ ["HSC", "MEP", "megakaryocyte"],
230
+ ],
231
+ },
232
+ true_regulatory_network={
233
+ "GATA1": ["KLF1", "HBB", "HBA1", "GYPA"],
234
+ "CEBPA": ["CSF3R", "ELANE", "MPO"],
235
+ "SPI1": ["CSF1R", "CD14", "FCGR3A"],
236
+ "RUNX1": ["CD34", "KIT"],
237
+ },
238
+ true_markers=["GATA1", "CEBPA", "SPI1"],
239
+ causal_mechanisms=[
240
+ "GATA1-driven erythroid commitment",
241
+ "PU.1/CEBPA antagonism at myeloid branch point",
242
+ ],
243
+ n_true_cells=15_000,
244
+ ),
245
+ technical=TechnicalState(dropout_rate=0.12, doublet_rate=0.06),
246
+ ),
247
+
248
+ # ── 3. Perturbation response ────────────────────────────────────────
249
+ Scenario(
250
+ name="perturbation_immune",
251
+ difficulty="hard",
252
+ tags=["perturbation", "scRNA-seq", "immune"],
253
+ task=TaskSpec(
254
+ problem_statement=(
255
+ "Determine the effect of JAK inhibitor treatment on "
256
+ "T-cell activation states in rheumatoid arthritis."
257
+ ),
258
+ modality="scRNA-seq",
259
+ organism="human",
260
+ tissue="synovial_fluid",
261
+ conditions=["untreated_RA", "JAK_inhibitor_treated"],
262
+ budget_limit=120_000.0,
263
+ time_limit_days=180.0,
264
+ prior_observations=[
265
+ "Elevated JAK-STAT signalling observed in prior bulk RNA-seq",
266
+ ],
267
+ success_criteria=[
268
+ "Quantify shift in T-cell activation states",
269
+ "Identify pathways modulated by JAK inhibitor",
270
+ "Propose validation strategy",
271
+ ],
272
+ ),
273
+ biology=LatentBiologicalState(
274
+ cell_populations=[
275
+ CellPopulation(name="CD4_Th1", proportion=0.20,
276
+ marker_genes=["IFNG", "TBX21", "IL2"],
277
+ state="activated",
278
+ condition_response={"JAK_inhibitor_treated": 0.5}),
279
+ CellPopulation(name="CD4_Th17", proportion=0.15,
280
+ marker_genes=["IL17A", "RORC", "CCR6"],
281
+ state="activated",
282
+ condition_response={"JAK_inhibitor_treated": 0.6}),
283
+ CellPopulation(name="CD4_Treg", proportion=0.08,
284
+ marker_genes=["FOXP3", "IL2RA", "CTLA4"],
285
+ state="regulatory",
286
+ condition_response={"JAK_inhibitor_treated": 1.2}),
287
+ CellPopulation(name="CD8_cytotoxic", proportion=0.18,
288
+ marker_genes=["GZMB", "PRF1", "CD8A"],
289
+ state="activated",
290
+ condition_response={"JAK_inhibitor_treated": 0.7}),
291
+ CellPopulation(name="macrophage", proportion=0.15,
292
+ marker_genes=["CD68", "CD163", "MARCO"],
293
+ state="inflammatory"),
294
+ CellPopulation(name="fibroblast", proportion=0.14,
295
+ marker_genes=["COL1A1", "FAP", "THY1"],
296
+ state="activated"),
297
+ CellPopulation(name="B_cell", proportion=0.10,
298
+ marker_genes=["CD19", "MS4A1", "CD79A"],
299
+ state="quiescent"),
300
+ ],
301
+ true_de_genes={
302
+ "treated_vs_untreated": {
303
+ "IFNG": -1.8, "TBX21": -1.2, "IL17A": -1.5,
304
+ "RORC": -0.9, "JAK1": -0.3, "STAT1": -1.0,
305
+ "STAT3": -0.8, "SOCS1": 1.5, "SOCS3": 1.3,
306
+ "FOXP3": 0.6, "IL10": 0.7,
307
+ },
308
+ },
309
+ true_pathways={
310
+ "JAK_STAT_signalling": 0.3,
311
+ "Th1_differentiation": 0.35,
312
+ "Th17_differentiation": 0.4,
313
+ "cytokine_signalling": 0.45,
314
+ "regulatory_T_cell_function": 0.7,
315
+ },
316
+ perturbation_effects={
317
+ "JAK_inhibitor": {
318
+ "STAT1": -0.8, "STAT3": -0.7, "IFNG": -1.5,
319
+ "IL17A": -1.3, "SOCS1": 1.2,
320
+ },
321
+ },
322
+ true_markers=["STAT1", "SOCS1", "IFNG"],
323
+ causal_mechanisms=[
324
+ "JAK-STAT pathway inhibition reduces Th1/Th17 activation",
325
+ "Compensatory Treg expansion under JAK inhibition",
326
+ ],
327
+ n_true_cells=18_000,
328
+ ),
329
+ technical=TechnicalState(
330
+ batch_effects={"batch_ctrl": 0.12, "batch_treated": 0.18},
331
+ ambient_rna_fraction=0.07,
332
+ dropout_rate=0.10,
333
+ ),
334
+ hidden_failure_conditions=[
335
+ "High ambient RNA may confound DE in low-abundance transcripts",
336
+ ],
337
+ ),
338
+
339
+ # ── 4. Biomarker validation ─────────────────────────────────────────
340
+ Scenario(
341
+ name="biomarker_validation_lung",
342
+ difficulty="medium",
343
+ tags=["biomarker", "validation", "scRNA-seq", "lung"],
344
+ task=TaskSpec(
345
+ problem_statement=(
346
+ "Design a follow-up validation experiment for candidate "
347
+ "biomarker SPP1 in idiopathic pulmonary fibrosis (IPF)."
348
+ ),
349
+ modality="scRNA-seq",
350
+ organism="human",
351
+ tissue="lung",
352
+ conditions=["healthy", "IPF"],
353
+ budget_limit=90_000.0,
354
+ time_limit_days=150.0,
355
+ prior_observations=[
356
+ "SPP1 identified as top DE gene in prior pilot study",
357
+ "SPP1+ macrophages enriched in fibrotic regions",
358
+ ],
359
+ success_criteria=[
360
+ "Validate SPP1 as a marker for pro-fibrotic macrophages",
361
+ "Confirm spatial localisation in fibrotic tissue",
362
+ ],
363
+ paper_references=[
364
+ PaperReference(
365
+ title=(
366
+ "Proliferating SPP1/MERTK-expressing macrophages in "
367
+ "idiopathic pulmonary fibrosis"
368
+ ),
369
+ citation="European Respiratory Journal (2019)",
370
+ doi="10.1183/13993003.02441-2018",
371
+ pmid="31221805",
372
+ url="https://pubmed.ncbi.nlm.nih.gov/31221805/",
373
+ ),
374
+ ],
375
+ expected_findings=[
376
+ ExpectedFinding(
377
+ finding=(
378
+ "SPP1-positive macrophages should be enriched in IPF "
379
+ "fibrotic regions."
380
+ ),
381
+ category="marker",
382
+ keywords=["SPP1", "macrophage", "IPF", "fibrotic"],
383
+ ),
384
+ ExpectedFinding(
385
+ finding=(
386
+ "MERTK should co-occur with the profibrotic macrophage "
387
+ "state."
388
+ ),
389
+ category="marker",
390
+ keywords=["MERTK", "macrophage", "SPP1"],
391
+ ),
392
+ ExpectedFinding(
393
+ finding=(
394
+ "Extracellular matrix organization should emerge as a "
395
+ "top fibrotic program."
396
+ ),
397
+ category="pathway",
398
+ keywords=["extracellular_matrix", "fibrosis", "pathway"],
399
+ ),
400
+ ],
401
+ dataset_metadata={
402
+ "literature_grounding": "single_cell_ipf_macrophages",
403
+ },
404
+ ),
405
+ biology=LatentBiologicalState(
406
+ cell_populations=[
407
+ CellPopulation(name="alveolar_macrophage", proportion=0.18,
408
+ marker_genes=["MARCO", "FABP4", "MCEMP1"],
409
+ state="resident"),
410
+ CellPopulation(name="SPP1_macrophage", proportion=0.12,
411
+ marker_genes=["SPP1", "MERTK", "MMP9", "TREM2"],
412
+ state="pro-fibrotic",
413
+ condition_response={"IPF": 2.0}),
414
+ CellPopulation(name="AT2", proportion=0.20,
415
+ marker_genes=["SFTPC", "SFTPB", "ABCA3"],
416
+ state="normal"),
417
+ CellPopulation(name="fibroblast", proportion=0.22,
418
+ marker_genes=["COL1A1", "COL3A1", "POSTN"],
419
+ state="activated",
420
+ condition_response={"IPF": 1.5}),
421
+ CellPopulation(name="endothelial", proportion=0.13,
422
+ marker_genes=["PECAM1", "CLDN5"],
423
+ state="quiescent"),
424
+ CellPopulation(name="T_cell", proportion=0.15,
425
+ marker_genes=["CD3D", "CD3E", "IL7R"],
426
+ state="quiescent"),
427
+ ],
428
+ true_de_genes={
429
+ "IPF_vs_healthy": {
430
+ "SPP1": 3.2, "MERTK": 1.4, "MMP9": 1.8, "TREM2": 1.5,
431
+ "COL1A1": 2.1, "COL3A1": 1.9, "POSTN": 2.4,
432
+ "SFTPC": -1.2, "AGER": -1.6,
433
+ },
434
+ },
435
+ true_pathways={
436
+ "extracellular_matrix_organisation": 0.9,
437
+ "integrin_signalling": 0.75,
438
+ "macrophage_activation": 0.8,
439
+ "Wnt_signalling": 0.6,
440
+ },
441
+ true_markers=["SPP1", "MERTK", "POSTN", "MMP9"],
442
+ causal_mechanisms=[
443
+ "SPP1+ macrophage-driven fibroblast activation",
444
+ "Integrin-mediated SPP1 signalling in fibrosis",
445
+ ],
446
+ n_true_cells=14_000,
447
+ ),
448
+ technical=TechnicalState(
449
+ batch_effects={"batch_1": 0.10},
450
+ dropout_rate=0.09,
451
+ sample_quality=0.85,
452
+ ),
453
+ ),
454
+ ]
tests/__init__.py ADDED
File without changes
tests/test_environment.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for the full BioExperimentEnvironment."""
2
+
3
+ from models import ActionType, ExperimentAction
4
+ from server.hackathon_environment import BioExperimentEnvironment
5
+
6
+
7
+ class TestEnvironmentLifecycle:
8
+ def test_reset_returns_valid_observation(self):
9
+ env = BioExperimentEnvironment()
10
+ obs = env.reset()
11
+ assert obs.step_index == 0
12
+ assert obs.done is False
13
+ assert obs.task.problem_statement != ""
14
+
15
+ def test_step_increments_step_count(self):
16
+ env = BioExperimentEnvironment()
17
+ env.reset()
18
+ obs = env.step(ExperimentAction(action_type=ActionType.COLLECT_SAMPLE))
19
+ assert obs.step_index == 1
20
+ assert env.state.step_count == 1
21
+
22
+ def test_valid_pipeline_trajectory(self):
23
+ env = BioExperimentEnvironment()
24
+ env.reset()
25
+
26
+ actions = [
27
+ ExperimentAction(action_type=ActionType.COLLECT_SAMPLE,
28
+ parameters={"n_samples": 6}),
29
+ ExperimentAction(action_type=ActionType.PREPARE_LIBRARY,
30
+ method="10x_chromium"),
31
+ ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
32
+ ExperimentAction(action_type=ActionType.RUN_QC),
33
+ ExperimentAction(action_type=ActionType.FILTER_DATA),
34
+ ExperimentAction(action_type=ActionType.NORMALIZE_DATA),
35
+ ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
36
+ ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION,
37
+ parameters={"comparison": "disease_vs_healthy"}),
38
+ ]
39
+
40
+ for a in actions:
41
+ obs = env.step(a)
42
+ assert obs.latest_output is not None
43
+ assert obs.latest_output.success is True, (
44
+ f"Step {a.action_type} failed: {obs.rule_violations}"
45
+ )
46
+
47
+ assert obs.step_index == len(actions)
48
+ assert obs.resource_usage.budget_used > 0
49
+
50
+ def test_premature_de_blocked(self):
51
+ env = BioExperimentEnvironment()
52
+ env.reset()
53
+ obs = env.step(ExperimentAction(
54
+ action_type=ActionType.DIFFERENTIAL_EXPRESSION,
55
+ ))
56
+ assert obs.latest_output is not None
57
+ assert obs.latest_output.success is False
58
+
59
+ def test_conclusion_ends_episode(self):
60
+ env = BioExperimentEnvironment()
61
+ env.reset()
62
+
63
+ quick_pipeline = [
64
+ ExperimentAction(action_type=ActionType.COLLECT_SAMPLE),
65
+ ExperimentAction(action_type=ActionType.PREPARE_LIBRARY),
66
+ ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
67
+ ExperimentAction(action_type=ActionType.RUN_QC),
68
+ ExperimentAction(action_type=ActionType.FILTER_DATA),
69
+ ExperimentAction(action_type=ActionType.NORMALIZE_DATA),
70
+ ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
71
+ ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION,
72
+ parameters={"comparison": "disease_vs_healthy"}),
73
+ ExperimentAction(
74
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
75
+ parameters={"claims": [
76
+ {"claim": "Test conclusion", "confidence": 0.7,
77
+ "claim_type": "correlational"},
78
+ ]},
79
+ ),
80
+ ]
81
+ for a in quick_pipeline:
82
+ obs = env.step(a)
83
+
84
+ assert obs.done is True
85
+ assert obs.reward != 0.0
tests/test_literature_benchmark.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for literature-grounded benchmark utilities."""
2
+
3
+ from training.literature_benchmark import (
4
+ run_paper_benchmark,
5
+ select_literature_scenario,
6
+ )
7
+
8
+
9
+ def test_select_literature_scenario_for_ipf_prompt():
10
+ scenario = select_literature_scenario(
11
+ "Validate SPP1-positive macrophage findings in idiopathic pulmonary fibrosis."
12
+ )
13
+ assert scenario.name == "biomarker_validation_lung"
14
+
15
+
16
+ def test_select_literature_scenario_for_trajectory_prompt():
17
+ scenario = select_literature_scenario(
18
+ "Recover branching hematopoietic lineages and branch point transcription factors."
19
+ )
20
+ assert scenario.name == "hematopoiesis_trajectory"
21
+
22
+
23
+ def test_run_paper_benchmark_matches_curated_findings():
24
+ result = run_paper_benchmark(
25
+ problem_statement=(
26
+ "Design a follow-up validation experiment for candidate biomarker "
27
+ "SPP1 in idiopathic pulmonary fibrosis."
28
+ ),
29
+ scenario_name="biomarker_validation_lung",
30
+ domain_randomise=False,
31
+ )
32
+
33
+ assert result.total_steps >= 1
34
+ assert result.matched_papers
35
+ assert result.match_ratio >= (2 / 3)
36
+ assert any("SPP1" in finding for finding in result.matched_findings)
tests/test_models.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for POMDP schema models."""
2
+
3
+ import pytest
4
+ from models import (
5
+ ActionType,
6
+ ConclusionClaim,
7
+ ExpectedFinding,
8
+ ExperimentAction,
9
+ ExperimentObservation,
10
+ IntermediateOutput,
11
+ OutputType,
12
+ PaperReference,
13
+ PipelineStepRecord,
14
+ ResourceUsage,
15
+ TaskSpec,
16
+ )
17
+
18
+
19
+ def test_experiment_action_roundtrip():
20
+ a = ExperimentAction(
21
+ action_type=ActionType.COLLECT_SAMPLE,
22
+ input_targets=["prior_cohort"],
23
+ method="10x_chromium",
24
+ parameters={"n_samples": 6},
25
+ confidence=0.8,
26
+ )
27
+ d = a.model_dump()
28
+ assert d["action_type"] == "collect_sample"
29
+ assert d["confidence"] == 0.8
30
+ reconstructed = ExperimentAction(**d)
31
+ assert reconstructed.action_type == ActionType.COLLECT_SAMPLE
32
+
33
+
34
+ def test_experiment_observation_defaults():
35
+ obs = ExperimentObservation(done=False, reward=0.0)
36
+ assert obs.step_index == 0
37
+ assert obs.pipeline_history == []
38
+ assert obs.resource_usage.budget_remaining == 100_000.0
39
+
40
+
41
+ def test_intermediate_output_quality_bounds():
42
+ with pytest.raises(Exception):
43
+ IntermediateOutput(
44
+ output_type=OutputType.QC_METRICS,
45
+ step_index=1,
46
+ quality_score=1.5,
47
+ )
48
+
49
+
50
+ def test_task_spec_defaults():
51
+ t = TaskSpec()
52
+ assert "10x_chromium" in t.available_assays
53
+ assert t.budget_limit == 100_000.0
54
+ assert t.paper_references == []
55
+ assert t.expected_findings == []
56
+
57
+
58
+ def test_paper_reference_and_expected_finding_roundtrip():
59
+ task = TaskSpec(
60
+ paper_references=[
61
+ PaperReference(
62
+ title="Example paper",
63
+ doi="10.0000/example",
64
+ )
65
+ ],
66
+ expected_findings=[
67
+ ExpectedFinding(
68
+ finding="Example marker is enriched",
69
+ category="marker",
70
+ keywords=["EXAMPLE"],
71
+ )
72
+ ],
73
+ )
74
+ dumped = task.model_dump()
75
+ assert dumped["paper_references"][0]["title"] == "Example paper"
76
+ assert dumped["expected_findings"][0]["category"] == "marker"
77
+
78
+
79
+ def test_conclusion_claim_serialization():
80
+ c = ConclusionClaim(
81
+ claim="NPPA is upregulated in disease",
82
+ evidence_steps=[3, 5],
83
+ confidence=0.85,
84
+ claim_type="correlational",
85
+ )
86
+ d = c.model_dump()
87
+ assert d["claim_type"] == "correlational"
88
+ assert d["confidence"] == 0.85
tests/test_rewards.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the decomposable reward function."""
2
+
3
+ from models import ActionType, ConclusionClaim, ExperimentAction, IntermediateOutput, OutputType
4
+ from server.rewards.reward import RewardComputer
5
+ from server.simulator.latent_state import (
6
+ ExperimentProgress,
7
+ FullLatentState,
8
+ LatentBiologicalState,
9
+ ResourceState,
10
+ )
11
+
12
+
13
+ def _states(
14
+ prev_flags: dict | None = None,
15
+ next_flags: dict | None = None,
16
+ budget_used: float = 0.0,
17
+ ):
18
+ prev = FullLatentState(
19
+ progress=ExperimentProgress(**(prev_flags or {})),
20
+ resources=ResourceState(budget_total=100_000, budget_used=budget_used),
21
+ )
22
+ nf = dict(prev_flags or {})
23
+ nf.update(next_flags or {})
24
+ nxt = FullLatentState(
25
+ progress=ExperimentProgress(**nf),
26
+ resources=ResourceState(budget_total=100_000, budget_used=budget_used + 5000),
27
+ )
28
+ return prev, nxt
29
+
30
+
31
+ class TestStepReward:
32
+ def test_valid_step_positive(self):
33
+ rc = RewardComputer()
34
+ prev, nxt = _states(
35
+ prev_flags={"samples_collected": True, "library_prepared": True},
36
+ next_flags={"cells_sequenced": True},
37
+ )
38
+ output = IntermediateOutput(
39
+ output_type=OutputType.SEQUENCING_RESULT,
40
+ step_index=1,
41
+ quality_score=0.85,
42
+ uncertainty=0.15,
43
+ )
44
+ rb = rc.step_reward(
45
+ ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
46
+ prev, nxt, output, [], [],
47
+ )
48
+ assert rb.total > 0
49
+
50
+ def test_hard_violation_negative(self):
51
+ rc = RewardComputer()
52
+ prev, nxt = _states()
53
+ output = IntermediateOutput(
54
+ output_type=OutputType.FAILURE_REPORT,
55
+ step_index=1,
56
+ success=False,
57
+ )
58
+ rb = rc.step_reward(
59
+ ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
60
+ prev, nxt, output, ["blocked"], [],
61
+ )
62
+ assert rb.total < 0
63
+
64
+
65
+ class TestTerminalReward:
66
+ def test_correct_conclusion_rewarded(self):
67
+ rc = RewardComputer()
68
+ state = FullLatentState(
69
+ biology=LatentBiologicalState(
70
+ causal_mechanisms=["TGF-beta-driven fibrosis"],
71
+ true_markers=["NPPA"],
72
+ ),
73
+ progress=ExperimentProgress(
74
+ samples_collected=True, cells_sequenced=True,
75
+ qc_performed=True, data_filtered=True,
76
+ data_normalized=True, de_performed=True,
77
+ conclusion_reached=True,
78
+ ),
79
+ resources=ResourceState(budget_total=100_000, budget_used=40_000),
80
+ )
81
+ claims = [
82
+ ConclusionClaim(
83
+ claim="TGF-beta-driven fibrosis observed",
84
+ confidence=0.9,
85
+ claim_type="causal",
86
+ ),
87
+ ]
88
+ rb = rc.terminal_reward(state, claims, [])
89
+ assert rb.terminal > 0
90
+
91
+ def test_overconfident_wrong_claim_penalised(self):
92
+ rc = RewardComputer()
93
+ state = FullLatentState(
94
+ biology=LatentBiologicalState(causal_mechanisms=["real_mechanism"]),
95
+ progress=ExperimentProgress(conclusion_reached=True),
96
+ )
97
+ claims = [
98
+ ConclusionClaim(
99
+ claim="completely_wrong_mechanism",
100
+ confidence=0.95,
101
+ claim_type="causal",
102
+ ),
103
+ ]
104
+ rb = rc.terminal_reward(state, claims, [])
105
+ assert rb.components.get("overconfidence_penalty", 0) < 0
tests/test_rules.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the biological rule engine."""
2
+
3
+ from models import ActionType, ExperimentAction
4
+ from server.rules.engine import RuleEngine, Severity
5
+ from server.simulator.latent_state import (
6
+ ExperimentProgress,
7
+ FullLatentState,
8
+ ResourceState,
9
+ )
10
+
11
+
12
+ def _state(**progress_flags) -> FullLatentState:
13
+ return FullLatentState(
14
+ progress=ExperimentProgress(**progress_flags),
15
+ resources=ResourceState(budget_total=100_000, time_limit_days=180),
16
+ )
17
+
18
+
19
+ class TestPrerequisites:
20
+ def test_sequence_without_library_blocked(self):
21
+ engine = RuleEngine()
22
+ violations = engine.check(
23
+ ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
24
+ _state(samples_collected=True),
25
+ )
26
+ hard = engine.hard_violations(violations)
27
+ assert any("library" in m.lower() for m in hard)
28
+
29
+ def test_sequence_with_library_allowed(self):
30
+ engine = RuleEngine()
31
+ violations = engine.check(
32
+ ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
33
+ _state(samples_collected=True, library_prepared=True),
34
+ )
35
+ hard = engine.hard_violations(violations)
36
+ assert not hard
37
+
38
+ def test_de_without_normalization_blocked(self):
39
+ engine = RuleEngine()
40
+ violations = engine.check(
41
+ ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION),
42
+ _state(cells_sequenced=True, qc_performed=True, data_filtered=True),
43
+ )
44
+ hard = engine.hard_violations(violations)
45
+ assert any("normalis" in m.lower() or "normaliz" in m.lower() for m in hard)
46
+
47
+ def test_validate_marker_without_discovery_blocked(self):
48
+ engine = RuleEngine()
49
+ violations = engine.check(
50
+ ExperimentAction(action_type=ActionType.VALIDATE_MARKER),
51
+ _state(de_performed=True),
52
+ )
53
+ hard = engine.hard_violations(violations)
54
+ assert any("marker" in m.lower() for m in hard)
55
+
56
+
57
+ class TestRedundancy:
58
+ def test_double_qc_is_soft(self):
59
+ engine = RuleEngine()
60
+ violations = engine.check(
61
+ ExperimentAction(action_type=ActionType.RUN_QC),
62
+ _state(cells_sequenced=True, qc_performed=True),
63
+ )
64
+ hard = engine.hard_violations(violations)
65
+ soft = engine.soft_violations(violations)
66
+ assert not hard
67
+ assert any("redundant" in m.lower() for m in soft)
68
+
69
+
70
+ class TestResourceConstraints:
71
+ def test_exhausted_budget_blocked(self):
72
+ s = _state()
73
+ s.resources.budget_used = 100_000
74
+ engine = RuleEngine()
75
+ violations = engine.check(
76
+ ExperimentAction(action_type=ActionType.COLLECT_SAMPLE), s,
77
+ )
78
+ hard = engine.hard_violations(violations)
79
+ assert any("budget" in m.lower() for m in hard)
tests/test_simulator.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the latent-state simulator modules."""
2
+
3
+ import pytest
4
+
5
+ from models import ActionType, ExperimentAction, OutputType
6
+ from server.simulator.latent_state import (
7
+ CellPopulation,
8
+ ExperimentProgress,
9
+ FullLatentState,
10
+ LatentBiologicalState,
11
+ ResourceState,
12
+ TechnicalState,
13
+ )
14
+ from server.simulator.noise import NoiseModel
15
+ from server.simulator.output_generator import OutputGenerator
16
+ from server.simulator.transition import TransitionEngine
17
+
18
+
19
+ def _make_state() -> FullLatentState:
20
+ return FullLatentState(
21
+ biology=LatentBiologicalState(
22
+ cell_populations=[
23
+ CellPopulation(name="A", proportion=0.6, marker_genes=["G1"]),
24
+ CellPopulation(name="B", proportion=0.4, marker_genes=["G2"]),
25
+ ],
26
+ true_de_genes={"disease_vs_healthy": {"G1": 2.0, "G2": -1.5}},
27
+ true_pathways={"apoptosis": 0.7},
28
+ true_markers=["G1"],
29
+ causal_mechanisms=["G1-driven apoptosis"],
30
+ n_true_cells=5000,
31
+ ),
32
+ technical=TechnicalState(dropout_rate=0.1, doublet_rate=0.04),
33
+ progress=ExperimentProgress(),
34
+ resources=ResourceState(budget_total=50_000, time_limit_days=90),
35
+ )
36
+
37
+
38
+ class TestNoiseModel:
39
+ def test_deterministic_with_seed(self):
40
+ n1 = NoiseModel(seed=42)
41
+ n2 = NoiseModel(seed=42)
42
+ assert n1.sample_qc_metric(0.5, 0.1) == n2.sample_qc_metric(0.5, 0.1)
43
+
44
+ def test_false_positives(self):
45
+ n = NoiseModel(seed=0)
46
+ fps = n.generate_false_positives(1000, 0.01)
47
+ assert all(g.startswith("FP_GENE_") for g in fps)
48
+
49
+ def test_quality_degradation_bounded(self):
50
+ n = NoiseModel(seed=0)
51
+ for _ in range(100):
52
+ q = n.quality_degradation(0.9, [0.8, 0.7])
53
+ assert 0.0 <= q <= 1.0
54
+
55
+
56
+ class TestOutputGenerator:
57
+ def test_collect_sample(self):
58
+ noise = NoiseModel(seed=1)
59
+ gen = OutputGenerator(noise)
60
+ s = _make_state()
61
+ action = ExperimentAction(
62
+ action_type=ActionType.COLLECT_SAMPLE,
63
+ parameters={"n_samples": 4},
64
+ )
65
+ out = gen.generate(action, s, 1)
66
+ assert out.output_type == OutputType.SAMPLE_COLLECTION_RESULT
67
+ assert out.data["n_samples"] == 4
68
+
69
+ def test_de_includes_true_genes(self):
70
+ noise = NoiseModel(seed=42)
71
+ gen = OutputGenerator(noise)
72
+ s = _make_state()
73
+ s.progress.data_normalized = True
74
+ action = ExperimentAction(
75
+ action_type=ActionType.DIFFERENTIAL_EXPRESSION,
76
+ parameters={"comparison": "disease_vs_healthy"},
77
+ )
78
+ out = gen.generate(action, s, 5)
79
+ assert out.output_type == OutputType.DE_RESULT
80
+ gene_names = [g["gene"] for g in out.data["top_genes"]]
81
+ assert "G1" in gene_names or "G2" in gene_names
82
+
83
+
84
+ class TestTransitionEngine:
85
+ def test_progress_flags_set(self):
86
+ noise = NoiseModel(seed=0)
87
+ engine = TransitionEngine(noise)
88
+ s = _make_state()
89
+ action = ExperimentAction(action_type=ActionType.COLLECT_SAMPLE)
90
+ result = engine.step(s, action)
91
+ assert result.next_state.progress.samples_collected is True
92
+
93
+ def test_hard_violation_blocks(self):
94
+ noise = NoiseModel(seed=0)
95
+ engine = TransitionEngine(noise)
96
+ s = _make_state()
97
+ result = engine.step(
98
+ s,
99
+ ExperimentAction(action_type=ActionType.COLLECT_SAMPLE),
100
+ hard_violations=["test_block"],
101
+ )
102
+ assert result.output.success is False
103
+ assert result.output.output_type == OutputType.FAILURE_REPORT
104
+
105
+ def test_resource_deduction(self):
106
+ noise = NoiseModel(seed=0)
107
+ engine = TransitionEngine(noise)
108
+ s = _make_state()
109
+ action = ExperimentAction(action_type=ActionType.SEQUENCE_CELLS)
110
+ s.progress.library_prepared = True
111
+ result = engine.step(s, action)
112
+ assert result.next_state.resources.budget_used == 15_000
113
+
114
+ def test_conclusion_ends_episode(self):
115
+ noise = NoiseModel(seed=0)
116
+ engine = TransitionEngine(noise)
117
+ s = _make_state()
118
+ s.progress.de_performed = True
119
+ action = ExperimentAction(action_type=ActionType.SYNTHESIZE_CONCLUSION)
120
+ result = engine.step(s, action)
121
+ assert result.done is True
training/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .evaluation import EvaluationSuite
2
+ from .gym_wrapper import BioExperimentGymEnv
3
+ from .trajectory import Trajectory, TrajectoryDataset
4
+
5
+ __all__ = [
6
+ "BioExperimentGymEnv",
7
+ "EvaluationSuite",
8
+ "PaperBenchmarkResult",
9
+ "Trajectory",
10
+ "TrajectoryDataset",
11
+ "run_paper_benchmark",
12
+ "select_literature_scenario",
13
+ ]
14
+
15
+
16
+ def __getattr__(name: str):
17
+ if name in {
18
+ "PaperBenchmarkResult",
19
+ "run_paper_benchmark",
20
+ "select_literature_scenario",
21
+ }:
22
+ from .literature_benchmark import (
23
+ PaperBenchmarkResult,
24
+ run_paper_benchmark,
25
+ select_literature_scenario,
26
+ )
27
+
28
+ exports = {
29
+ "PaperBenchmarkResult": PaperBenchmarkResult,
30
+ "run_paper_benchmark": run_paper_benchmark,
31
+ "select_literature_scenario": select_literature_scenario,
32
+ }
33
+ return exports[name]
34
+ raise AttributeError(f"module 'training' has no attribute {name!r}")
training/evaluation.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation suite for the bio-experiment planning environment.
2
+
3
+ Separates metrics into four families:
4
+ - online RL metrics (collected during training rollouts)
5
+ - offline benchmark metrics (computed on a fixed held-out set)
6
+ - expert review metrics (for human-in-the-loop evaluation)
7
+ - simulator fidelity metrics (how well the simulator matches reality)
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ import numpy as np
16
+
17
+ from .trajectory import Trajectory, TrajectoryDataset
18
+
19
+
20
+ @dataclass
21
+ class MetricResult:
22
+ name: str
23
+ value: float
24
+ details: Dict[str, Any] = field(default_factory=dict)
25
+
26
+
27
+ class EvaluationSuite:
28
+ """Computes and aggregates evaluation metrics over trajectory datasets."""
29
+
30
+ # ── online RL metrics ───────────────────────────────────────────────
31
+
32
+ @staticmethod
33
+ def online_metrics(trajectories: List[Trajectory]) -> List[MetricResult]:
34
+ if not trajectories:
35
+ return []
36
+
37
+ rewards = [t.total_reward for t in trajectories]
38
+ lengths = [len(t.steps) for t in trajectories]
39
+ successes = [t.success for t in trajectories]
40
+
41
+ return [
42
+ MetricResult("mean_return", float(np.mean(rewards))),
43
+ MetricResult("median_return", float(np.median(rewards))),
44
+ MetricResult("std_return", float(np.std(rewards))),
45
+ MetricResult("mean_episode_length", float(np.mean(lengths))),
46
+ MetricResult("success_rate", float(np.mean(successes))),
47
+ ]
48
+
49
+ # ── offline benchmark metrics ───────────────────────────────────────
50
+
51
+ @staticmethod
52
+ def benchmark_metrics(dataset: TrajectoryDataset) -> List[MetricResult]:
53
+ results: List[MetricResult] = []
54
+ if len(dataset) == 0:
55
+ return results
56
+
57
+ results.append(MetricResult(
58
+ "pipeline_validity_rate",
59
+ EvaluationSuite._pipeline_validity_rate(dataset),
60
+ ))
61
+ results.append(MetricResult(
62
+ "ordering_score",
63
+ EvaluationSuite._ordering_score(dataset),
64
+ ))
65
+ results.append(MetricResult(
66
+ "action_diversity",
67
+ EvaluationSuite._action_diversity(dataset),
68
+ ))
69
+ results.append(MetricResult(
70
+ "mean_conclusion_confidence",
71
+ EvaluationSuite._mean_conclusion_confidence(dataset),
72
+ ))
73
+ return results
74
+
75
+ # ── expert review metrics (stubs) ───────────────────────────────────
76
+
77
+ @staticmethod
78
+ def expert_review_metrics(
79
+ trajectories: List[Trajectory],
80
+ expert_scores: Optional[Dict[str, float]] = None,
81
+ ) -> List[MetricResult]:
82
+ """Placeholder for human expert review scores.
83
+
84
+ In practice, each trajectory would be scored by a domain expert
85
+ on axes such as scientific validity, creativity, and efficiency.
86
+ """
87
+ if not expert_scores:
88
+ return [MetricResult("expert_review", 0.0, {"note": "no scores provided"})]
89
+ avg = float(np.mean(list(expert_scores.values())))
90
+ return [MetricResult("expert_review_mean", avg, expert_scores)]
91
+
92
+ # ── simulator fidelity metrics (stubs) ──────────────────────────────
93
+
94
+ @staticmethod
95
+ def simulator_fidelity_metrics(
96
+ simulated: TrajectoryDataset,
97
+ real: Optional[TrajectoryDataset] = None,
98
+ ) -> List[MetricResult]:
99
+ """Compare simulated trajectories against real experimental data.
100
+
101
+ When ``real`` is provided, computes distributional distances
102
+ between simulated and real output statistics.
103
+ """
104
+ if real is None or len(real) == 0:
105
+ return [MetricResult("fidelity", 0.0, {"note": "no real data"})]
106
+
107
+ sim_rewards = [t.total_reward for t in simulated.trajectories]
108
+ real_rewards = [t.total_reward for t in real.trajectories]
109
+
110
+ reward_gap = abs(float(np.mean(sim_rewards)) - float(np.mean(real_rewards)))
111
+ return [MetricResult("reward_distribution_gap", reward_gap)]
112
+
113
+ # ── internal helpers ────────────────────────────────────────────────
114
+
115
+ @staticmethod
116
+ def _pipeline_validity_rate(ds: TrajectoryDataset) -> float:
117
+ valid = 0
118
+ for t in ds.trajectories:
119
+ violations = sum(
120
+ 1 for s in t.steps
121
+ if not s.observation.get("rule_violations") == []
122
+ and s.observation.get("rule_violations") is not None
123
+ )
124
+ if violations == 0:
125
+ valid += 1
126
+ return valid / max(len(ds), 1)
127
+
128
+ @staticmethod
129
+ def _ordering_score(ds: TrajectoryDataset) -> float:
130
+ scores: List[float] = []
131
+ for t in ds.trajectories:
132
+ breakdown_scores = []
133
+ for s in t.steps:
134
+ bd = s.reward_breakdown
135
+ if "ordering" in bd:
136
+ breakdown_scores.append(bd["ordering"])
137
+ if breakdown_scores:
138
+ scores.append(float(np.mean(breakdown_scores)))
139
+ return float(np.mean(scores)) if scores else 0.0
140
+
141
+ @staticmethod
142
+ def _action_diversity(ds: TrajectoryDataset) -> float:
143
+ all_types: set = set()
144
+ for t in ds.trajectories:
145
+ for s in t.steps:
146
+ at = s.action.get("action_type")
147
+ if at:
148
+ all_types.add(at)
149
+ return len(all_types)
150
+
151
+ @staticmethod
152
+ def _mean_conclusion_confidence(ds: TrajectoryDataset) -> float:
153
+ confs: List[float] = []
154
+ for t in ds.trajectories:
155
+ for s in t.steps:
156
+ conclusions = s.observation.get("conclusions", [])
157
+ for c in conclusions:
158
+ if isinstance(c, dict) and "confidence" in c:
159
+ confs.append(c["confidence"])
160
+ return float(np.mean(confs)) if confs else 0.0
training/gym_wrapper.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gymnasium-compatible wrapper around ``BioExperimentEnvironment``.
2
+
3
+ Provides ``BioExperimentGymEnv`` which wraps the OpenEnv environment for
4
+ local in-process RL training (no HTTP/WebSocket overhead).
5
+
6
+ Observation and action spaces are represented as ``gymnasium.spaces.Dict``
7
+ so that standard RL libraries (SB3, CleanRL, etc.) can ingest them.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any, Dict, Optional, Tuple
13
+
14
+ import gymnasium as gym
15
+ import numpy as np
16
+ from gymnasium import spaces
17
+
18
+ from models import ActionType, ExperimentAction, ExperimentObservation
19
+ from server.hackathon_environment import BioExperimentEnvironment, MAX_STEPS
20
+
21
+
22
+ ACTION_TYPE_LIST = list(ActionType)
23
+ _N_ACTION_TYPES = len(ACTION_TYPE_LIST)
24
+
25
+ _MAX_OUTPUTS = MAX_STEPS
26
+ _MAX_HISTORY = MAX_STEPS
27
+ _VEC_DIM = 64
28
+
29
+
30
+ class BioExperimentGymEnv(gym.Env):
31
+ """Gymnasium ``Env`` backed by the in-process simulator.
32
+
33
+ Observations are flattened into a dictionary of NumPy arrays suitable
34
+ for RL policy networks. Actions are integer-indexed action types with
35
+ a continuous confidence scalar.
36
+
37
+ For LLM-based agents or planners that prefer structured
38
+ ``ExperimentAction`` objects, use the underlying
39
+ ``BioExperimentEnvironment`` directly instead.
40
+ """
41
+
42
+ metadata = {"render_modes": ["human"]}
43
+
44
+ def __init__(self, render_mode: Optional[str] = None):
45
+ super().__init__()
46
+ self._env = BioExperimentEnvironment()
47
+ self.render_mode = render_mode
48
+
49
+ self.action_space = spaces.Dict({
50
+ "action_type": spaces.Discrete(_N_ACTION_TYPES),
51
+ "confidence": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
52
+ })
53
+
54
+ self.observation_space = spaces.Dict({
55
+ "step_index": spaces.Discrete(MAX_STEPS + 1),
56
+ "budget_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
57
+ "time_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
58
+ "progress_flags": spaces.MultiBinary(18),
59
+ "latest_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
60
+ "latest_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
61
+ "avg_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
62
+ "avg_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
63
+ "n_violations": spaces.Discrete(20),
64
+ "n_outputs": spaces.Discrete(_MAX_OUTPUTS + 1),
65
+ "cumulative_reward": spaces.Box(-100.0, 100.0, shape=(), dtype=np.float32),
66
+ })
67
+
68
+ self._last_obs: Optional[ExperimentObservation] = None
69
+
70
+ # ── Gymnasium interface ─────────────────────────────────────────────
71
+
72
+ def reset(
73
+ self,
74
+ *,
75
+ seed: Optional[int] = None,
76
+ options: Optional[Dict[str, Any]] = None,
77
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
78
+ super().reset(seed=seed)
79
+ obs = self._env.reset()
80
+ self._last_obs = obs
81
+ return self._vectorise(obs), self._info(obs)
82
+
83
+ def step(
84
+ self, action: Dict[str, Any]
85
+ ) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
86
+ action_idx = int(action["action_type"])
87
+ confidence = float(action.get("confidence", 0.5))
88
+
89
+ experiment_action = ExperimentAction(
90
+ action_type=ACTION_TYPE_LIST[action_idx],
91
+ confidence=confidence,
92
+ )
93
+ obs = self._env.step(experiment_action)
94
+ self._last_obs = obs
95
+
96
+ terminated = obs.done
97
+ truncated = obs.step_index >= MAX_STEPS and not terminated
98
+ reward = obs.reward
99
+
100
+ return (
101
+ self._vectorise(obs),
102
+ reward,
103
+ terminated,
104
+ truncated,
105
+ self._info(obs),
106
+ )
107
+
108
+ def render(self) -> Optional[str]:
109
+ if self.render_mode != "human" or self._last_obs is None:
110
+ return None
111
+ obs = self._last_obs
112
+ lines = [
113
+ f"Step {obs.step_index}",
114
+ f" Task: {obs.task.problem_statement[:80]}",
115
+ f" Budget: ${obs.resource_usage.budget_remaining:,.0f} remaining",
116
+ f" Time: {obs.resource_usage.time_remaining_days:.0f} days remaining",
117
+ ]
118
+ if obs.latest_output:
119
+ lines.append(f" Latest: {obs.latest_output.summary}")
120
+ if obs.rule_violations:
121
+ lines.append(f" Violations: {obs.rule_violations}")
122
+ text = "\n".join(lines)
123
+ print(text)
124
+ return text
125
+
126
+ # ── helpers ─────────────────────────────────────────────────────────
127
+
128
+ def _vectorise(self, obs: ExperimentObservation) -> Dict[str, Any]:
129
+ progress = self._env._latent.progress if self._env._latent else None
130
+ flags = np.zeros(18, dtype=np.int8)
131
+ if progress:
132
+ flag_names = [
133
+ "samples_collected", "cohort_selected", "cells_cultured",
134
+ "library_prepared", "perturbation_applied", "cells_sequenced",
135
+ "qc_performed", "data_filtered", "data_normalized",
136
+ "batches_integrated", "cells_clustered", "de_performed",
137
+ "trajectories_inferred", "pathways_analyzed",
138
+ "networks_inferred", "markers_discovered",
139
+ "markers_validated", "conclusion_reached",
140
+ ]
141
+ for i, f in enumerate(flag_names):
142
+ flags[i] = int(getattr(progress, f, False))
143
+
144
+ unc = obs.uncertainty_summary
145
+ lo = obs.latest_output
146
+
147
+ return {
148
+ "step_index": obs.step_index,
149
+ "budget_remaining_frac": np.float32(
150
+ obs.resource_usage.budget_remaining
151
+ / max(obs.task.budget_limit, 1)
152
+ ),
153
+ "time_remaining_frac": np.float32(
154
+ obs.resource_usage.time_remaining_days
155
+ / max(obs.task.time_limit_days, 1)
156
+ ),
157
+ "progress_flags": flags,
158
+ "latest_quality": np.float32(lo.quality_score if lo else 0.0),
159
+ "latest_uncertainty": np.float32(lo.uncertainty if lo else 0.0),
160
+ "avg_quality": np.float32(unc.get("avg_quality", 0.0)),
161
+ "avg_uncertainty": np.float32(unc.get("avg_uncertainty", 0.0)),
162
+ "n_violations": min(len(obs.rule_violations), 19),
163
+ "n_outputs": min(len(obs.all_outputs), _MAX_OUTPUTS),
164
+ "cumulative_reward": np.float32(
165
+ obs.metadata.get("cumulative_reward", 0.0)
166
+ if obs.metadata else 0.0
167
+ ),
168
+ }
169
+
170
+ def _info(self, obs: ExperimentObservation) -> Dict[str, Any]:
171
+ return {
172
+ "structured_obs": obs,
173
+ "episode_id": obs.metadata.get("episode_id") if obs.metadata else None,
174
+ }
training/literature_benchmark.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Literature-grounded experiment benchmark utilities.
2
+
3
+ This module lets the environment run a paper-backed experiment plan, then
4
+ compare the resulting simulated findings against curated expected findings
5
+ from the literature.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import json
12
+ import re
13
+ from dataclasses import asdict, dataclass, field
14
+ from importlib.metadata import PackageNotFoundError, version
15
+ from typing import Any, Dict, List, Optional, Sequence
16
+
17
+ from models import (
18
+ ActionType,
19
+ ConclusionClaim,
20
+ ExperimentAction,
21
+ ExperimentObservation,
22
+ OutputType,
23
+ TaskSpec,
24
+ )
25
+ from server.hackathon_environment import BioExperimentEnvironment
26
+ from server.tasks.scenarios import SCENARIO_LIBRARY, Scenario
27
+
28
+ TOKEN_RE = re.compile(r"[A-Za-z0-9_+\-]+")
29
+ STOPWORDS = {
30
+ "a",
31
+ "an",
32
+ "and",
33
+ "as",
34
+ "by",
35
+ "for",
36
+ "from",
37
+ "in",
38
+ "into",
39
+ "of",
40
+ "on",
41
+ "or",
42
+ "the",
43
+ "to",
44
+ "using",
45
+ "with",
46
+ }
47
+
48
+ BIO_LIBRARY_DISTRIBUTIONS = {
49
+ "scanpy": "scanpy",
50
+ "gseapy": "gseapy",
51
+ "biopython": "biopython",
52
+ }
53
+
54
+
55
+ @dataclass
56
+ class PaperBenchmarkResult:
57
+ scenario_name: str
58
+ problem_statement: str
59
+ matched_papers: List[str]
60
+ bio_library_versions: Dict[str, Optional[str]]
61
+ matched_findings: List[str] = field(default_factory=list)
62
+ missed_findings: List[str] = field(default_factory=list)
63
+ discovered_markers: List[str] = field(default_factory=list)
64
+ candidate_mechanisms: List[str] = field(default_factory=list)
65
+ conclusions: List[str] = field(default_factory=list)
66
+ final_reward: float = 0.0
67
+ total_steps: int = 0
68
+
69
+ @property
70
+ def match_ratio(self) -> float:
71
+ total = len(self.matched_findings) + len(self.missed_findings)
72
+ return len(self.matched_findings) / max(total, 1)
73
+
74
+ def to_dict(self) -> Dict[str, Any]:
75
+ data = asdict(self)
76
+ data["match_ratio"] = self.match_ratio
77
+ return data
78
+
79
+
80
+ def detect_bio_library_versions() -> Dict[str, Optional[str]]:
81
+ versions: Dict[str, Optional[str]] = {}
82
+ for name, dist_name in BIO_LIBRARY_DISTRIBUTIONS.items():
83
+ try:
84
+ versions[name] = version(dist_name)
85
+ except PackageNotFoundError:
86
+ versions[name] = None
87
+ return versions
88
+
89
+
90
+ def select_literature_scenario(problem_statement: str) -> Scenario:
91
+ """Pick the closest literature-backed scenario for a prompt."""
92
+
93
+ prompt_tokens = set(_tokenize(problem_statement))
94
+ best_score = -1
95
+ best_scenario: Optional[Scenario] = None
96
+
97
+ for scenario in SCENARIO_LIBRARY:
98
+ if not scenario.task.paper_references:
99
+ continue
100
+ corpus = [
101
+ scenario.task.problem_statement,
102
+ *(ref.title for ref in scenario.task.paper_references),
103
+ *(finding.finding for finding in scenario.task.expected_findings),
104
+ scenario.task.tissue,
105
+ scenario.task.modality,
106
+ *scenario.task.conditions,
107
+ ]
108
+ score = len(prompt_tokens & set(_tokenize(" ".join(corpus))))
109
+ if scenario.task.problem_statement.lower() in problem_statement.lower():
110
+ score += 4
111
+ if score > best_score:
112
+ best_score = score
113
+ best_scenario = scenario
114
+
115
+ if best_scenario is None:
116
+ raise ValueError("No literature-backed scenarios are available.")
117
+ return best_scenario
118
+
119
+
120
+ def run_paper_benchmark(
121
+ *,
122
+ problem_statement: str,
123
+ scenario_name: Optional[str] = None,
124
+ domain_randomise: bool = False,
125
+ ) -> PaperBenchmarkResult:
126
+ """Run a literature-backed episode and compare outputs to paper results."""
127
+
128
+ scenario = _resolve_scenario(problem_statement, scenario_name)
129
+ env = BioExperimentEnvironment(
130
+ scenario_name=scenario.name,
131
+ domain_randomise=domain_randomise,
132
+ )
133
+ obs = env.reset()
134
+
135
+ for action in build_paper_aligned_actions(obs.task):
136
+ obs = env.step(action)
137
+
138
+ claims = infer_conclusion_claims(obs)
139
+ obs = env.step(
140
+ ExperimentAction(
141
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
142
+ parameters={"claims": [claim.model_dump() for claim in claims]},
143
+ justification=(
144
+ "Summarize the simulated experimental evidence and compare it "
145
+ "with the paper-backed expected findings."
146
+ ),
147
+ confidence=0.8,
148
+ tool_call_spec=_tool_context(
149
+ obs.task,
150
+ libraries=["biopython"],
151
+ include_expected_findings=True,
152
+ ),
153
+ )
154
+ )
155
+
156
+ matched, missed = compare_expected_findings(obs.task, obs)
157
+ return PaperBenchmarkResult(
158
+ scenario_name=scenario.name,
159
+ problem_statement=obs.task.problem_statement,
160
+ matched_papers=[ref.title for ref in obs.task.paper_references],
161
+ bio_library_versions=detect_bio_library_versions(),
162
+ matched_findings=matched,
163
+ missed_findings=missed,
164
+ discovered_markers=list(obs.discovered_markers),
165
+ candidate_mechanisms=list(obs.candidate_mechanisms),
166
+ conclusions=[c.claim for c in obs.conclusions],
167
+ final_reward=float(obs.metadata.get("cumulative_reward", 0.0)),
168
+ total_steps=obs.step_index,
169
+ )
170
+
171
+
172
+ def build_paper_aligned_actions(task: TaskSpec) -> List[ExperimentAction]:
173
+ """Construct a pragmatic analysis plan aligned to the task modality."""
174
+
175
+ actions: List[ExperimentAction] = [
176
+ ExperimentAction(
177
+ action_type=ActionType.COLLECT_SAMPLE,
178
+ parameters={"n_samples": 8},
179
+ justification="Collect enough samples to support downstream analysis.",
180
+ confidence=0.75,
181
+ tool_call_spec=_tool_context(task, libraries=["biopython"]),
182
+ ),
183
+ ExperimentAction(
184
+ action_type=ActionType.PREPARE_LIBRARY,
185
+ method="10x_chromium",
186
+ justification="Use a standard single-cell library prep workflow.",
187
+ confidence=0.8,
188
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
189
+ ),
190
+ ExperimentAction(
191
+ action_type=ActionType.SEQUENCE_CELLS,
192
+ method="NovaSeq",
193
+ justification="Generate sufficient single-cell read depth.",
194
+ confidence=0.8,
195
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
196
+ ),
197
+ ExperimentAction(
198
+ action_type=ActionType.RUN_QC,
199
+ method="scanpy.pp.calculate_qc_metrics",
200
+ justification="Check technical quality before downstream inference.",
201
+ confidence=0.85,
202
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
203
+ ),
204
+ ExperimentAction(
205
+ action_type=ActionType.FILTER_DATA,
206
+ method="scanpy.pp.filter_cells",
207
+ justification="Remove low-quality cells and reduce technical noise.",
208
+ confidence=0.85,
209
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
210
+ ),
211
+ ExperimentAction(
212
+ action_type=ActionType.NORMALIZE_DATA,
213
+ method="scanpy.pp.normalize_total",
214
+ justification="Normalize expression to prepare comparable profiles.",
215
+ confidence=0.85,
216
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
217
+ ),
218
+ ExperimentAction(
219
+ action_type=ActionType.CLUSTER_CELLS,
220
+ method="scanpy.tl.leiden",
221
+ justification="Resolve cell states before focused interpretation.",
222
+ confidence=0.8,
223
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
224
+ ),
225
+ ]
226
+
227
+ categories = {finding.category for finding in task.expected_findings}
228
+ if "trajectory" in categories:
229
+ actions.extend([
230
+ ExperimentAction(
231
+ action_type=ActionType.TRAJECTORY_ANALYSIS,
232
+ method="scanpy.tl.dpt",
233
+ justification="Recover pseudotime structure and lineage branches.",
234
+ confidence=0.8,
235
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
236
+ ),
237
+ ExperimentAction(
238
+ action_type=ActionType.REGULATORY_NETWORK_INFERENCE,
239
+ method="pySCENIC",
240
+ justification="Infer branch-associated regulators from the trajectory.",
241
+ confidence=0.75,
242
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
243
+ ),
244
+ ExperimentAction(
245
+ action_type=ActionType.MARKER_SELECTION,
246
+ method="scanpy.tl.rank_genes_groups",
247
+ justification="Summarize lineage markers and branch-state genes.",
248
+ confidence=0.75,
249
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
250
+ ),
251
+ ])
252
+ return actions
253
+
254
+ actions.extend([
255
+ ExperimentAction(
256
+ action_type=ActionType.DIFFERENTIAL_EXPRESSION,
257
+ method="scanpy.tl.rank_genes_groups",
258
+ parameters={"comparison": _default_comparison_name(task)},
259
+ justification="Identify genes associated with the focal phenotype.",
260
+ confidence=0.85,
261
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
262
+ ),
263
+ ExperimentAction(
264
+ action_type=ActionType.PATHWAY_ENRICHMENT,
265
+ method="gseapy.prerank",
266
+ justification="Translate DE hits into pathway-level interpretation.",
267
+ confidence=0.8,
268
+ tool_call_spec=_tool_context(task, libraries=["gseapy"]),
269
+ ),
270
+ ExperimentAction(
271
+ action_type=ActionType.MARKER_SELECTION,
272
+ method="scanpy.tl.rank_genes_groups",
273
+ justification="Nominate candidate markers for follow-up validation.",
274
+ confidence=0.8,
275
+ tool_call_spec=_tool_context(task, libraries=["scanpy"]),
276
+ ),
277
+ ExperimentAction(
278
+ action_type=ActionType.VALIDATE_MARKER,
279
+ method="immunofluorescence",
280
+ parameters={"marker": _preferred_marker(task)},
281
+ justification="Check whether the leading marker reproduces in validation.",
282
+ confidence=0.75,
283
+ tool_call_spec=_tool_context(task, libraries=["biopython"]),
284
+ ),
285
+ ])
286
+ return actions
287
+
288
+
289
+ def infer_conclusion_claims(obs: ExperimentObservation) -> List[ConclusionClaim]:
290
+ """Turn accumulated evidence into concise, paper-comparable claims."""
291
+
292
+ markers = set(obs.discovered_markers)
293
+ mechanisms = set(obs.candidate_mechanisms)
294
+ network_regulators = set(_extract_network_regulators(obs))
295
+ trajectory_output = _latest_output_data(obs, OutputType.TRAJECTORY_RESULT)
296
+
297
+ claims: List[ConclusionClaim] = []
298
+
299
+ if "SPP1" in markers:
300
+ claims.append(ConclusionClaim(
301
+ claim="SPP1-positive macrophages are enriched in IPF fibrotic tissue.",
302
+ confidence=0.84,
303
+ claim_type="marker",
304
+ evidence_steps=_evidence_steps(obs, {
305
+ OutputType.DE_RESULT,
306
+ OutputType.MARKER_RESULT,
307
+ OutputType.VALIDATION_RESULT,
308
+ }),
309
+ ))
310
+ if {"SPP1", "MERTK"} <= markers:
311
+ claims.append(ConclusionClaim(
312
+ claim="MERTK co-occurs with the SPP1-positive profibrotic macrophage state.",
313
+ confidence=0.8,
314
+ claim_type="marker",
315
+ evidence_steps=_evidence_steps(obs, {
316
+ OutputType.DE_RESULT,
317
+ OutputType.MARKER_RESULT,
318
+ }),
319
+ ))
320
+ if "extracellular_matrix_organisation" in mechanisms:
321
+ claims.append(ConclusionClaim(
322
+ claim=(
323
+ "Extracellular matrix organization is a dominant fibrotic "
324
+ "program in the IPF samples."
325
+ ),
326
+ confidence=0.78,
327
+ claim_type="pathway",
328
+ evidence_steps=_evidence_steps(obs, {OutputType.PATHWAY_RESULT}),
329
+ ))
330
+
331
+ if trajectory_output.get("branching_detected"):
332
+ claims.append(ConclusionClaim(
333
+ claim=(
334
+ "Trajectory analysis recovered branching blood lineages rooted "
335
+ "in HSCs."
336
+ ),
337
+ confidence=0.82,
338
+ claim_type="trajectory",
339
+ evidence_steps=_evidence_steps(obs, {OutputType.TRAJECTORY_RESULT}),
340
+ ))
341
+ if "GATA1" in network_regulators:
342
+ claims.append(ConclusionClaim(
343
+ claim="GATA1 emerges as a driver of erythroid fate commitment.",
344
+ confidence=0.8,
345
+ claim_type="regulatory_network",
346
+ evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}),
347
+ ))
348
+ if {"CEBPA", "SPI1"} & network_regulators:
349
+ claims.append(ConclusionClaim(
350
+ claim="CEBPA and SPI1 support myeloid branch decisions.",
351
+ confidence=0.78,
352
+ claim_type="regulatory_network",
353
+ evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}),
354
+ ))
355
+
356
+ if claims:
357
+ return claims
358
+
359
+ # Fallback: preserve the strongest expected findings verbatim if the
360
+ # heuristic extractors do not recover enough signal from the episode.
361
+ return [
362
+ ConclusionClaim(
363
+ claim=finding.finding,
364
+ confidence=0.65,
365
+ claim_type=finding.category,
366
+ )
367
+ for finding in obs.task.expected_findings[:3]
368
+ ]
369
+
370
+
371
+ def compare_expected_findings(
372
+ task: TaskSpec,
373
+ obs: ExperimentObservation,
374
+ ) -> tuple[List[str], List[str]]:
375
+ """Compare the episode evidence against literature-backed findings."""
376
+
377
+ evidence_text = _evidence_text(obs)
378
+ matched: List[str] = []
379
+ missed: List[str] = []
380
+
381
+ for finding in task.expected_findings:
382
+ keywords = [kw.lower() for kw in finding.keywords]
383
+ if not keywords:
384
+ keywords = _tokenize(finding.finding)
385
+ hits = sum(1 for kw in keywords if kw in evidence_text)
386
+ threshold = max(1, (len(keywords) + 1) // 2)
387
+ if hits >= threshold:
388
+ matched.append(finding.finding)
389
+ else:
390
+ missed.append(finding.finding)
391
+
392
+ return matched, missed
393
+
394
+
395
+ def _resolve_scenario(
396
+ problem_statement: str,
397
+ scenario_name: Optional[str],
398
+ ) -> Scenario:
399
+ if scenario_name:
400
+ for scenario in SCENARIO_LIBRARY:
401
+ if scenario.name == scenario_name:
402
+ return scenario
403
+ raise ValueError(f"Unknown scenario_name '{scenario_name}'.")
404
+ return select_literature_scenario(problem_statement)
405
+
406
+
407
+ def _tool_context(
408
+ task: TaskSpec,
409
+ *,
410
+ libraries: Sequence[str],
411
+ include_expected_findings: bool = False,
412
+ ) -> Dict[str, Any]:
413
+ context: Dict[str, Any] = {
414
+ "literature_query": task.problem_statement,
415
+ "paper_references": [
416
+ {
417
+ "title": ref.title,
418
+ "doi": ref.doi,
419
+ "pmid": ref.pmid,
420
+ "url": ref.url,
421
+ }
422
+ for ref in task.paper_references
423
+ ],
424
+ "bioinformatics_libraries": list(libraries),
425
+ }
426
+ if include_expected_findings:
427
+ context["expected_findings"] = [
428
+ finding.finding for finding in task.expected_findings
429
+ ]
430
+ return context
431
+
432
+
433
+ def _default_comparison_name(task: TaskSpec) -> str:
434
+ conditions = {condition.lower() for condition in task.conditions}
435
+ if {"healthy", "ipf"} <= conditions:
436
+ return "IPF_vs_healthy"
437
+ if any("treated" in condition for condition in conditions) and any(
438
+ "untreated" in condition for condition in conditions
439
+ ):
440
+ return "treated_vs_untreated"
441
+ if any("healthy" in condition for condition in conditions):
442
+ return "disease_vs_healthy"
443
+ return "disease_vs_healthy"
444
+
445
+
446
+ def _preferred_marker(task: TaskSpec) -> str:
447
+ for finding in task.expected_findings:
448
+ for keyword in finding.keywords:
449
+ if keyword.isupper():
450
+ return keyword
451
+ return "SPP1"
452
+
453
+
454
+ def _latest_output_data(
455
+ obs: ExperimentObservation,
456
+ output_type: OutputType,
457
+ ) -> Dict[str, Any]:
458
+ for output in reversed(obs.all_outputs):
459
+ if output.output_type == output_type:
460
+ return output.data
461
+ return {}
462
+
463
+
464
+ def _extract_network_regulators(obs: ExperimentObservation) -> List[str]:
465
+ for output in reversed(obs.all_outputs):
466
+ if output.output_type == OutputType.NETWORK_RESULT:
467
+ return output.data.get("top_regulators", [])
468
+ return []
469
+
470
+
471
+ def _evidence_steps(
472
+ obs: ExperimentObservation,
473
+ output_types: set[OutputType],
474
+ ) -> List[int]:
475
+ return [
476
+ output.step_index
477
+ for output in obs.all_outputs
478
+ if output.output_type in output_types
479
+ ]
480
+
481
+
482
+ def _evidence_text(obs: ExperimentObservation) -> str:
483
+ parts: List[str] = []
484
+ parts.extend(obs.discovered_markers)
485
+ parts.extend(obs.candidate_mechanisms)
486
+ parts.extend(conclusion.claim for conclusion in obs.conclusions)
487
+
488
+ for output in obs.all_outputs:
489
+ parts.append(output.summary)
490
+ if output.output_type == OutputType.DE_RESULT:
491
+ parts.extend(
492
+ gene["gene"]
493
+ for gene in output.data.get("top_genes", [])
494
+ if isinstance(gene, dict) and "gene" in gene
495
+ )
496
+ elif output.output_type == OutputType.PATHWAY_RESULT:
497
+ parts.extend(
498
+ pathway["pathway"]
499
+ for pathway in output.data.get("top_pathways", [])
500
+ if isinstance(pathway, dict) and "pathway" in pathway
501
+ )
502
+ elif output.output_type == OutputType.NETWORK_RESULT:
503
+ parts.extend(output.data.get("top_regulators", []))
504
+ elif output.output_type == OutputType.TRAJECTORY_RESULT:
505
+ if output.data.get("branching_detected"):
506
+ parts.append("branching lineage HSC trajectory")
507
+
508
+ return " ".join(parts).lower()
509
+
510
+
511
+ def _tokenize(text: str) -> List[str]:
512
+ return [
513
+ token.lower()
514
+ for token in TOKEN_RE.findall(text)
515
+ if token and token.lower() not in STOPWORDS
516
+ ]
517
+
518
+
519
+ def main() -> None:
520
+ parser = argparse.ArgumentParser()
521
+ parser.add_argument(
522
+ "--problem-statement",
523
+ default=(
524
+ "Design a follow-up validation experiment for candidate biomarker "
525
+ "SPP1 in idiopathic pulmonary fibrosis."
526
+ ),
527
+ )
528
+ parser.add_argument("--scenario-name", default=None)
529
+ parser.add_argument("--domain-randomise", action="store_true")
530
+ parser.add_argument("--json", action="store_true")
531
+ args = parser.parse_args()
532
+
533
+ result = run_paper_benchmark(
534
+ problem_statement=args.problem_statement,
535
+ scenario_name=args.scenario_name,
536
+ domain_randomise=args.domain_randomise,
537
+ )
538
+
539
+ if args.json:
540
+ print(json.dumps(result.to_dict(), indent=2))
541
+ return
542
+
543
+ print(f"Scenario: {result.scenario_name}")
544
+ print(f"Problem: {result.problem_statement}")
545
+ print(f"Paper: {', '.join(result.matched_papers)}")
546
+ print(f"Match ratio: {result.match_ratio:.2%}")
547
+ print(f"Matched findings: {len(result.matched_findings)}")
548
+ print(f"Missed findings: {len(result.missed_findings)}")
549
+ print(f"Discovered markers: {', '.join(result.discovered_markers[:8])}")
550
+ print(f"Candidate mechanisms: {', '.join(result.candidate_mechanisms[:5])}")
551
+ print(f"Conclusions: {len(result.conclusions)}")
552
+ print(f"Final reward: {result.final_reward:+.3f}")
553
+ print(f"Bio libraries: {json.dumps(result.bio_library_versions, sort_keys=True)}")
554
+
555
+
556
+ if __name__ == "__main__":
557
+ main()
training/trajectory.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trajectory serialisation and dataset utilities.
2
+
3
+ A ``Trajectory`` stores the full history of one episode (task, actions,
4
+ observations, rewards, latent-state snapshots) in a format that supports:
5
+ - offline RL training
6
+ - imitation learning from expert demonstrations
7
+ - evaluation / replay
8
+ - simulator calibration
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ from dataclasses import dataclass, field
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ from models import (
19
+ ExperimentAction,
20
+ ExperimentObservation,
21
+ TaskSpec,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class TrajectoryStep:
27
+ step_index: int
28
+ action: Dict[str, Any]
29
+ observation: Dict[str, Any]
30
+ reward: float
31
+ done: bool
32
+ reward_breakdown: Dict[str, float] = field(default_factory=dict)
33
+ latent_snapshot: Optional[Dict[str, Any]] = None
34
+
35
+
36
+ @dataclass
37
+ class Trajectory:
38
+ """Complete record of one environment episode."""
39
+
40
+ episode_id: str
41
+ task: Dict[str, Any]
42
+ steps: List[TrajectoryStep] = field(default_factory=list)
43
+ total_reward: float = 0.0
44
+ success: bool = False
45
+ metadata: Dict[str, Any] = field(default_factory=dict)
46
+
47
+ # ── construction helpers ────────────────────────────────────────────
48
+
49
+ def add_step(
50
+ self,
51
+ action: ExperimentAction,
52
+ observation: ExperimentObservation,
53
+ reward: float,
54
+ done: bool,
55
+ reward_breakdown: Optional[Dict[str, float]] = None,
56
+ latent_snapshot: Optional[Dict[str, Any]] = None,
57
+ ) -> None:
58
+ self.steps.append(TrajectoryStep(
59
+ step_index=len(self.steps),
60
+ action=action.model_dump(),
61
+ observation=observation.model_dump(),
62
+ reward=reward,
63
+ done=done,
64
+ reward_breakdown=reward_breakdown or {},
65
+ latent_snapshot=latent_snapshot,
66
+ ))
67
+ self.total_reward += reward
68
+ if done:
69
+ self.success = reward > 0
70
+
71
+ # ── serialisation ───────────────────────────────────────────────────
72
+
73
+ def to_dict(self) -> Dict[str, Any]:
74
+ return {
75
+ "episode_id": self.episode_id,
76
+ "task": self.task,
77
+ "steps": [
78
+ {
79
+ "step_index": s.step_index,
80
+ "action": s.action,
81
+ "observation": s.observation,
82
+ "reward": s.reward,
83
+ "done": s.done,
84
+ "reward_breakdown": s.reward_breakdown,
85
+ "latent_snapshot": s.latent_snapshot,
86
+ }
87
+ for s in self.steps
88
+ ],
89
+ "total_reward": self.total_reward,
90
+ "success": self.success,
91
+ "metadata": self.metadata,
92
+ }
93
+
94
+ def save(self, path: str | Path) -> None:
95
+ p = Path(path)
96
+ p.parent.mkdir(parents=True, exist_ok=True)
97
+ with open(p, "w") as f:
98
+ json.dump(self.to_dict(), f, indent=2, default=str)
99
+
100
+ @classmethod
101
+ def load(cls, path: str | Path) -> "Trajectory":
102
+ with open(path) as f:
103
+ d = json.load(f)
104
+ traj = cls(
105
+ episode_id=d["episode_id"],
106
+ task=d["task"],
107
+ total_reward=d.get("total_reward", 0.0),
108
+ success=d.get("success", False),
109
+ metadata=d.get("metadata", {}),
110
+ )
111
+ for s in d.get("steps", []):
112
+ traj.steps.append(TrajectoryStep(**s))
113
+ return traj
114
+
115
+
116
+ class TrajectoryDataset:
117
+ """In-memory collection of trajectories with convenience accessors."""
118
+
119
+ def __init__(self, trajectories: Optional[List[Trajectory]] = None):
120
+ self.trajectories: List[Trajectory] = trajectories or []
121
+
122
+ def add(self, traj: Trajectory) -> None:
123
+ self.trajectories.append(traj)
124
+
125
+ def __len__(self) -> int:
126
+ return len(self.trajectories)
127
+
128
+ def __getitem__(self, idx: int) -> Trajectory:
129
+ return self.trajectories[idx]
130
+
131
+ def filter_successful(self) -> "TrajectoryDataset":
132
+ return TrajectoryDataset([t for t in self.trajectories if t.success])
133
+
134
+ def save_dir(self, directory: str | Path) -> None:
135
+ d = Path(directory)
136
+ d.mkdir(parents=True, exist_ok=True)
137
+ for t in self.trajectories:
138
+ t.save(d / f"{t.episode_id}.json")
139
+
140
+ @classmethod
141
+ def load_dir(cls, directory: str | Path) -> "TrajectoryDataset":
142
+ d = Path(directory)
143
+ trajs = [Trajectory.load(p) for p in sorted(d.glob("*.json"))]
144
+ return cls(trajs)
145
+
146
+ def summary(self) -> Dict[str, Any]:
147
+ if not self.trajectories:
148
+ return {"n": 0}
149
+ rewards = [t.total_reward for t in self.trajectories]
150
+ lengths = [len(t.steps) for t in self.trajectories]
151
+ success_rate = sum(1 for t in self.trajectories if t.success) / len(self.trajectories)
152
+ return {
153
+ "n": len(self.trajectories),
154
+ "success_rate": success_rate,
155
+ "mean_reward": sum(rewards) / len(rewards),
156
+ "mean_length": sum(lengths) / len(lengths),
157
+ "max_reward": max(rewards),
158
+ "min_reward": min(rewards),
159
+ }
uv.lock CHANGED
The diff for this file is too large to render. See raw diff