Spaces:
Running
Running
Prepare MolForge OpenEnv Docker Space submission
Browse files- .dockerignore +23 -0
- .gitignore +47 -0
- Dockerfile +45 -0
- EVALUATION_PROTOCOL.md +103 -0
- HF_RL_JOBS_NOTES.md +92 -0
- README.md +159 -6
- REAL_WORLD_WORKFLOW_MAPPING.md +221 -0
- RL_TRAINING_COLAB.md +51 -0
- Requirements_before_submitting.md +521 -0
- TRAINING_INSTRUCTIONS.md +253 -0
- __init__.py +11 -0
- client.py +31 -0
- inference.py +209 -0
- inference_common.py +831 -0
- local_inference.py +203 -0
- lora_inference.py +244 -0
- mlx_lora_inference.py +457 -0
- models.py +216 -0
- molforge_grpo_official_submission.ipynb +277 -0
- molforge_oracles.py +274 -0
- openenv.yaml +6 -0
- openenv_shim.py +114 -0
- pyproject.toml +32 -0
- scenarios.py +504 -0
- scripts/convert_peft_lora_to_mlx.py +98 -0
- scripts/generate_sft_all_actions_dataset.py +180 -0
- scripts/generate_sft_compact_policy_v4_dataset.py +446 -0
- scripts/validate_sft_traces.py +97 -0
- server/Dockerfile +45 -0
- server/__init__.py +1 -0
- server/actions.py +414 -0
- server/app.py +36 -0
- server/governance.py +576 -0
- server/molforge_environment.py +342 -0
- server/requirements.txt +8 -0
- server/shared.py +227 -0
- server/views.py +436 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.gitignore
|
| 3 |
+
.venv
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
.DS_Store
|
| 7 |
+
|
| 8 |
+
artifacts/
|
| 9 |
+
adapters/
|
| 10 |
+
data/
|
| 11 |
+
openenv_molforge.egg-info/
|
| 12 |
+
*.egg-info/
|
| 13 |
+
|
| 14 |
+
qwen3_5_2b_lora_adapters*/
|
| 15 |
+
*.safetensors
|
| 16 |
+
*.pt
|
| 17 |
+
*.pth
|
| 18 |
+
*.bin
|
| 19 |
+
|
| 20 |
+
analysis_results.md
|
| 21 |
+
help_guide/
|
| 22 |
+
issue/*.ipynb
|
| 23 |
+
scripts/__pycache__/
|
.gitignore
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# macOS / editor noise
|
| 2 |
+
.DS_Store
|
| 3 |
+
|
| 4 |
+
# Output runs
|
| 5 |
+
molforge_rl_runs/
|
| 6 |
+
molforge_grpo_*/
|
| 7 |
+
*.egg-info/
|
| 8 |
+
|
| 9 |
+
*.swp
|
| 10 |
+
*.swo
|
| 11 |
+
|
| 12 |
+
# Python caches and local environments
|
| 13 |
+
__pycache__/
|
| 14 |
+
*.py[cod]
|
| 15 |
+
.pytest_cache/
|
| 16 |
+
.mypy_cache/
|
| 17 |
+
.ruff_cache/
|
| 18 |
+
.venv/
|
| 19 |
+
venv/
|
| 20 |
+
env/
|
| 21 |
+
|
| 22 |
+
# Build/package outputs
|
| 23 |
+
build/
|
| 24 |
+
dist/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
|
| 27 |
+
# Local secrets and notebooks
|
| 28 |
+
.env
|
| 29 |
+
.env.*
|
| 30 |
+
*.ipynb_checkpoints/
|
| 31 |
+
|
| 32 |
+
# Generated model/adapters/checkpoints
|
| 33 |
+
qwen3_5_2b_lora_adapters*/
|
| 34 |
+
artifacts/
|
| 35 |
+
outputs/
|
| 36 |
+
checkpoints/
|
| 37 |
+
*.safetensors
|
| 38 |
+
*.bin
|
| 39 |
+
*.pt
|
| 40 |
+
*.pth
|
| 41 |
+
*.ckpt
|
| 42 |
+
|
| 43 |
+
# Legacy/generated SFT artifacts. Keep the current v4 dataset in issue/.
|
| 44 |
+
data/*.jsonl
|
| 45 |
+
issue/molforge_sft_compact_policy_v3.jsonl
|
| 46 |
+
qwen3_5_2b_unsloth_sft.py
|
| 47 |
+
analysis_results.md
|
Dockerfile
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
+
FROM ${BASE_IMAGE} AS builder
|
| 3 |
+
ARG INSTALL_TDC=0
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
RUN apt-get update && \
|
| 8 |
+
apt-get install -y --no-install-recommends git && \
|
| 9 |
+
rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY . /app/env
|
| 12 |
+
WORKDIR /app/env
|
| 13 |
+
|
| 14 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 15 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 16 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 17 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
ENV UV_LINK_MODE=copy
|
| 21 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 22 |
+
if [ "$INSTALL_TDC" = "1" ]; then \
|
| 23 |
+
uv sync --no-editable --extra tdc; \
|
| 24 |
+
else \
|
| 25 |
+
uv sync --no-editable; \
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
FROM ${BASE_IMAGE}
|
| 29 |
+
|
| 30 |
+
WORKDIR /app
|
| 31 |
+
|
| 32 |
+
RUN apt-get update && \
|
| 33 |
+
apt-get install -y --no-install-recommends curl && \
|
| 34 |
+
rm -rf /var/lib/apt/lists/*
|
| 35 |
+
|
| 36 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 37 |
+
COPY --from=builder /app/env /app/env
|
| 38 |
+
|
| 39 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 40 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 41 |
+
|
| 42 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 43 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 44 |
+
|
| 45 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
EVALUATION_PROTOCOL.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MolForge Evaluation Protocol
|
| 2 |
+
|
| 3 |
+
Use two reward settings for different purposes.
|
| 4 |
+
|
| 5 |
+
## 1. Training / RL Warmup
|
| 6 |
+
|
| 7 |
+
Use curriculum mode:
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
MOLFORGE_REWARD_MODE=curriculum
|
| 11 |
+
MOLFORGE_TRAINING_RANDOMIZATION=1
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
Track:
|
| 15 |
+
|
| 16 |
+
- mean episode reward;
|
| 17 |
+
- valid JSON/action rate;
|
| 18 |
+
- policy veto rate;
|
| 19 |
+
- evidence score;
|
| 20 |
+
- number of oracle calls;
|
| 21 |
+
- budget remaining at submit;
|
| 22 |
+
- submit rate;
|
| 23 |
+
- missed-nomination rate;
|
| 24 |
+
- strict terminal `submission_score`.
|
| 25 |
+
|
| 26 |
+
Curriculum reward is allowed to be generous because its purpose is learning.
|
| 27 |
+
It rewards useful evidence collection and evidence-supported submit timing.
|
| 28 |
+
|
| 29 |
+
## 2. Judge-Facing Evaluation
|
| 30 |
+
|
| 31 |
+
Use strict/default mode:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
unset MOLFORGE_TRAINING_RANDOMIZATION
|
| 35 |
+
export MOLFORGE_REWARD_MODE=assay_gated
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Report:
|
| 39 |
+
|
| 40 |
+
- `average_submission_score`;
|
| 41 |
+
- `average_final_score`;
|
| 42 |
+
- per-task `final_score`;
|
| 43 |
+
- per-task `submission_score`;
|
| 44 |
+
- `candidate_score`;
|
| 45 |
+
- `progress_score`;
|
| 46 |
+
- `constraint_margin_score`;
|
| 47 |
+
- `evidence_score`;
|
| 48 |
+
- `coordination_score`;
|
| 49 |
+
- `budget_score`;
|
| 50 |
+
- submitted vs not submitted;
|
| 51 |
+
- invalid action count;
|
| 52 |
+
- policy veto count.
|
| 53 |
+
|
| 54 |
+
The official score should not be minimum number of steps. Real drug discovery
|
| 55 |
+
does not reward the fastest project if it skips necessary evidence. Instead,
|
| 56 |
+
MolForge rewards finishing within the available budget and decision horizon.
|
| 57 |
+
`final_score` is the single scalar to optimize and headline. It equals
|
| 58 |
+
`submission_score` for submitted episodes and gives only small capped partial
|
| 59 |
+
credit to non-submitted episodes. `progress_score` is useful for debugging but
|
| 60 |
+
is not a substitute for `final_score` or `submission_score`: it is capped when
|
| 61 |
+
constraints fail, when the hard trap scenario is not restarted, or when the
|
| 62 |
+
model loops through repeated assays and vetoes.
|
| 63 |
+
|
| 64 |
+
## Budget And Step Interpretation
|
| 65 |
+
|
| 66 |
+
MolForge has both:
|
| 67 |
+
|
| 68 |
+
- `max_steps`: the project decision deadline;
|
| 69 |
+
- `remaining_budget`: the assay/resource budget.
|
| 70 |
+
|
| 71 |
+
The agent must finish inside both limits.
|
| 72 |
+
|
| 73 |
+
Budget effects:
|
| 74 |
+
|
| 75 |
+
- assays subtract from `remaining_budget`;
|
| 76 |
+
- over-budget assays are invalid;
|
| 77 |
+
- budget exhaustion terminates the episode;
|
| 78 |
+
- valid submissions receive a transition-level `budget_efficiency` reward;
|
| 79 |
+
- formal `submission_score` receives a small bonus for unused budget only when
|
| 80 |
+
the submission has required evidence, passes constraints, and beats baseline;
|
| 81 |
+
- curriculum near-miss reward includes `budget_score`, but missed nomination is
|
| 82 |
+
penalized if the evidence package was ready and the model failed to submit.
|
| 83 |
+
|
| 84 |
+
Step effects:
|
| 85 |
+
|
| 86 |
+
- reaching `max_steps` without submission ends the episode;
|
| 87 |
+
- there is a step-limit penalty;
|
| 88 |
+
- no extra score is given merely for fewer steps;
|
| 89 |
+
- faster is better only if the candidate is supported by evidence and budget is
|
| 90 |
+
preserved.
|
| 91 |
+
|
| 92 |
+
## Recommended Comparison Table
|
| 93 |
+
|
| 94 |
+
For the README/demo, compare:
|
| 95 |
+
|
| 96 |
+
| Model | Reward mode | Submit rate | Avg final_score | Avg submission_score | Avg evidence_score | Avg budget_score | Veto rate |
|
| 97 |
+
| --- | --- | ---: | ---: | ---: | ---: | ---: |
|
| 98 |
+
| Base model | assay_gated | low | low | low | low/medium | variable | high |
|
| 99 |
+
| SFT v4 | assay_gated | better | better | better | better | variable | lower |
|
| 100 |
+
| SFT v4 + RL | assay_gated | best | best | best | high | healthy | low |
|
| 101 |
+
|
| 102 |
+
For training plots, show curriculum reward increasing, but always pair it with
|
| 103 |
+
strict `submission_score` before/after so the improvement is credible.
|
HF_RL_JOBS_NOTES.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face RL Jobs Notes
|
| 2 |
+
|
| 3 |
+
This file tracks the remote RL training attempts for the MolForge OpenEnv GRPO run.
|
| 4 |
+
|
| 5 |
+
## Jobs Tried
|
| 6 |
+
|
| 7 |
+
| Job | Hardware | Result | Notes |
|
| 8 |
+
| --- | --- | --- | --- |
|
| 9 |
+
| `69ed7260d70108f37acdf4b8` | `a100-large` | Canceled | Stayed in `SCHEDULING`, so we canceled it before it used GPU time. |
|
| 10 |
+
| `69ed73d3d70108f37acdf4e1` | `l40sx1` | Failed | Started but exited during Python import before model load or training. |
|
| 11 |
+
| `69ed74f6d70108f37acdf504` | `l40sx1` | **Failed** | `--with mergekit` caused unsolvable pydantic conflict with `openenv-core`. |
|
| 12 |
+
| `69ed7be5d2c8bd8662bcef00` | `l40sx1` | Canceled | Incorrect CLI usage (missing image name). |
|
| 13 |
+
| `69ed9440d70108f37acdf83b` | `l40sx1` | Failed | `uv run` couldn't find the script path `issue/script.py`. |
|
| 14 |
+
| `69ed94add2c8bd8662bcf215` | `l40sx1` | Submitted | Fixed script path to just filename and used explicit `python` call. |
|
| 15 |
+
|
| 16 |
+
## Failure History
|
| 17 |
+
|
| 18 |
+
### Job 2 (`69ed73d3`) — `ModuleNotFoundError: No module named 'mergekit'`
|
| 19 |
+
|
| 20 |
+
TRL internally imports `mergekit` for GRPO model-merging callbacks even though we don't use merging. The fix was to add `--with mergekit`.
|
| 21 |
+
|
| 22 |
+
### Job 3 (`69ed74f6`) — **pydantic version conflict** (CURRENT)
|
| 23 |
+
|
| 24 |
+
Adding `--with mergekit` broke the resolver:
|
| 25 |
+
|
| 26 |
+
- `mergekit` (all versions) requires `pydantic < 2.11`
|
| 27 |
+
- `openenv-core==0.2.3` → `fastmcp>=3.0.0` → `pydantic >= 2.11.7`
|
| 28 |
+
|
| 29 |
+
**No version of pydantic satisfies both.** uv correctly refuses to resolve.
|
| 30 |
+
|
| 31 |
+
## Fix
|
| 32 |
+
|
| 33 |
+
**Do NOT pass `--with mergekit`** in the HF Jobs command. Instead, the script now installs mergekit at runtime with `--no-deps` before importing TRL:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
try:
|
| 37 |
+
import mergekit
|
| 38 |
+
except ImportError:
|
| 39 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "mergekit", "--no-deps", "-q"])
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
This makes `mergekit` importable (satisfying TRL) without pulling in its conflicting pydantic constraint.
|
| 43 |
+
|
| 44 |
+
## Checkpoint and Artifact Persistence
|
| 45 |
+
|
| 46 |
+
The OpenEnv GRPO script saves the final trained adapter and tokenizer to:
|
| 47 |
+
|
| 48 |
+
```text
|
| 49 |
+
<run_dir>/adapters/
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
It also writes logs, metrics, plots, before/after evaluator JSON, and a zip archive under the run directory. When `HF_OUTPUT_REPO=Adhitya122/molforge-rl-runs` is set, the full run folder is uploaded to:
|
| 53 |
+
|
| 54 |
+
```text
|
| 55 |
+
hf://datasets/Adhitya122/molforge-rl-runs/<run_name>
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## Safer Next Runs
|
| 59 |
+
|
| 60 |
+
Recommended next HF Jobs command (NO `--with mergekit`):
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
--env RL_MAX_STEPS=20
|
| 64 |
+
--env RL_DATASET_SIZE=30
|
| 65 |
+
--env MAX_COMPLETION_LENGTH=1024
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Use this as a smoke run first. Once it reaches at least one trainer log line and uploads artifacts, scale back to:
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
--env RL_MAX_STEPS=80
|
| 72 |
+
--env RL_DATASET_SIZE=120
|
| 73 |
+
--env MAX_COMPLETION_LENGTH=2048
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Good hardware choices:
|
| 77 |
+
|
| 78 |
+
| Hardware | Use |
|
| 79 |
+
| --- | --- |
|
| 80 |
+
| `l40sx1` | Best next smoke test: 48 GB VRAM, cheaper than A100. |
|
| 81 |
+
| `a100-large` | Good full run if scheduling is available. |
|
| 82 |
+
| `h200` | Highest headroom, more expensive, useful if A100 scheduling stalls. |
|
| 83 |
+
| `a10g-large` | Cheap fallback, but may need shorter completion length and fewer steps. |
|
| 84 |
+
|
| 85 |
+
## Monitoring Commands
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
hf jobs inspect <job_id>
|
| 89 |
+
hf jobs logs <job_id> --tail 200
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Use logs without `inspect` when searching for the real traceback, because `inspect` prints the full base64-encoded submitted script and makes the useful error harder to see.
|
README.md
CHANGED
|
@@ -1,10 +1,163 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: MolForge
|
| 3 |
+
emoji: 🧪
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 8000
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# MolForge
|
| 11 |
+
|
| 12 |
+
This repository implements an OpenEnv-compatible reinforcement learning environment for **medicinal chemistry lead optimization**. The agent does not directly see the true biological properties of the candidate molecule. Instead, a specialist team iteratively edits a KRAS G12C candidate under limited assay budget, partial observability, and strict safety constraints, receiving a noisy simulated output, and is rewarded for discovering a highly potent, synthesizable, and safe drug candidate.
|
| 13 |
+
|
| 14 |
+
The environment is designed as a **partially observable Markov decision process (POMDP)** with:
|
| 15 |
+
- hidden ground-truth molecular properties and scenario constraints
|
| 16 |
+
- hidden target mutation traps (e.g. KRAS resistance panel shifts)
|
| 17 |
+
- visible task metadata, team communication, assay results, and remaining budget
|
| 18 |
+
- simulated `RDKit` descriptors and `TDC` (Therapeutics Data Commons) predictions (QED, SA_Score, LogP, TPSA)
|
| 19 |
+
- dense step-wise reward (in curriculum mode) plus terminal reward for submission quality
|
| 20 |
+
|
| 21 |
+
At a high level, each episode looks like this:
|
| 22 |
+
1. `reset()` picks a biological scenario (e.g. `level_1_medium`) and seeds the simulator.
|
| 23 |
+
2. The agent receives a `MolForgeObservation` describing the task, the starting molecule scaffold, and the current visible state.
|
| 24 |
+
3. The agent (acting as different roles) submits a `MolForgeAction` such as `edit`, `run_assay`, `propose_nomination`, or `submit`.
|
| 25 |
+
4. The **Governance rule engine** checks whether the action is valid, requiring multi-agent consensus for final decisions.
|
| 26 |
+
5. The transition engine updates the molecule, spends the assay budget, and returns oracle readings.
|
| 27 |
+
6. The reward computer scores the step based on whether the action was invalid, vetoed, or successful.
|
| 28 |
+
7. The environment returns a new observation with updated history, assay readings, and reward.
|
| 29 |
+
8. The episode ends when the agent successfully submits the molecule, exhausts its budget, or reaches the maximum step horizon.
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Hidden state vs Visible state
|
| 34 |
+
|
| 35 |
+
### Hidden state
|
| 36 |
+
The simulator keeps ground-truth properties that the agent never directly sees. It contains:
|
| 37 |
+
- The true underlying scoring functions for `potency`, `safety`, and `synthesizability`.
|
| 38 |
+
- Sunk-cost traps and late-stage target mutations (e.g., in `level_2_hard`).
|
| 39 |
+
- The strict constraints required for a valid submission.
|
| 40 |
+
- The remaining hidden milestones for the scenario.
|
| 41 |
+
|
| 42 |
+
### Visible state
|
| 43 |
+
The agent only sees `MolForgeObservation`, which includes:
|
| 44 |
+
- The current `TaskSpec` and `scenario_id`.
|
| 45 |
+
- Pipeline history and previous actions.
|
| 46 |
+
- The current molecular scaffold (in SMILES format).
|
| 47 |
+
- The `budget_used` and `remaining_budget`.
|
| 48 |
+
- Responses from the `run_assay` oracle (TDC predictors and RDKit descriptors).
|
| 49 |
+
- The `GovernanceStatus` showing which specialist agents have approved or objected.
|
| 50 |
+
- The `step_reward_breakdown`.
|
| 51 |
+
|
| 52 |
+
This separation is what makes the environment a POMDP rather than a fully observed simulator.
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## Repository files navigation
|
| 57 |
+
|
| 58 |
+
### `models.py`
|
| 59 |
+
Defines the Pydantic contracts that all other modules use:
|
| 60 |
+
- `MolForgeAction`: One structured step chosen by the agent. Fields include `action_type`, `acting_role`, `tool_name`, `slot`, `fragment`, and `rationale`.
|
| 61 |
+
- `MolForgeObservation`: What the agent can see after each step; includes `current_molecule`, `last_transition_summary`, `reward_breakdown`, and `governance_status`.
|
| 62 |
+
- `MolForgeState`: The internal tracked state including `episode_id`, `step_count`, and `invalid_action_count`.
|
| 63 |
+
|
| 64 |
+
### `server/scenarios.py`
|
| 65 |
+
This is where episodes come from. It defines a curated library of three biological scenarios, each bundling a starting scaffold, a budget, and a specific molecular target:
|
| 66 |
+
- `level_0_easy`: Potency-first optimization with a generous budget and a starting scaffold that is one or two edits from success.
|
| 67 |
+
- `level_1_medium`: Multi-objective optimization with safety as a hard constraint and moderate budget pressure.
|
| 68 |
+
- `level_2_hard`: A sunk-cost trap plus late target mutation. The initial scaffold family has a hidden liability, and the best policy is often to restart early.
|
| 69 |
+
|
| 70 |
+
### `server/actions.py` & `server/governance.py`
|
| 71 |
+
The rule engines enforcing scientific and procedural constraints before each action is applied:
|
| 72 |
+
- `run_assay`: Costs budget. Assembles the fragments into a valid `SMILES` string and evaluates the current molecule using `TDC` Oracles and `RDKit` logic (e.g. `MolLogP`, `TPSA`, `NumRotatableBonds`, `QED`).
|
| 73 |
+
- `edit`: Replaces a specific R-group slot (`warhead`, `hinge`, `solvent_tail`, `back_pocket`) with a new chemical fragment (e.g. `acrylamide`, `fluorophenyl`, `morpholine`). Clears previously gathered evidence.
|
| 74 |
+
- `submit`: Ends the episode. Triggers the final evaluation grader against the scenario's strict hard constraints (`potency_min`, `toxicity_max`, `synth_min`).
|
| 75 |
+
- **Governance**: Certain actions require multi-agent consensus. If the `Lead Chemist` tries to submit without the `Safety Specialist`'s approval, the action is vetoed.
|
| 76 |
+
|
| 77 |
+
### `server/molforge_environment.py`
|
| 78 |
+
This is the orchestration layer that ties everything together.
|
| 79 |
+
On `reset()` it:
|
| 80 |
+
- Generates a task scenario.
|
| 81 |
+
- Clears the message log, history, and resets the molecule to the default scaffold.
|
| 82 |
+
|
| 83 |
+
On `step()` it:
|
| 84 |
+
- Checks governance rules and validates the action.
|
| 85 |
+
- Executes the action (e.g. replacing an R-group fragment or running an assay).
|
| 86 |
+
- Computes reward (via Curriculum or Assay-Gated mode).
|
| 87 |
+
- Builds the next `MolForgeObservation`.
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## What actually happens on one step
|
| 92 |
+
Here is the concrete order of operations for `env.step(action)`:
|
| 93 |
+
1. Increment the step counter.
|
| 94 |
+
2. Run validation checks. If the action format is invalid, return a failure report and a `-1.0` reward.
|
| 95 |
+
3. Assess **Governance**. If a required specialist agent vetoes the action, the action is blocked and penalized.
|
| 96 |
+
4. Execute the action (`edit`, `run_assay`, `submit`).
|
| 97 |
+
5. Deduct oracle budget if `run_assay` was called.
|
| 98 |
+
6. Compute decomposed reward from the state transition (e.g., getting penalized for redundant assays).
|
| 99 |
+
7. If the episode is ending (via `submit`, max steps, or zero budget), compute the terminal `submission_score`.
|
| 100 |
+
8. Return an observation that exposes the visible summary but not the hidden truth.
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## Typical successful pipeline
|
| 105 |
+
Most scenarios reward a sensible experiment order similar to:
|
| 106 |
+
1. `run_assay` (Assay potency and safety of the baseline molecule).
|
| 107 |
+
2. `edit` (Swap an R-group fragment to improve a weak property).
|
| 108 |
+
3. `run_assay` (Gather new evidence for the modified molecule).
|
| 109 |
+
4. `propose_nomination` (Discuss the findings with the multi-agent review board).
|
| 110 |
+
5. `submit` (Finalize the candidate).
|
| 111 |
+
|
| 112 |
+
The exact best sequence depends on the scenario. In `level_2_hard`, the best strategy is often to `restart` the entire scaffold immediately rather than wasting budget on a doomed trajectory.
|
| 113 |
+
|
| 114 |
+
---
|
| 115 |
+
|
| 116 |
+
## Reward Strategy & Episode termination
|
| 117 |
+
|
| 118 |
+
MolForge uses two distinct reward settings for different purposes:
|
| 119 |
+
|
| 120 |
+
**1. Training / RL Warmup (`MOLFORGE_REWARD_MODE=curriculum`)**
|
| 121 |
+
- Gives partial credit at the end of an episode even if the model didn't submit, provided it gathered useful evidence.
|
| 122 |
+
- It actively prevents "reward hacking" by penalizing assay-spamming, and giving massive multipliers to successful submissions.
|
| 123 |
+
|
| 124 |
+
**2. Judge-Facing Evaluation (`MOLFORGE_REWARD_MODE=assay_gated`)**
|
| 125 |
+
- Strict OpenEnv hackathon rules.
|
| 126 |
+
- If the agent does not formally `submit` the candidate, the final score is `0.0`.
|
| 127 |
+
- No partial credit is given for just gathering evidence.
|
| 128 |
+
|
| 129 |
+
An episode ends when one of the following happens:
|
| 130 |
+
- The agent explicitly chooses `submit`.
|
| 131 |
+
- Resources (oracle budget) are exhausted.
|
| 132 |
+
- The environment reaches `MAX_STEPS`.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
## Installation & Usage
|
| 137 |
+
The package requires Python ≥ 3.10.
|
| 138 |
+
```bash
|
| 139 |
+
pip install "openenv-core[core]>=0.2.3" pydantic transformers trl peft datasets
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### 1. In-process environment
|
| 143 |
+
Use `MolForgeEnvironment` when you want direct Python access with full structured observations:
|
| 144 |
+
```python
|
| 145 |
+
from models import MolForgeAction
|
| 146 |
+
from server.molforge_environment import MolForgeEnvironment
|
| 147 |
+
|
| 148 |
+
env = MolForgeEnvironment()
|
| 149 |
+
obs = env.reset()
|
| 150 |
+
|
| 151 |
+
action = MolForgeAction(
|
| 152 |
+
action_type="run_assay",
|
| 153 |
+
acting_role="Lead Chemist",
|
| 154 |
+
tool_name="potency_oracle",
|
| 155 |
+
rationale="Need to gather baseline potency evidence."
|
| 156 |
+
)
|
| 157 |
+
obs = env.step(action)
|
| 158 |
+
print(obs.reward)
|
| 159 |
+
print(obs.last_transition_summary)
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
### 2. RL Training Notebook
|
| 163 |
+
We have provided a cleanly documented `issue/molforge_grpo_official_submission.ipynb` which demonstrates exactly how to fine-tune a Qwen3.5 model using TRL's GRPO trainer natively against this OpenEnv environment.
|
REAL_WORLD_WORKFLOW_MAPPING.md
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MolForge Real-World Workflow Mapping
|
| 2 |
+
|
| 3 |
+
MolForge should feel like a compressed medicinal-chemistry lead-optimization
|
| 4 |
+
program, not a one-shot molecule generator.
|
| 5 |
+
|
| 6 |
+
The real-world pattern is:
|
| 7 |
+
|
| 8 |
+
1. A team starts with a scaffold.
|
| 9 |
+
2. Chemists propose edits based on structure-activity reasoning.
|
| 10 |
+
3. Assay teams spend limited budget to measure uncertain properties.
|
| 11 |
+
4. Safety and process specialists veto risky or impractical candidates.
|
| 12 |
+
5. The team decides whether to keep optimizing, restart, or nominate a lead.
|
| 13 |
+
6. Success depends on evidence, not only on the final molecule.
|
| 14 |
+
|
| 15 |
+
This is exactly the shape MolForge should copy.
|
| 16 |
+
|
| 17 |
+
## Real-World Loop
|
| 18 |
+
|
| 19 |
+
### 1. Design Hypothesis
|
| 20 |
+
|
| 21 |
+
Real teams do not mutate molecules randomly. A medicinal chemist proposes a
|
| 22 |
+
change with an intended purpose:
|
| 23 |
+
|
| 24 |
+
- improve potency;
|
| 25 |
+
- reduce toxicity;
|
| 26 |
+
- improve solubility or ADME;
|
| 27 |
+
- simplify synthesis;
|
| 28 |
+
- escape a known scaffold liability.
|
| 29 |
+
|
| 30 |
+
MolForge equivalent:
|
| 31 |
+
|
| 32 |
+
- `edit`
|
| 33 |
+
- `rationale`
|
| 34 |
+
- `expected_effects`
|
| 35 |
+
- `evidence`
|
| 36 |
+
|
| 37 |
+
The model should not only choose a fragment. It should say what scientific
|
| 38 |
+
pressure that edit is meant to address.
|
| 39 |
+
|
| 40 |
+
### 2. Cheap Triage Before Expensive Assays
|
| 41 |
+
|
| 42 |
+
Real projects usually run cheap computational or low-cost screens before
|
| 43 |
+
expensive experiments.
|
| 44 |
+
|
| 45 |
+
MolForge equivalent:
|
| 46 |
+
|
| 47 |
+
- `evaluate_properties`
|
| 48 |
+
- `search_literature`
|
| 49 |
+
- `estimate_synthesizability`
|
| 50 |
+
- `dock_target`
|
| 51 |
+
|
| 52 |
+
These should be useful but imperfect. They help the model decide where to spend
|
| 53 |
+
more serious assay budget.
|
| 54 |
+
|
| 55 |
+
### 3. Expensive Evidence Gates
|
| 56 |
+
|
| 57 |
+
Real lead candidates require stronger evidence before nomination:
|
| 58 |
+
|
| 59 |
+
- potency evidence;
|
| 60 |
+
- toxicity/safety evidence;
|
| 61 |
+
- synthesis or route feasibility evidence;
|
| 62 |
+
- sometimes post-mutation or resistance-panel evidence.
|
| 63 |
+
|
| 64 |
+
MolForge equivalent:
|
| 65 |
+
|
| 66 |
+
- `assay_toxicity`
|
| 67 |
+
- `dock_target`
|
| 68 |
+
- `estimate_synthesizability`
|
| 69 |
+
- hard evidence requirements in `submit`
|
| 70 |
+
- `evidence_score`
|
| 71 |
+
|
| 72 |
+
This is why `submission_score` should remain strict. A molecule that looks good
|
| 73 |
+
but was never properly assayed is not a real lead candidate.
|
| 74 |
+
|
| 75 |
+
### 4. Cross-Functional Decision Board
|
| 76 |
+
|
| 77 |
+
Real projects are not controlled by one chemist. A lead-optimization meeting
|
| 78 |
+
usually includes:
|
| 79 |
+
|
| 80 |
+
- medicinal chemistry;
|
| 81 |
+
- assay biology;
|
| 82 |
+
- toxicology/safety;
|
| 83 |
+
- process chemistry or manufacturability;
|
| 84 |
+
- project leadership.
|
| 85 |
+
|
| 86 |
+
MolForge equivalent:
|
| 87 |
+
|
| 88 |
+
- `lead_chemist`
|
| 89 |
+
- `assay_planner`
|
| 90 |
+
- `toxicologist`
|
| 91 |
+
- `process_chemist`
|
| 92 |
+
- governance messages;
|
| 93 |
+
- hard vetoes;
|
| 94 |
+
- `coordination_score`
|
| 95 |
+
|
| 96 |
+
This is one of MolForge's strongest environment-innovation points. The agent is
|
| 97 |
+
not just optimizing a molecule; it is coordinating a scientific team.
|
| 98 |
+
|
| 99 |
+
### 5. Stop, Submit, or Restart
|
| 100 |
+
|
| 101 |
+
Real teams must decide when to stop spending money. Sometimes the right answer
|
| 102 |
+
is to abandon a scaffold early because the series is a trap.
|
| 103 |
+
|
| 104 |
+
MolForge equivalent:
|
| 105 |
+
|
| 106 |
+
- `submit`
|
| 107 |
+
- `restart`
|
| 108 |
+
- budget limits;
|
| 109 |
+
- max decision horizon;
|
| 110 |
+
- hard scenario target shift;
|
| 111 |
+
- sunk-cost trap in `level_2_hard`
|
| 112 |
+
|
| 113 |
+
This lets the environment test project judgment, not just local molecule edits.
|
| 114 |
+
|
| 115 |
+
## How To Use This In MolForge
|
| 116 |
+
|
| 117 |
+
### Keep Two Scores
|
| 118 |
+
|
| 119 |
+
Use two kinds of reward:
|
| 120 |
+
|
| 121 |
+
1. **Training reward**
|
| 122 |
+
Helps the model learn the workflow.
|
| 123 |
+
|
| 124 |
+
2. **Formal submission score**
|
| 125 |
+
Measures whether the agent actually nominated a valid candidate.
|
| 126 |
+
|
| 127 |
+
That means:
|
| 128 |
+
|
| 129 |
+
- `MOLFORGE_REWARD_MODE=curriculum` for early RL;
|
| 130 |
+
- default `assay_gated` mode for final reporting;
|
| 131 |
+
- `submission_score` stays `0.0` without a formal submit.
|
| 132 |
+
|
| 133 |
+
This mirrors the real world: a project can make progress without nominating a
|
| 134 |
+
lead, but it cannot claim lead success without a nomination package.
|
| 135 |
+
|
| 136 |
+
### Make Rewards Stage-Gated
|
| 137 |
+
|
| 138 |
+
A good real-world reward should not be one giant final number only.
|
| 139 |
+
|
| 140 |
+
Useful reward components:
|
| 141 |
+
|
| 142 |
+
- valid action/schema;
|
| 143 |
+
- useful design edit;
|
| 144 |
+
- useful first assay;
|
| 145 |
+
- evidence coverage;
|
| 146 |
+
- safety improvement;
|
| 147 |
+
- synthesis improvement;
|
| 148 |
+
- avoiding repeated assays;
|
| 149 |
+
- avoiding vetoed decisions;
|
| 150 |
+
- submitting only with enough support;
|
| 151 |
+
- restarting from a bad scaffold when appropriate.
|
| 152 |
+
|
| 153 |
+
This gives RL a learnable path while preserving strict final success.
|
| 154 |
+
|
| 155 |
+
### Make The Demo Story Simple
|
| 156 |
+
|
| 157 |
+
Judges should understand this in one sentence:
|
| 158 |
+
|
| 159 |
+
> MolForge tests whether an LLM can run a miniature drug-discovery project:
|
| 160 |
+
> design molecules, buy assays, respect safety vetoes, manage budget, and
|
| 161 |
+
> nominate a candidate only when the evidence package is strong enough.
|
| 162 |
+
|
| 163 |
+
Then show:
|
| 164 |
+
|
| 165 |
+
- baseline model repeats invalid or vetoed actions;
|
| 166 |
+
- SFT model learns the action language;
|
| 167 |
+
- RL model learns better evidence and submit timing;
|
| 168 |
+
- final candidate report card shows potency, toxicity, synthesis, evidence,
|
| 169 |
+
budget, and coordination.
|
| 170 |
+
|
| 171 |
+
## What We Already Have
|
| 172 |
+
|
| 173 |
+
MolForge already contains most of this real-world structure:
|
| 174 |
+
|
| 175 |
+
- molecule slot edits;
|
| 176 |
+
- RDKit/TDC-backed surrogate oracle path;
|
| 177 |
+
- limited assay budget;
|
| 178 |
+
- cheap and expensive tools;
|
| 179 |
+
- hidden true properties;
|
| 180 |
+
- visible assay estimates;
|
| 181 |
+
- toxicity and synthesis constraints;
|
| 182 |
+
- multi-agent specialist governance;
|
| 183 |
+
- safety vetoes;
|
| 184 |
+
- restart action;
|
| 185 |
+
- hard target-shift scenario;
|
| 186 |
+
- decomposed report card;
|
| 187 |
+
- strict terminal `submission_score`;
|
| 188 |
+
- curriculum reward mode for early RL.
|
| 189 |
+
|
| 190 |
+
## What To Strengthen Next
|
| 191 |
+
|
| 192 |
+
The next useful additions should make the environment feel even more like a
|
| 193 |
+
real project:
|
| 194 |
+
|
| 195 |
+
1. **Assay uncertainty**
|
| 196 |
+
Repeated assays should narrow confidence intervals, but cost budget.
|
| 197 |
+
|
| 198 |
+
2. **Stage labels**
|
| 199 |
+
Mark states as `design`, `triage`, `evidence_package`, `nomination`, or
|
| 200 |
+
`no-go`.
|
| 201 |
+
|
| 202 |
+
3. **No-go decisions**
|
| 203 |
+
Reward a model for stopping or restarting when the evidence says the series
|
| 204 |
+
is unsafe or infeasible.
|
| 205 |
+
|
| 206 |
+
4. **Portfolio-style report**
|
| 207 |
+
At terminal time, show why the candidate was nominated or rejected.
|
| 208 |
+
|
| 209 |
+
5. **Holdout variants**
|
| 210 |
+
Randomize scaffold starts and budgets so the model cannot memorize only
|
| 211 |
+
three paths.
|
| 212 |
+
|
| 213 |
+
For the hackathon, the best near-term path is:
|
| 214 |
+
|
| 215 |
+
```text
|
| 216 |
+
SFT v4 for action/workflow competence
|
| 217 |
+
-> curriculum RL for observable reward improvement
|
| 218 |
+
-> strict assay_gated evaluation for final submission_score
|
| 219 |
+
-> README/demo framed as a real drug-discovery decision board
|
| 220 |
+
```
|
| 221 |
+
|
RL_TRAINING_COLAB.md
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MolForge RL Training in Colab
|
| 2 |
+
|
| 3 |
+
Use [issue/molforge_grpo_colab_training.ipynb](issue/molforge_grpo_colab_training.ipynb) for the judge-rerunnable workflow.
|
| 4 |
+
|
| 5 |
+
The notebook trains from the Qwen3.5 2B SFT v4 adapter with TRL GRPO against the real MolForge environment reward. It uses the TRL/OpenEnv `environment_factory` pattern from the Wordle/Sudoku examples: MolForge exposes tool methods for `edit`, `run_assay`, `submit`, `restart`, and `defer`, and reward functions read scores from the environment instances. It is set up for short evidence runs on A100/H100 rather than full convergence.
|
| 6 |
+
|
| 7 |
+
## Outputs
|
| 8 |
+
|
| 9 |
+
Each run writes to `/content/molforge_rl_runs/<run_name>/` and copies the same folder to `DRIVE_OUTPUT_DIR` when set.
|
| 10 |
+
|
| 11 |
+
Important artifacts:
|
| 12 |
+
|
| 13 |
+
- `logs/openenv_tool_rollouts.jsonl`: every tool call, reward, governance status, and score diagnostics.
|
| 14 |
+
- `logs/trainer_log_history.jsonl`: trainer loss, grad norm, learning rate, and step timing.
|
| 15 |
+
- `openenv_tool_metrics.csv`: spreadsheet-friendly tool rollout reward table.
|
| 16 |
+
- `eval_before_training.json`: full 3-task rollout before GRPO.
|
| 17 |
+
- `eval_after_training.json`: full 3-task rollout after GRPO.
|
| 18 |
+
- `plots/reward_curve.png`: completion reward curve and moving average.
|
| 19 |
+
- `plots/loss_curve.png`: trainer loss curve.
|
| 20 |
+
- `plots/eval_before_after.png`: before/after final_score comparison.
|
| 21 |
+
- `plots/action_distribution.png`: sampled action mix.
|
| 22 |
+
- `adapters/`: trained LoRA adapter checkpoint.
|
| 23 |
+
- `<run_name>.zip`: portable archive of the run outputs.
|
| 24 |
+
|
| 25 |
+
## Fast Demo Settings
|
| 26 |
+
|
| 27 |
+
For a quick A100/H100 proof run:
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
os.environ["RL_MAX_STEPS"] = "80"
|
| 31 |
+
os.environ["NUM_GENERATIONS"] = "2"
|
| 32 |
+
os.environ["RL_DATASET_SIZE"] = "120"
|
| 33 |
+
os.environ["RL_BATCH_SIZE"] = "2"
|
| 34 |
+
os.environ["RL_GRAD_ACCUM"] = "4"
|
| 35 |
+
os.environ["RL_LEARNING_RATE"] = "2e-6"
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
For a stronger run, try `RL_MAX_STEPS=200` and `NUM_GENERATIONS=4` on H100.
|
| 39 |
+
If Colab runs out of memory, reduce `MAX_COMPLETION_LENGTH` to `1024`; keep `RL_BATCH_SIZE` divisible by `NUM_GENERATIONS`.
|
| 40 |
+
|
| 41 |
+
If TRL fails during import with `No module named 'mergekit'`, install `mergekit` in the same setup cell as `trl`.
|
| 42 |
+
|
| 43 |
+
## What to Show Judges
|
| 44 |
+
|
| 45 |
+
Use the before/after rollout JSON plus these plots:
|
| 46 |
+
|
| 47 |
+
- `reward_curve.png` for reward improvement during RL.
|
| 48 |
+
- `loss_curve.png` for actual training evidence.
|
| 49 |
+
- `eval_before_after.png` for task-level behavior change.
|
| 50 |
+
|
| 51 |
+
The official environment score remains `final_score`; `progress_score` and per-step rewards are debugging signals.
|
Requirements_before_submitting.md
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Evaluation Criteria
|
| 3 |
+
|
| 4 |
+
Phase 1: Automated Validation
|
| 5 |
+
|
| 6 |
+
Pass/fail gate — HF Space deploys, OpenEnv spec compliance, Dockerfile builds, baseline reproduces, 3+ tasks with graders.
|
| 7 |
+
|
| 8 |
+
Phase 2: Agentic Evaluation
|
| 9 |
+
|
| 10 |
+
Scored — baseline agent re-run, standard Open LLM agent (e.g. Nemotron 3 Super) run against all environments, score variance check.
|
| 11 |
+
|
| 12 |
+
Phase 3: Human Review
|
| 13 |
+
|
| 14 |
+
Top submissions reviewed by Meta and Hugging Face engineers for real-world utility, creativity, and exploit checks.
|
| 15 |
+
|
| 16 |
+
Disqualification Criteria
|
| 17 |
+
|
| 18 |
+
Environment does not deploy or respond
|
| 19 |
+
|
| 20 |
+
Plagiarized or trivially modified existing environments
|
| 21 |
+
|
| 22 |
+
Graders that always return the same score
|
| 23 |
+
|
| 24 |
+
No baseline inference script
|
| 25 |
+
|
| 26 |
+
How Judging works
|
| 27 |
+
|
| 28 |
+
Pre-Submission Checklist — all must pass or you're disqualified
|
| 29 |
+
|
| 30 |
+
HF Space deploys
|
| 31 |
+
|
| 32 |
+
Automated ping to the Space URL — must return 200 and respond to reset()
|
| 33 |
+
|
| 34 |
+
OpenEnv spec compliance
|
| 35 |
+
|
| 36 |
+
Validate openenv.yaml, typed models, step()/reset()/state() endpoints
|
| 37 |
+
|
| 38 |
+
Dockerfile builds
|
| 39 |
+
|
| 40 |
+
Automated docker build on the submitted repo
|
| 41 |
+
|
| 42 |
+
Baseline reproduces
|
| 43 |
+
|
| 44 |
+
Run the submitted inference script — must complete without error and produce scores
|
| 45 |
+
|
| 46 |
+
3+ tasks with graders
|
| 47 |
+
|
| 48 |
+
Enumerate tasks, run each grader, verify scores in 0.0–1.0 range
|
| 49 |
+
|
| 50 |
+
Additional Instructions
|
| 51 |
+
|
| 52 |
+
Before submitting, ensure the following variables are defined in your environment configuration:
|
| 53 |
+
|
| 54 |
+
API\_BASE\_URL The API endpoint for the LLM.
|
| 55 |
+
|
| 56 |
+
MODEL\_NAME The model identifier to use for inference.
|
| 57 |
+
|
| 58 |
+
HF\_TOKEN Your Hugging Face / API key.
|
| 59 |
+
|
| 60 |
+
The inference script must be named \`inference.py\` and placed in the root directory of the project
|
| 61 |
+
|
| 62 |
+
Participants must use OpenAI Client for all LLM calls using above variables
|
| 63 |
+
|
| 64 |
+
Infra Restrictions
|
| 65 |
+
|
| 66 |
+
Runtime of inference script should be less than 20min
|
| 67 |
+
|
| 68 |
+
Make sure your env and inference can run on a machine with vcpu=2, memory=8gb
|
| 69 |
+
|
| 70 |
+
Validator
|
| 71 |
+
|
| 72 |
+
Run the pre-submission validation script before submitting
|
| 73 |
+
|
| 74 |
+
Sample Inference Script
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
Inference Script Example
|
| 78 |
+
===================================
|
| 79 |
+
MANDATORY
|
| 80 |
+
- Before submitting, ensure the following variables are defined in your environment configuration:
|
| 81 |
+
API_BASE_URL The API endpoint for the LLM.
|
| 82 |
+
MODEL_NAME The model identifier to use for inference.
|
| 83 |
+
HF_TOKEN Your Hugging Face / API key.
|
| 84 |
+
|
| 85 |
+
- The inference script must be named `inference.py` and placed in the root directory of the project
|
| 86 |
+
- Participants must use OpenAI Client for all LLM calls using above variables
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
import os
|
| 90 |
+
import re
|
| 91 |
+
import base64
|
| 92 |
+
import textwrap
|
| 93 |
+
from io import BytesIO
|
| 94 |
+
from typing import List, Optional, Dict
|
| 95 |
+
|
| 96 |
+
from openai import OpenAI
|
| 97 |
+
import numpy as np
|
| 98 |
+
from PIL import Image
|
| 99 |
+
|
| 100 |
+
from browsergym_env import BrowserGymAction, BrowserGymEnv
|
| 101 |
+
|
| 102 |
+
API_BASE_URL = os.getenv("API_BASE_URL") // "https://router.huggingface.co/v1"
|
| 103 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 104 |
+
MODEL_NAME = os.getenv("MODEL_NAME")
|
| 105 |
+
MAX_STEPS = 8
|
| 106 |
+
MAX_DOM_CHARS = 3500
|
| 107 |
+
TEMPERATURE = 0.2
|
| 108 |
+
MAX_TOKENS = 200
|
| 109 |
+
FALLBACK_ACTION = "noop()"
|
| 110 |
+
|
| 111 |
+
DEBUG = True
|
| 112 |
+
ACTION_PREFIX_RE = re.compile(
|
| 113 |
+
r"^(action|next action)\s*[:\-]\s*",
|
| 114 |
+
re.IGNORECASE,
|
| 115 |
+
)
|
| 116 |
+
ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
SYSTEM_PROMPT = textwrap.dedent(
|
| 120 |
+
"""
|
| 121 |
+
You control a web browser through BrowserGym.
|
| 122 |
+
Reply with exactly one action string.
|
| 123 |
+
The action must be a valid BrowserGym command such as:
|
| 124 |
+
- noop()
|
| 125 |
+
- click('<BID>')
|
| 126 |
+
- type('selector', 'text to enter')
|
| 127 |
+
- fill('selector', 'text to enter')
|
| 128 |
+
- send_keys('Enter')
|
| 129 |
+
- scroll('down')
|
| 130 |
+
Use single quotes around string arguments.
|
| 131 |
+
When clicking, use the BrowserGym element IDs (BIDs) listed in the user message.
|
| 132 |
+
If you are unsure, respond with noop().
|
| 133 |
+
Do not include explanations or additional text.
|
| 134 |
+
"""
|
| 135 |
+
).strip()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def build_history_lines(history: List[str]) -> str:
|
| 139 |
+
if not history:
|
| 140 |
+
return "None"
|
| 141 |
+
return "\n".join(history[-4:])
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def extract_screenshot_uri(observation) -> Optional[str]:
|
| 145 |
+
if observation.screenshot is None:
|
| 146 |
+
return None
|
| 147 |
+
screen_array = np.array(observation.screenshot, dtype=np.uint8)
|
| 148 |
+
image = Image.fromarray(screen_array)
|
| 149 |
+
buffer = BytesIO()
|
| 150 |
+
image.save(buffer, format="PNG")
|
| 151 |
+
buffer.seek(0)
|
| 152 |
+
data_uri = base64.b64encode(buffer.read()).decode("utf-8")
|
| 153 |
+
return f"data:image/png;base64,{data_uri}"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def extract_clickable_elements(observation) -> List[Dict[str, str]]:
|
| 157 |
+
"""Collect BrowserGym element IDs that can be clicked."""
|
| 158 |
+
|
| 159 |
+
metadata = getattr(observation, "metadata", {}) or {}
|
| 160 |
+
obs_dict = metadata.get("browsergym_obs", {}) or {}
|
| 161 |
+
extra_props = obs_dict.get("extra_element_properties", {}) or {}
|
| 162 |
+
|
| 163 |
+
clickables: List[Dict[str, str]] = []
|
| 164 |
+
for bid, props in extra_props.items():
|
| 165 |
+
if not props.get("clickable"):
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
bbox = props.get("bbox") or []
|
| 169 |
+
bbox_str = ", ".join(bbox) if bbox else "?"
|
| 170 |
+
clickables.append(
|
| 171 |
+
{
|
| 172 |
+
"bid": str(bid),
|
| 173 |
+
"bbox": bbox_str,
|
| 174 |
+
}
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Keep a stable ordering for readability
|
| 178 |
+
clickables.sort(key=lambda item: item["bid"])
|
| 179 |
+
return clickables
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def build_user_prompt(step: int, observation, history: List[str]) -> str:
|
| 183 |
+
goal = observation.goal or "(not provided)"
|
| 184 |
+
url = observation.url or "(unknown)"
|
| 185 |
+
error_note = "Yes" if observation.last_action_error else "No"
|
| 186 |
+
|
| 187 |
+
clickables = extract_clickable_elements(observation)
|
| 188 |
+
if clickables:
|
| 189 |
+
actions_hint = "\n".join(
|
| 190 |
+
f" - {item['bid']} (bbox: {item['bbox']})" for item in clickables
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
actions_hint = " (none detected)"
|
| 194 |
+
|
| 195 |
+
prompt = textwrap.dedent(
|
| 196 |
+
f"""
|
| 197 |
+
Step: {step}
|
| 198 |
+
Goal: {goal}
|
| 199 |
+
Current URL: {url}
|
| 200 |
+
Previous steps:
|
| 201 |
+
{build_history_lines(history)}
|
| 202 |
+
Last action error: {error_note}
|
| 203 |
+
Available clickable element IDs: {actions_hint}
|
| 204 |
+
Reply with exactly one BrowserGym action string.
|
| 205 |
+
"""
|
| 206 |
+
).strip()
|
| 207 |
+
return prompt
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def parse_model_action(response_text: str) -> str:
|
| 211 |
+
if not response_text:
|
| 212 |
+
return FALLBACK_ACTION
|
| 213 |
+
|
| 214 |
+
# Prefer the first line that looks like an action string
|
| 215 |
+
lines = response_text.splitlines()
|
| 216 |
+
for raw_line in lines:
|
| 217 |
+
line = raw_line.strip()
|
| 218 |
+
if not line:
|
| 219 |
+
continue
|
| 220 |
+
line = ACTION_PREFIX_RE.sub("", line)
|
| 221 |
+
match = ACTION_PATTERN.search(line)
|
| 222 |
+
if match:
|
| 223 |
+
action = match.group(0).strip()
|
| 224 |
+
# Collapse internal whitespace
|
| 225 |
+
action = re.sub(r"\s+", " ", action)
|
| 226 |
+
# If the model tried to click by natural-language description while we
|
| 227 |
+
# only exposed numeric BrowserGym IDs, fallback to the single detected ID.
|
| 228 |
+
return action
|
| 229 |
+
|
| 230 |
+
# Fall back to searching the whole response
|
| 231 |
+
match = ACTION_PATTERN.search(response_text)
|
| 232 |
+
if match:
|
| 233 |
+
action = match.group(0).strip()
|
| 234 |
+
action = re.sub(r"\s+", " ", action)
|
| 235 |
+
return action
|
| 236 |
+
|
| 237 |
+
return FALLBACK_ACTION
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def main() -> None:
|
| 241 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 242 |
+
|
| 243 |
+
env = BrowserGymEnv.from_docker_image(
|
| 244 |
+
image="browsergym-env:latest",
|
| 245 |
+
env_vars={
|
| 246 |
+
"BROWSERGYM_BENCHMARK": "miniwob",
|
| 247 |
+
"BROWSERGYM_TASK_NAME": "click-test",
|
| 248 |
+
},
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
history: List[str] = []
|
| 252 |
+
|
| 253 |
+
try:
|
| 254 |
+
result = env.reset()
|
| 255 |
+
observation = result.observation
|
| 256 |
+
print(f"Episode goal: {observation.goal}")
|
| 257 |
+
|
| 258 |
+
for step in range(1, MAX_STEPS + 1):
|
| 259 |
+
if result.done:
|
| 260 |
+
print("Environment signalled done. Stopping early.")
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
user_prompt = build_user_prompt(step, observation, history)
|
| 264 |
+
user_content = [{"type": "text", "text": user_prompt}]
|
| 265 |
+
screenshot_uri = extract_screenshot_uri(observation)
|
| 266 |
+
if screenshot_uri:
|
| 267 |
+
user_content.append(
|
| 268 |
+
{
|
| 269 |
+
"type": "image_url",
|
| 270 |
+
"image_url": {"url": screenshot_uri},
|
| 271 |
+
}
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
messages = [
|
| 275 |
+
{
|
| 276 |
+
"role": "system",
|
| 277 |
+
"content": [{"type": "text", "text": SYSTEM_PROMPT}],
|
| 278 |
+
},
|
| 279 |
+
{
|
| 280 |
+
"role": "user",
|
| 281 |
+
"content": user_content,
|
| 282 |
+
},
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
completion = client.chat.completions.create(
|
| 287 |
+
model=MODEL_NAME,
|
| 288 |
+
messages=messages,
|
| 289 |
+
temperature=TEMPERATURE,
|
| 290 |
+
max_tokens=MAX_TOKENS,
|
| 291 |
+
stream=False,
|
| 292 |
+
)
|
| 293 |
+
response_text = completion.choices[0].message.content or ""
|
| 294 |
+
# pylint: disable=broad-except
|
| 295 |
+
except Exception as exc: # noqa: BLE001
|
| 296 |
+
failure_msg = f"Model request failed ({exc}). Using fallback action."
|
| 297 |
+
print(failure_msg)
|
| 298 |
+
response_text = FALLBACK_ACTION
|
| 299 |
+
|
| 300 |
+
action_str = parse_model_action(response_text)
|
| 301 |
+
print(f"Step {step}: model suggested -> {action_str}")
|
| 302 |
+
|
| 303 |
+
result = env.step(BrowserGymAction(action_str=action_str))
|
| 304 |
+
observation = result.observation
|
| 305 |
+
|
| 306 |
+
reward = result.reward or 0.0
|
| 307 |
+
error_flag = " ERROR" if observation.last_action_error else ""
|
| 308 |
+
history_line = (
|
| 309 |
+
f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}"
|
| 310 |
+
)
|
| 311 |
+
history.append(history_line)
|
| 312 |
+
print(
|
| 313 |
+
" Reward: "
|
| 314 |
+
f"{reward:+.2f} | Done: {result.done} | Last action error: "
|
| 315 |
+
f"{observation.last_action_error}"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if result.done:
|
| 319 |
+
print("Episode complete.")
|
| 320 |
+
break
|
| 321 |
+
|
| 322 |
+
else:
|
| 323 |
+
print(f"Reached max steps ({MAX_STEPS}).")
|
| 324 |
+
|
| 325 |
+
finally:
|
| 326 |
+
env.close()
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
if __name__ == "__main__":
|
| 330 |
+
main()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
Pre Validation Script
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
#!/usr/bin/env bash
|
| 338 |
+
#
|
| 339 |
+
# validate-submission.sh — OpenEnv Submission Validator
|
| 340 |
+
#
|
| 341 |
+
# Checks that your HF Space is live, Docker image builds, and openenv validate passes.
|
| 342 |
+
#
|
| 343 |
+
# Prerequisites:
|
| 344 |
+
# - Docker: https://docs.docker.com/get-docker/
|
| 345 |
+
# - openenv-core: pip install openenv-core
|
| 346 |
+
# - curl (usually pre-installed)
|
| 347 |
+
#
|
| 348 |
+
# Run:
|
| 349 |
+
# curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
|
| 350 |
+
#
|
| 351 |
+
# Or download and run locally:
|
| 352 |
+
# chmod +x validate-submission.sh
|
| 353 |
+
# ./validate-submission.sh <ping_url> [repo_dir]
|
| 354 |
+
#
|
| 355 |
+
# Arguments:
|
| 356 |
+
# ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
|
| 357 |
+
# repo_dir Path to your repo (default: current directory)
|
| 358 |
+
#
|
| 359 |
+
# Examples:
|
| 360 |
+
# ./validate-submission.sh https://my-team.hf.space
|
| 361 |
+
# ./validate-submission.sh https://my-team.hf.space ./my-repo
|
| 362 |
+
#
|
| 363 |
+
|
| 364 |
+
set -uo pipefail
|
| 365 |
+
|
| 366 |
+
DOCKER_BUILD_TIMEOUT=600
|
| 367 |
+
if [ -t 1 ]; then
|
| 368 |
+
RED='\033[0;31m'
|
| 369 |
+
GREEN='\033[0;32m'
|
| 370 |
+
YELLOW='\033[1;33m'
|
| 371 |
+
BOLD='\033[1m'
|
| 372 |
+
NC='\033[0m'
|
| 373 |
+
else
|
| 374 |
+
RED='' GREEN='' YELLOW='' BOLD='' NC=''
|
| 375 |
+
fi
|
| 376 |
+
|
| 377 |
+
run_with_timeout() {
|
| 378 |
+
local secs="$1"; shift
|
| 379 |
+
if command -v timeout &>/dev/null; then
|
| 380 |
+
timeout "$secs" "$@"
|
| 381 |
+
elif command -v gtimeout &>/dev/null; then
|
| 382 |
+
gtimeout "$secs" "$@"
|
| 383 |
+
else
|
| 384 |
+
"$@" &
|
| 385 |
+
local pid=$!
|
| 386 |
+
( sleep "$secs" && kill "$pid" 2>/dev/null ) &
|
| 387 |
+
local watcher=$!
|
| 388 |
+
wait "$pid" 2>/dev/null
|
| 389 |
+
local rc=$?
|
| 390 |
+
kill "$watcher" 2>/dev/null
|
| 391 |
+
wait "$watcher" 2>/dev/null
|
| 392 |
+
return $rc
|
| 393 |
+
fi
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
portable_mktemp() {
|
| 397 |
+
local prefix="${1:-validate}"
|
| 398 |
+
mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
CLEANUP_FILES=()
|
| 402 |
+
cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
|
| 403 |
+
trap cleanup EXIT
|
| 404 |
+
|
| 405 |
+
PING_URL="${1:-}"
|
| 406 |
+
REPO_DIR="${2:-.}"
|
| 407 |
+
|
| 408 |
+
if [ -z "$PING_URL" ]; then
|
| 409 |
+
printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
|
| 410 |
+
printf "\n"
|
| 411 |
+
printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
|
| 412 |
+
printf " repo_dir Path to your repo (default: current directory)\n"
|
| 413 |
+
exit 1
|
| 414 |
+
fi
|
| 415 |
+
|
| 416 |
+
if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
|
| 417 |
+
printf "Error: directory '%s' not found\n" "${2:-.}"
|
| 418 |
+
exit 1
|
| 419 |
+
fi
|
| 420 |
+
PING_URL="${PING_URL%/}"
|
| 421 |
+
export PING_URL
|
| 422 |
+
PASS=0
|
| 423 |
+
|
| 424 |
+
log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
|
| 425 |
+
pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
|
| 426 |
+
fail() { log "${RED}FAILED${NC} -- $1"; }
|
| 427 |
+
hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
|
| 428 |
+
stop_at() {
|
| 429 |
+
printf "\n"
|
| 430 |
+
printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
|
| 431 |
+
exit 1
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
printf "\n"
|
| 435 |
+
printf "${BOLD}========================================${NC}\n"
|
| 436 |
+
printf "${BOLD} OpenEnv Submission Validator${NC}\n"
|
| 437 |
+
printf "${BOLD}========================================${NC}\n"
|
| 438 |
+
log "Repo: $REPO_DIR"
|
| 439 |
+
log "Ping URL: $PING_URL"
|
| 440 |
+
printf "\n"
|
| 441 |
+
|
| 442 |
+
log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
|
| 443 |
+
|
| 444 |
+
CURL_OUTPUT=$(portable_mktemp "validate-curl")
|
| 445 |
+
CLEANUP_FILES+=("$CURL_OUTPUT")
|
| 446 |
+
HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
|
| 447 |
+
-H "Content-Type: application/json" -d '{}' \
|
| 448 |
+
"$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
|
| 449 |
+
|
| 450 |
+
if [ "$HTTP_CODE" = "200" ]; then
|
| 451 |
+
pass "HF Space is live and responds to /reset"
|
| 452 |
+
elif [ "$HTTP_CODE" = "000" ]; then
|
| 453 |
+
fail "HF Space not reachable (connection failed or timed out)"
|
| 454 |
+
hint "Check your network connection and that the Space is running."
|
| 455 |
+
hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
|
| 456 |
+
stop_at "Step 1"
|
| 457 |
+
else
|
| 458 |
+
fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
|
| 459 |
+
hint "Make sure your Space is running and the URL is correct."
|
| 460 |
+
hint "Try opening $PING_URL in your browser first."
|
| 461 |
+
stop_at "Step 1"
|
| 462 |
+
fi
|
| 463 |
+
|
| 464 |
+
log "${BOLD}Step 2/3: Running docker build${NC} ..."
|
| 465 |
+
|
| 466 |
+
if ! command -v docker &>/dev/null; then
|
| 467 |
+
fail "docker command not found"
|
| 468 |
+
hint "Install Docker: https://docs.docker.com/get-docker/"
|
| 469 |
+
stop_at "Step 2"
|
| 470 |
+
fi
|
| 471 |
+
|
| 472 |
+
if [ -f "$REPO_DIR/Dockerfile" ]; then
|
| 473 |
+
DOCKER_CONTEXT="$REPO_DIR"
|
| 474 |
+
elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
|
| 475 |
+
DOCKER_CONTEXT="$REPO_DIR/server"
|
| 476 |
+
else
|
| 477 |
+
fail "No Dockerfile found in repo root or server/ directory"
|
| 478 |
+
stop_at "Step 2"
|
| 479 |
+
fi
|
| 480 |
+
|
| 481 |
+
log " Found Dockerfile in $DOCKER_CONTEXT"
|
| 482 |
+
|
| 483 |
+
BUILD_OK=false
|
| 484 |
+
BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
|
| 485 |
+
|
| 486 |
+
if [ "$BUILD_OK" = true ]; then
|
| 487 |
+
pass "Docker build succeeded"
|
| 488 |
+
else
|
| 489 |
+
fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
|
| 490 |
+
printf "%s\n" "$BUILD_OUTPUT" | tail -20
|
| 491 |
+
stop_at "Step 2"
|
| 492 |
+
fi
|
| 493 |
+
|
| 494 |
+
log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
|
| 495 |
+
|
| 496 |
+
if ! command -v openenv &>/dev/null; then
|
| 497 |
+
fail "openenv command not found"
|
| 498 |
+
hint "Install it: pip install openenv-core"
|
| 499 |
+
stop_at "Step 3"
|
| 500 |
+
fi
|
| 501 |
+
|
| 502 |
+
VALIDATE_OK=false
|
| 503 |
+
VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
|
| 504 |
+
|
| 505 |
+
if [ "$VALIDATE_OK" = true ]; then
|
| 506 |
+
pass "openenv validate passed"
|
| 507 |
+
[ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
|
| 508 |
+
else
|
| 509 |
+
fail "openenv validate failed"
|
| 510 |
+
printf "%s\n" "$VALIDATE_OUTPUT"
|
| 511 |
+
stop_at "Step 3"
|
| 512 |
+
fi
|
| 513 |
+
|
| 514 |
+
printf "\n"
|
| 515 |
+
printf "${BOLD}========================================${NC}\n"
|
| 516 |
+
printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
|
| 517 |
+
printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
|
| 518 |
+
printf "${BOLD}========================================${NC}\n"
|
| 519 |
+
printf "\n"
|
| 520 |
+
|
| 521 |
+
exit 0
|
TRAINING_INSTRUCTIONS.md
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MolForge Training Instructions
|
| 2 |
+
|
| 3 |
+
This guide is for training a small model against MolForge without teaching it to exploit the environment.
|
| 4 |
+
|
| 5 |
+
## 1. Safety Defaults
|
| 6 |
+
|
| 7 |
+
MolForge now hides true internal molecule properties from public `state()` metadata by default. If you need to debug the environment manually, use:
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
MOLFORGE_DEBUG_STATE=1 python inference.py
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
Do not use `MOLFORGE_DEBUG_STATE=1` while collecting SFT data or running RL.
|
| 14 |
+
|
| 15 |
+
The chemistry oracle path uses RDKit descriptors by default and TDC molecule oracles when `pytdc` is available. TDC is kept as an optional extra because current PyTDC releases pull a large platform-sensitive ML stack; install it with `uv sync --extra tdc` on a compatible Python if you want TDC SA/QED oracles active. RDKit remains active in the default Docker/HF deployment, and the environment records the active backend in observation metadata.
|
| 16 |
+
|
| 17 |
+
The default reward mode is `assay_gated`, which gives coarse edit feedback and leaves the strongest quality signal to assays and terminal graders. For early RL warmup, use the curriculum reward mode:
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
MOLFORGE_REWARD_MODE=curriculum python inference.py
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Curriculum mode keeps the official `submission_score` strict, but gives bounded
|
| 24 |
+
training reward for useful evidence collection, evidence-supported submit
|
| 25 |
+
decisions, and non-submitted near-miss episodes. If the model reaches a strong
|
| 26 |
+
evidence package and still fails to submit before the deadline, curriculum mode
|
| 27 |
+
adds a small missed-nomination penalty. This prevents small models from seeing
|
| 28 |
+
only zero terminal scores while they are still learning when to submit, without
|
| 29 |
+
letting endless assay collection become the best behavior. Use this for initial
|
| 30 |
+
GRPO curves, then switch back to `assay_gated` for final evaluation.
|
| 31 |
+
|
| 32 |
+
For curriculum experiments only, you can also restore the older dense edit reward:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
MOLFORGE_REWARD_MODE=dense python inference.py
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Use randomized training episodes when collecting data or training a policy:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
MOLFORGE_TRAINING_RANDOMIZATION=1 MOLFORGE_RANDOM_SEED=42 python inference.py
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Keep randomization off for judge-facing baseline runs so scores remain reproducible.
|
| 45 |
+
|
| 46 |
+
## 2. Recommended Training Plan
|
| 47 |
+
|
| 48 |
+
Use a two-stage plan:
|
| 49 |
+
|
| 50 |
+
1. Small SFT warm start
|
| 51 |
+
2. RL with verifiable rewards
|
| 52 |
+
|
| 53 |
+
SFT is only for teaching the model the action schema and basic workflow. RL should do the real environment optimization.
|
| 54 |
+
|
| 55 |
+
## 3. What SFT Should Teach
|
| 56 |
+
|
| 57 |
+
Include these example types:
|
| 58 |
+
|
| 59 |
+
- Valid JSON action formatting
|
| 60 |
+
- Correct `acting_role` for each action
|
| 61 |
+
- Short `rationale` values that explain the decision without chain-of-thought
|
| 62 |
+
- `evidence` lists that cite visible observation facts only
|
| 63 |
+
- `expected_effects` dictionaries with directional predictions, not hidden scores
|
| 64 |
+
- Specialist message bundles with proposal, approval, objection, assay request, or rejection
|
| 65 |
+
- Running cheap/necessary assays before risky submissions
|
| 66 |
+
- Editing toward safer fragments when toxicity risk is visible
|
| 67 |
+
- Restarting early in the hard sunk-cost scenario
|
| 68 |
+
- Submitting only when evidence covers the task constraints
|
| 69 |
+
- Handling noisy assay estimates without undoing a high-confidence final candidate at the last moment
|
| 70 |
+
- Recovering from low budget by choosing small actions or stopping
|
| 71 |
+
|
| 72 |
+
Avoid these example types:
|
| 73 |
+
|
| 74 |
+
- Any example that reads `state.metadata.debug_hidden_properties`
|
| 75 |
+
- Any answer that mentions exact hidden objective deltas
|
| 76 |
+
- Hidden chain-of-thought or long private reasoning transcripts
|
| 77 |
+
- Repetitive message spam just to collect coordination reward
|
| 78 |
+
- Premature submit actions without potency/safety evidence
|
| 79 |
+
- Examples where missing specialist messages are silently repaired by the runner
|
| 80 |
+
|
| 81 |
+
## 4. Generate a Starter SFT Dataset
|
| 82 |
+
|
| 83 |
+
For the first schema warm start, use the strict curriculum dataset. It includes
|
| 84 |
+
explicit JSON `null` fields, only the intended top-level action keys, all action
|
| 85 |
+
types, all assay tools, all edit subtypes, and valid role/message permissions:
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
python scripts/generate_sft_schema_strict_dataset.py \
|
| 89 |
+
--episodes 75 \
|
| 90 |
+
--output data/molforge_sft_schema_strict.jsonl
|
| 91 |
+
|
| 92 |
+
python scripts/validate_sft_traces.py data/molforge_sft_schema_strict.jsonl
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Use this file first for Qwen 2B-class SFT:
|
| 96 |
+
|
| 97 |
+
```text
|
| 98 |
+
data/molforge_sft_schema_strict.jsonl
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
The older trace generator is still useful after the model learns the exact
|
| 102 |
+
schema, because it provides more policy-like trajectories:
|
| 103 |
+
|
| 104 |
+
Run:
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
python scripts/generate_sft_traces.py --episodes 80 --output data/molforge_sft_traces.jsonl
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
For a more robust dataset:
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
python scripts/generate_sft_traces.py \
|
| 114 |
+
--episodes 200 \
|
| 115 |
+
--randomized \
|
| 116 |
+
--output data/molforge_sft_traces_randomized.jsonl
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
The generated records use chat-style JSONL:
|
| 120 |
+
|
| 121 |
+
```json
|
| 122 |
+
{"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Before training, spot-check the JSONL:
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
python - <<'PY'
|
| 129 |
+
import json
|
| 130 |
+
from pathlib import Path
|
| 131 |
+
|
| 132 |
+
path = Path("data/molforge_sft_traces.jsonl")
|
| 133 |
+
for i, line in zip(range(3), path.open()):
|
| 134 |
+
item = json.loads(line)
|
| 135 |
+
print(i, item["metadata"], item["messages"][-1]["content"][:300])
|
| 136 |
+
PY
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
## 5. SFT Settings
|
| 140 |
+
|
| 141 |
+
Start small:
|
| 142 |
+
|
| 143 |
+
- Dataset size: 200 to 1,000 action examples
|
| 144 |
+
- Max sequence length: 2,048 or 4,096
|
| 145 |
+
- LoRA rank: 16 or 32
|
| 146 |
+
- Learning rate: `1e-4` to `2e-4`
|
| 147 |
+
- Epochs: 1 to 3
|
| 148 |
+
- Target modules: attention and MLP projection layers
|
| 149 |
+
- Save LoRA adapters first; test them before merging
|
| 150 |
+
|
| 151 |
+
Stop SFT once the model reliably emits valid `MolForgeAction` JSON. Do not overfit it into copying one fixed heuristic path.
|
| 152 |
+
|
| 153 |
+
## 6. RL Stage
|
| 154 |
+
|
| 155 |
+
After SFT, run RL/GRPO with MolForge as the verifier environment.
|
| 156 |
+
|
| 157 |
+
Use these environment settings:
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
export MOLFORGE_TRAINING_RANDOMIZATION=1
|
| 161 |
+
export MOLFORGE_REWARD_MODE=curriculum
|
| 162 |
+
unset MOLFORGE_DEBUG_STATE
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
Once the model starts submitting valid candidates, run a second RL/evaluation
|
| 166 |
+
phase with:
|
| 167 |
+
|
| 168 |
+
```bash
|
| 169 |
+
export MOLFORGE_REWARD_MODE=assay_gated
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
Report both curves if possible:
|
| 173 |
+
|
| 174 |
+
- curriculum reward curve for early learning progress;
|
| 175 |
+
- strict terminal `submission_score` before/after for judge-facing task success.
|
| 176 |
+
|
| 177 |
+
Track these metrics separately:
|
| 178 |
+
|
| 179 |
+
- Average terminal `submission_score`
|
| 180 |
+
- Average terminal `candidate_score`
|
| 181 |
+
- Average terminal `budget_score`
|
| 182 |
+
- Budget remaining at valid submit
|
| 183 |
+
- Invalid action rate
|
| 184 |
+
- Policy veto rate
|
| 185 |
+
- Budget exhaustion rate
|
| 186 |
+
- Repeated assay count
|
| 187 |
+
- Loop penalty count
|
| 188 |
+
- Coordination score
|
| 189 |
+
- Evidence score
|
| 190 |
+
- Submitted-without-evidence count
|
| 191 |
+
- Constraint margin score
|
| 192 |
+
- Number of actions before submit
|
| 193 |
+
|
| 194 |
+
Inspect generations every few hundred updates. A rising reward is not enough if the model learns to spam messages, submit without evidence, or memorize the three default scenarios.
|
| 195 |
+
|
| 196 |
+
## 7. Evaluation Protocol
|
| 197 |
+
|
| 198 |
+
Use three evaluations:
|
| 199 |
+
|
| 200 |
+
1. Deterministic public tasks
|
| 201 |
+
Run with randomization off and compare to `python inference.py`.
|
| 202 |
+
2. Randomized training tasks
|
| 203 |
+
Run with `MOLFORGE_TRAINING_RANDOMIZATION=1`.
|
| 204 |
+
3. Holdout tasks
|
| 205 |
+
Add new scenario configs or fragment perturbations not present in SFT traces.
|
| 206 |
+
|
| 207 |
+
A trained model should improve terminal submission score while keeping invalid actions and evidence-free submissions low.
|
| 208 |
+
|
| 209 |
+
For the full testing protocol, including how to compare curriculum reward
|
| 210 |
+
against strict evaluation, see [EVALUATION_PROTOCOL.md](EVALUATION_PROTOCOL.md).
|
| 211 |
+
|
| 212 |
+
## 8. Model Choice
|
| 213 |
+
|
| 214 |
+
Recommended starting point:
|
| 215 |
+
|
| 216 |
+
- `unsloth/Qwen3.5-2B` for the lightest serious iteration loop
|
| 217 |
+
- `unsloth/Qwen3-4B-Instruct-2507` if you can afford a little more VRAM and want stronger JSON/tool following
|
| 218 |
+
|
| 219 |
+
Why:
|
| 220 |
+
|
| 221 |
+
- Qwen3.5 has 0.8B, 2B, and 4B Unsloth fine-tuning support.
|
| 222 |
+
- The 2B class should be fast enough for repeated MolForge SFT/RL experiments.
|
| 223 |
+
- The 4B class is still lightweight, but should be more reliable for structured action generation.
|
| 224 |
+
|
| 225 |
+
Use `Qwen3.5-0.8B` only for plumbing tests. It is useful to verify the training loop, but likely too weak to judge the environment.
|
| 226 |
+
|
| 227 |
+
If you have more GPU budget:
|
| 228 |
+
|
| 229 |
+
- `unsloth/Qwen3-8B` or a current Qwen3/Qwen3.5 8B-class instruct model
|
| 230 |
+
|
| 231 |
+
If you specifically want alternate-family baselines:
|
| 232 |
+
|
| 233 |
+
- `unsloth/Llama-3.1-8B-Instruct`
|
| 234 |
+
- Gemma 3/4 small instruct models can be tested, but prefer Qwen first because the current Unsloth Qwen3.5 fine-tuning path is clearer for 2B/4B RL iteration.
|
| 235 |
+
|
| 236 |
+
For the hackathon, prefer faster iteration over maximum model size. A clean 4B model trained well against this environment is more useful than a larger model that only runs a few noisy experiments.
|
| 237 |
+
|
| 238 |
+
## 9. Honest Inference Reporting
|
| 239 |
+
|
| 240 |
+
`inference.py` has no heuristic fallback. It requires a configured model and exits with an error if the model is missing, times out, or emits unparsable action JSON.
|
| 241 |
+
|
| 242 |
+
`local_inference.py` also has no heuristic policy fallback and does not patch missing team messages into model outputs. If a model omits reviewer communication, that weakness should appear as missing-review penalties and a lower `coordination_score`.
|
| 243 |
+
|
| 244 |
+
For real model evaluation, run:
|
| 245 |
+
|
| 246 |
+
```bash
|
| 247 |
+
API_BASE_URL=https://router.huggingface.co/v1 \
|
| 248 |
+
MODEL_NAME=your-model \
|
| 249 |
+
HF_TOKEN=your-token \
|
| 250 |
+
python inference.py
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
Use the deterministic trace policy only for SFT data generation, not for reporting model scores.
|
__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MolForge OpenEnv package exports."""
|
| 2 |
+
|
| 3 |
+
from .client import MolForgeEnv
|
| 4 |
+
from .models import MolForgeAction, MolForgeObservation, MolForgeState
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"MolForgeAction",
|
| 8 |
+
"MolForgeEnv",
|
| 9 |
+
"MolForgeObservation",
|
| 10 |
+
"MolForgeState",
|
| 11 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Synchronous and async client for the MolForge environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict
|
| 6 |
+
|
| 7 |
+
from openenv.core import EnvClient
|
| 8 |
+
from openenv.core.client_types import StepResult
|
| 9 |
+
|
| 10 |
+
from .models import MolForgeAction, MolForgeObservation, MolForgeState
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MolForgeEnv(EnvClient[MolForgeAction, MolForgeObservation, MolForgeState]):
|
| 14 |
+
"""OpenEnv client for the MolForge environment."""
|
| 15 |
+
|
| 16 |
+
def _step_payload(self, action: MolForgeAction) -> Dict:
|
| 17 |
+
return action.model_dump(exclude_none=True)
|
| 18 |
+
|
| 19 |
+
def _parse_result(self, payload: Dict) -> StepResult[MolForgeObservation]:
|
| 20 |
+
obs_data = dict(payload.get("observation", payload))
|
| 21 |
+
obs_data["done"] = payload.get("done", obs_data.get("done", False))
|
| 22 |
+
obs_data["reward"] = payload.get("reward", obs_data.get("reward"))
|
| 23 |
+
observation = MolForgeObservation(**obs_data)
|
| 24 |
+
return StepResult(
|
| 25 |
+
observation=observation,
|
| 26 |
+
reward=payload.get("reward"),
|
| 27 |
+
done=payload.get("done", False),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def _parse_state(self, payload: Dict) -> MolForgeState:
|
| 31 |
+
return MolForgeState(**payload)
|
inference.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Judge-facing baseline inference script for MolForge."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Optional, cast
|
| 8 |
+
|
| 9 |
+
from openai import OpenAI
|
| 10 |
+
|
| 11 |
+
from inference_common import (
|
| 12 |
+
COMPACT_SYSTEM_PROMPT,
|
| 13 |
+
SYSTEM_PROMPT,
|
| 14 |
+
build_model_payload,
|
| 15 |
+
extract_json,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from molforge.models import MolForgeAction, MolForgeObservation
|
| 20 |
+
from molforge.server.molforge_environment import MolForgeEnvironment
|
| 21 |
+
except ImportError:
|
| 22 |
+
from models import MolForgeAction, MolForgeObservation
|
| 23 |
+
from server.molforge_environment import MolForgeEnvironment
|
| 24 |
+
|
| 25 |
+
API_BASE_URL = os.getenv("API_BASE_URL")
|
| 26 |
+
MODEL_NAME = os.getenv("MODEL_NAME")
|
| 27 |
+
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
|
| 28 |
+
MAX_TURNS = 10
|
| 29 |
+
MODEL_TIMEOUT_S = float(os.getenv("MODEL_TIMEOUT_S", "35"))
|
| 30 |
+
MODEL_LONG_TIMEOUT_S = float(os.getenv("MODEL_LONG_TIMEOUT_S", "45"))
|
| 31 |
+
MODEL_RETRY_TIMEOUT_S = float(os.getenv("MODEL_RETRY_TIMEOUT_S", "15"))
|
| 32 |
+
MODEL_MAX_TOKENS = int(os.getenv("MODEL_MAX_TOKENS", "220"))
|
| 33 |
+
MIN_REPORTED_SCORE = 1e-6
|
| 34 |
+
MAX_REPORTED_SCORE = 1.0 - 1e-6
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main() -> None:
|
| 38 |
+
env = MolForgeEnvironment()
|
| 39 |
+
if not API_BASE_URL or not MODEL_NAME or not API_KEY:
|
| 40 |
+
raise RuntimeError(
|
| 41 |
+
"API_BASE_URL, MODEL_NAME, and API_KEY or HF_TOKEN are required. "
|
| 42 |
+
"No heuristic fallback is available."
|
| 43 |
+
)
|
| 44 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 45 |
+
|
| 46 |
+
scores = []
|
| 47 |
+
raw_final_scores = []
|
| 48 |
+
submission_scores = []
|
| 49 |
+
progress_scores = []
|
| 50 |
+
model_action_count = 0
|
| 51 |
+
for episode_index in range(3):
|
| 52 |
+
observation = env.reset()
|
| 53 |
+
task_name = observation.scenario_id
|
| 54 |
+
episode_error = ""
|
| 55 |
+
print(
|
| 56 |
+
f"[START] task={task_name} difficulty={observation.difficulty} episode={episode_index + 1}",
|
| 57 |
+
flush=True,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
for _ in range(MAX_TURNS):
|
| 61 |
+
if observation.done:
|
| 62 |
+
break
|
| 63 |
+
try:
|
| 64 |
+
action = choose_action(client, observation)
|
| 65 |
+
model_action_count += 1
|
| 66 |
+
observation = env.step(action)
|
| 67 |
+
except Exception as exc:
|
| 68 |
+
episode_error = f"{exc.__class__.__name__}:{exc}"
|
| 69 |
+
print(
|
| 70 |
+
f"[STEP] task={task_name} step={observation.step_index + 1} "
|
| 71 |
+
f"reward=0.000000 action=model_error status=failed",
|
| 72 |
+
flush=True,
|
| 73 |
+
)
|
| 74 |
+
break
|
| 75 |
+
print(
|
| 76 |
+
f"[STEP] task={task_name} step={observation.step_index} "
|
| 77 |
+
f"reward={observation.reward:.6f} action={action.action_type} "
|
| 78 |
+
f"actor={action.acting_role} status={observation.governance.status}",
|
| 79 |
+
flush=True,
|
| 80 |
+
)
|
| 81 |
+
if observation.done:
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
grader_scores = observation.metadata.get("terminal_grader_scores", {})
|
| 85 |
+
raw_final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0)))
|
| 86 |
+
final_score = reportable_score(raw_final_score)
|
| 87 |
+
submission_score = float(grader_scores.get("submission_score", 0.0))
|
| 88 |
+
progress_score = float(grader_scores.get("progress_score", 0.0))
|
| 89 |
+
scores.append(final_score)
|
| 90 |
+
raw_final_scores.append(raw_final_score)
|
| 91 |
+
submission_scores.append(submission_score)
|
| 92 |
+
progress_scores.append(progress_score)
|
| 93 |
+
end_line = (
|
| 94 |
+
f"[END] task={task_name} score={final_score:.6f} raw_score={raw_final_score:.6f} "
|
| 95 |
+
f"submission_score={submission_score:.6f} progress_score={progress_score:.6f} "
|
| 96 |
+
f"steps={observation.step_index}"
|
| 97 |
+
)
|
| 98 |
+
if episode_error:
|
| 99 |
+
end_line += f" error={json.dumps(episode_error)}"
|
| 100 |
+
print(end_line, flush=True)
|
| 101 |
+
if observation.report_card:
|
| 102 |
+
print(observation.report_card, flush=True)
|
| 103 |
+
|
| 104 |
+
average = sum(scores) / len(scores)
|
| 105 |
+
average_progress = sum(progress_scores) / len(progress_scores)
|
| 106 |
+
summary = {
|
| 107 |
+
"scores": scores,
|
| 108 |
+
"raw_final_scores": raw_final_scores,
|
| 109 |
+
"average_final_score": round(reportable_score(average), 6),
|
| 110 |
+
"submission_scores": submission_scores,
|
| 111 |
+
"average_submission_score": round(sum(submission_scores) / len(submission_scores), 4),
|
| 112 |
+
"progress_scores": progress_scores,
|
| 113 |
+
"average_progress_score": round(average_progress, 4),
|
| 114 |
+
"model_action_count": model_action_count,
|
| 115 |
+
"model_name": MODEL_NAME,
|
| 116 |
+
"api_base_url": API_BASE_URL,
|
| 117 |
+
"fallback_enabled": False,
|
| 118 |
+
}
|
| 119 |
+
print("[SUMMARY] " + json.dumps(summary, separators=(",", ":")), flush=True)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def reportable_score(score: float) -> float:
|
| 123 |
+
"""Validator-facing scores must be strictly between 0 and 1."""
|
| 124 |
+
|
| 125 |
+
if score <= 0.0:
|
| 126 |
+
return MIN_REPORTED_SCORE
|
| 127 |
+
if score >= 1.0:
|
| 128 |
+
return MAX_REPORTED_SCORE
|
| 129 |
+
return score
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def choose_action(client: OpenAI, observation: MolForgeObservation) -> MolForgeAction:
|
| 133 |
+
"""Use the model and fail loudly when it cannot produce a valid action."""
|
| 134 |
+
|
| 135 |
+
action, error = ask_model(client, observation)
|
| 136 |
+
if action is None:
|
| 137 |
+
raise RuntimeError(f"Model action failed: {error}")
|
| 138 |
+
return action
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def ask_model(client: OpenAI, observation: MolForgeObservation) -> tuple[Optional[MolForgeAction], str]:
|
| 142 |
+
"""Request a structured team action from the model and parse it safely."""
|
| 143 |
+
|
| 144 |
+
errors = []
|
| 145 |
+
try:
|
| 146 |
+
full_payload = build_model_payload(observation, compact=False)
|
| 147 |
+
timeout_s = model_timeout_for_step(observation)
|
| 148 |
+
data = request_action_json(
|
| 149 |
+
client=client,
|
| 150 |
+
system_prompt=SYSTEM_PROMPT,
|
| 151 |
+
user_payload=full_payload,
|
| 152 |
+
timeout_s=timeout_s,
|
| 153 |
+
)
|
| 154 |
+
return MolForgeAction(**data), ""
|
| 155 |
+
except Exception as exc:
|
| 156 |
+
errors.append(f"full_prompt:{exc.__class__.__name__}:{exc}")
|
| 157 |
+
try:
|
| 158 |
+
compact_payload = build_model_payload(observation, compact=True)
|
| 159 |
+
data = request_action_json(
|
| 160 |
+
client=client,
|
| 161 |
+
system_prompt=COMPACT_SYSTEM_PROMPT,
|
| 162 |
+
user_payload=compact_payload,
|
| 163 |
+
timeout_s=MODEL_RETRY_TIMEOUT_S,
|
| 164 |
+
)
|
| 165 |
+
return MolForgeAction(**data), ""
|
| 166 |
+
except Exception as retry_exc:
|
| 167 |
+
errors.append(f"compact_prompt:{retry_exc.__class__.__name__}:{retry_exc}")
|
| 168 |
+
return None, " | ".join(errors)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def request_action_json(
|
| 172 |
+
*,
|
| 173 |
+
client: OpenAI,
|
| 174 |
+
system_prompt: str,
|
| 175 |
+
user_payload: dict[str, Any],
|
| 176 |
+
timeout_s: float,
|
| 177 |
+
) -> dict[str, Any]:
|
| 178 |
+
"""Call the remote model with a bounded timeout and parse a JSON action."""
|
| 179 |
+
|
| 180 |
+
configured_client = client.with_options(timeout=timeout_s)
|
| 181 |
+
completion = configured_client.chat.completions.create(
|
| 182 |
+
model=MODEL_NAME,
|
| 183 |
+
temperature=0.0,
|
| 184 |
+
max_tokens=MODEL_MAX_TOKENS,
|
| 185 |
+
messages=[
|
| 186 |
+
{"role": "system", "content": system_prompt},
|
| 187 |
+
{"role": "user", "content": json.dumps(user_payload, indent=2)},
|
| 188 |
+
],
|
| 189 |
+
)
|
| 190 |
+
message_content = completion.choices[0].message.content
|
| 191 |
+
if isinstance(message_content, list):
|
| 192 |
+
text = "".join(part.get("text", "") for part in cast(list[dict[str, Any]], message_content))
|
| 193 |
+
else:
|
| 194 |
+
text = message_content or ""
|
| 195 |
+
return extract_json(text)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def model_timeout_for_step(observation: MolForgeObservation) -> float:
|
| 199 |
+
"""Allow more time for high-value late-stage decisions without making every step unbounded."""
|
| 200 |
+
|
| 201 |
+
if observation.difficulty == "hard":
|
| 202 |
+
return MODEL_LONG_TIMEOUT_S
|
| 203 |
+
if observation.step_index >= observation.max_steps - 2:
|
| 204 |
+
return MODEL_LONG_TIMEOUT_S
|
| 205 |
+
return MODEL_TIMEOUT_S
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
main()
|
inference_common.py
ADDED
|
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared inference helpers for MolForge judge/local runners."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from molforge.models import AgentMessage, MolForgeAction, MolForgeObservation
|
| 10 |
+
except ImportError:
|
| 11 |
+
from models import AgentMessage, MolForgeAction, MolForgeObservation
|
| 12 |
+
|
| 13 |
+
SYSTEM_PROMPT = """
|
| 14 |
+
You control the MolForge specialist team.
|
| 15 |
+
Return exactly one JSON object matching this schema.
|
| 16 |
+
The top-level "action_type" must be one of exactly:
|
| 17 |
+
["edit", "run_assay", "submit", "restart", "defer"].
|
| 18 |
+
Never use "proposal", "approval", "objection", "risk_flag", "assay_request",
|
| 19 |
+
"rejection", or "submission_recommendation" as the top-level action_type.
|
| 20 |
+
Those words are only valid inside messages[].message_type.
|
| 21 |
+
{
|
| 22 |
+
"action_type": "edit" | "run_assay" | "submit" | "restart" | "defer",
|
| 23 |
+
"acting_role": "lead_chemist" | "assay_planner",
|
| 24 |
+
"edit_type": "add_fragment" | "substitute" | "remove" | "undo_last_edit" | null,
|
| 25 |
+
"slot": "warhead" | "hinge" | "solvent_tail" | "back_pocket" | null,
|
| 26 |
+
"fragment": string | null,
|
| 27 |
+
"tool_name": "evaluate_properties" | "dock_target" | "assay_toxicity" | "estimate_synthesizability" | "evaluate_novelty" | "search_literature" | "run_md_simulation" | null,
|
| 28 |
+
"rationale": string,
|
| 29 |
+
"evidence": [string],
|
| 30 |
+
"expected_effects": {
|
| 31 |
+
"potency": "up" | "down" | "neutral" | "unknown" | "not_applicable",
|
| 32 |
+
"toxicity": "up" | "down" | "neutral" | "unknown" | "not_applicable",
|
| 33 |
+
"synth": "up" | "down" | "neutral" | "unknown" | "not_applicable",
|
| 34 |
+
"novelty": "up" | "down" | "neutral" | "unknown" | "not_applicable",
|
| 35 |
+
"budget": "up" | "down" | "neutral" | "unknown" | "not_applicable"
|
| 36 |
+
},
|
| 37 |
+
"messages": [
|
| 38 |
+
{
|
| 39 |
+
"sender": "lead_chemist" | "toxicologist" | "assay_planner" | "process_chemist",
|
| 40 |
+
"message_type": "proposal" | "approval" | "objection" | "risk_flag" | "assay_request" | "rejection" | "submission_recommendation",
|
| 41 |
+
"severity": "low" | "medium" | "high" | "critical",
|
| 42 |
+
"summary": string,
|
| 43 |
+
"payload": object
|
| 44 |
+
}
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
Required top-level keys only:
|
| 48 |
+
action_type, acting_role, edit_type, slot, fragment, tool_name, rationale,
|
| 49 |
+
evidence, expected_effects, messages.
|
| 50 |
+
Do not output wrapper keys such as action, role, message_status,
|
| 51 |
+
message_payload, sender_role, or explanation_reason.
|
| 52 |
+
Use JSON null for unused optional fields.
|
| 53 |
+
Use structured specialist messages. Keep rationale short. Evidence must cite only visible observation facts. Expected effects are directional predictions, not hidden scores. Prefer cheap informative assays early, respect safety evidence, and do not submit without adequate support.
|
| 54 |
+
Critical role rules:
|
| 55 |
+
- lead_chemist may send only proposal, revision_request, or submission_recommendation.
|
| 56 |
+
- assay_planner may send proposal, approval, rejection, assay_request, or submission_recommendation.
|
| 57 |
+
- toxicologist may send approval, objection, risk_flag, assay_request, or rejection.
|
| 58 |
+
- process_chemist may send approval, objection, risk_flag, or assay_request.
|
| 59 |
+
- The acting_role should include a proposal message inside messages[].
|
| 60 |
+
- Do not use lead_chemist approval messages.
|
| 61 |
+
- Do not use toxicologist proposal messages.
|
| 62 |
+
- For run_assay, acting_role must be assay_planner. For edit, submit, restart, or defer, acting_role must be lead_chemist.
|
| 63 |
+
""".strip()
|
| 64 |
+
|
| 65 |
+
COMPACT_SYSTEM_PROMPT = """
|
| 66 |
+
Return one concise JSON team action only.
|
| 67 |
+
Do not explain.
|
| 68 |
+
Top-level action_type must be edit, run_assay, submit, restart, or defer.
|
| 69 |
+
Never use proposal as action_type; proposal is only a message_type.
|
| 70 |
+
Use only the required MolForgeAction top-level keys.
|
| 71 |
+
Prioritize finishing the current task with the smallest valid action bundle.
|
| 72 |
+
Respect role/message permissions exactly. Never output string "null"; use JSON null.
|
| 73 |
+
""".strip()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def heuristic_team_action(observation: MolForgeObservation) -> MolForgeAction:
|
| 77 |
+
candidate = select_candidate_action(observation)
|
| 78 |
+
attach_reasoning_fields(observation, candidate)
|
| 79 |
+
return attach_team_messages(observation, candidate)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def attach_reasoning_fields(
|
| 83 |
+
observation: MolForgeObservation,
|
| 84 |
+
action: MolForgeAction,
|
| 85 |
+
) -> MolForgeAction:
|
| 86 |
+
action.evidence = build_action_evidence(observation, action)
|
| 87 |
+
action.expected_effects = build_expected_effects(observation, action)
|
| 88 |
+
return action
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def select_candidate_action(observation: MolForgeObservation) -> MolForgeAction:
|
| 92 |
+
current = current_fragments(observation)
|
| 93 |
+
known_potency = known_estimate(observation, "potency")
|
| 94 |
+
known_toxicity = known_estimate(observation, "toxicity")
|
| 95 |
+
known_synth = known_estimate(observation, "synth")
|
| 96 |
+
potency_threshold = threshold_value(observation, "potency_min")
|
| 97 |
+
toxicity_threshold = threshold_value(observation, "toxicity_max")
|
| 98 |
+
synth_threshold = threshold_value(observation, "synth_min")
|
| 99 |
+
|
| 100 |
+
current_assay_props = current_property_names(observation)
|
| 101 |
+
required_evidence = ["potency", "toxicity"] + (["synth"] if synth_threshold is not None else [])
|
| 102 |
+
has_required_evidence = all(prop in current_assay_props for prop in required_evidence)
|
| 103 |
+
constraints_known_pass = constraints_pass_from_visible_evidence(observation)
|
| 104 |
+
post_shift_potency_ready = hard_post_shift_potency_ready(observation)
|
| 105 |
+
if has_required_evidence and post_shift_potency_ready and (
|
| 106 |
+
constraints_known_pass
|
| 107 |
+
or on_planned_final_candidate(observation, current)
|
| 108 |
+
or observation.step_index >= observation.max_steps - 1
|
| 109 |
+
):
|
| 110 |
+
return MolForgeAction(
|
| 111 |
+
action_type="submit",
|
| 112 |
+
acting_role="lead_chemist",
|
| 113 |
+
rationale="Current assay evidence covers potency, toxicity, and feasibility constraints, so the team should submit before spending more budget.",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if (
|
| 117 |
+
observation.scenario_id == "level_2_hard"
|
| 118 |
+
and current["warhead"] != "nitrile"
|
| 119 |
+
and observation.remaining_budget >= 350
|
| 120 |
+
):
|
| 121 |
+
return MolForgeAction(
|
| 122 |
+
action_type="restart",
|
| 123 |
+
acting_role="lead_chemist",
|
| 124 |
+
rationale="The starting series is a known trap under the resistance shift; restart before spending assay budget.",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
target_edit = planned_fragment_edit(observation, current)
|
| 128 |
+
if target_edit is not None:
|
| 129 |
+
slot, fragment, rationale = target_edit
|
| 130 |
+
return MolForgeAction(
|
| 131 |
+
action_type="edit",
|
| 132 |
+
acting_role="lead_chemist",
|
| 133 |
+
edit_type="substitute",
|
| 134 |
+
slot=slot, # type: ignore[arg-type]
|
| 135 |
+
fragment=fragment,
|
| 136 |
+
rationale=rationale,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if (
|
| 140 |
+
observation.scenario_id == "level_2_hard"
|
| 141 |
+
and not post_shift_potency_ready
|
| 142 |
+
and observation.step_index < 3
|
| 143 |
+
):
|
| 144 |
+
if known_toxicity is None and observation.remaining_budget >= 2000:
|
| 145 |
+
return MolForgeAction(
|
| 146 |
+
action_type="run_assay",
|
| 147 |
+
acting_role="assay_planner",
|
| 148 |
+
tool_name="assay_toxicity",
|
| 149 |
+
rationale="Use the pre-shift turns to lock down direct toxicity evidence on the restart scaffold.",
|
| 150 |
+
)
|
| 151 |
+
if known_synth is None and observation.remaining_budget >= 120:
|
| 152 |
+
return MolForgeAction(
|
| 153 |
+
action_type="run_assay",
|
| 154 |
+
acting_role="assay_planner",
|
| 155 |
+
tool_name="estimate_synthesizability",
|
| 156 |
+
rationale="Confirm route feasibility before the target mutation changes the potency readout.",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
if known_toxicity is None and observation.remaining_budget >= 2000:
|
| 160 |
+
return MolForgeAction(
|
| 161 |
+
action_type="run_assay",
|
| 162 |
+
acting_role="assay_planner",
|
| 163 |
+
tool_name="assay_toxicity",
|
| 164 |
+
rationale="The current candidate needs direct toxicity evidence before it can be submitted.",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if (
|
| 168 |
+
synth_threshold is not None
|
| 169 |
+
and known_synth is None
|
| 170 |
+
and observation.remaining_budget >= 120
|
| 171 |
+
):
|
| 172 |
+
return MolForgeAction(
|
| 173 |
+
action_type="run_assay",
|
| 174 |
+
acting_role="assay_planner",
|
| 175 |
+
tool_name="estimate_synthesizability",
|
| 176 |
+
rationale="The current candidate needs explicit synthesizability evidence before submission.",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if (
|
| 180 |
+
known_potency is None
|
| 181 |
+
and observation.remaining_budget >= 300
|
| 182 |
+
and can_collect_potency_now(observation)
|
| 183 |
+
):
|
| 184 |
+
return MolForgeAction(
|
| 185 |
+
action_type="run_assay",
|
| 186 |
+
acting_role="assay_planner",
|
| 187 |
+
tool_name="dock_target",
|
| 188 |
+
rationale="The final decision needs a direct potency readout on the current molecule.",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if is_safety_risky(current, known_toxicity, toxicity_threshold):
|
| 192 |
+
for slot, fragment, rationale in [
|
| 193 |
+
("solvent_tail", "morpholine", "Morpholine typically lowers safety risk while keeping the molecule tractable."),
|
| 194 |
+
("back_pocket", "cyano", "Cyano is a safer back-pocket handle than a strongly lipophilic group."),
|
| 195 |
+
("warhead", "reversible_cyanoacrylamide", "A softer warhead can preserve potency while reducing reactivity risk."),
|
| 196 |
+
("hinge", "azaindole", "Azaindole can recover potency after safer peripheral edits."),
|
| 197 |
+
]:
|
| 198 |
+
if current[slot] != fragment:
|
| 199 |
+
return MolForgeAction(
|
| 200 |
+
action_type="edit",
|
| 201 |
+
acting_role="lead_chemist",
|
| 202 |
+
edit_type="substitute",
|
| 203 |
+
slot=slot, # type: ignore[arg-type]
|
| 204 |
+
fragment=fragment,
|
| 205 |
+
rationale=rationale,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
if potency_threshold is not None and (known_potency is None or known_potency < potency_threshold):
|
| 209 |
+
preferred_warhead = "nitrile" if observation.scenario_id == "level_2_hard" else "acrylamide"
|
| 210 |
+
for slot, fragment, rationale in [
|
| 211 |
+
("hinge", "azaindole", "Azaindole is the strongest potency-oriented hinge in this library."),
|
| 212 |
+
("back_pocket", "cyano", "Cyano improves potency more safely than heavy lipophilic groups."),
|
| 213 |
+
("warhead", preferred_warhead, "The warhead should align with the current target context."),
|
| 214 |
+
]:
|
| 215 |
+
if current[slot] != fragment:
|
| 216 |
+
return MolForgeAction(
|
| 217 |
+
action_type="edit",
|
| 218 |
+
acting_role="lead_chemist",
|
| 219 |
+
edit_type="substitute",
|
| 220 |
+
slot=slot, # type: ignore[arg-type]
|
| 221 |
+
fragment=fragment,
|
| 222 |
+
rationale=rationale,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if (
|
| 226 |
+
known_potency is None
|
| 227 |
+
and observation.remaining_budget >= 50
|
| 228 |
+
and not has_assay_tool(observation, "evaluate_properties")
|
| 229 |
+
):
|
| 230 |
+
return MolForgeAction(
|
| 231 |
+
action_type="run_assay",
|
| 232 |
+
acting_role="assay_planner",
|
| 233 |
+
tool_name="evaluate_properties",
|
| 234 |
+
rationale="Use the cheap property panel to cover any remaining potency evidence gap.",
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
if known_potency is None and observation.remaining_budget >= 300:
|
| 238 |
+
return MolForgeAction(
|
| 239 |
+
action_type="run_assay",
|
| 240 |
+
acting_role="assay_planner",
|
| 241 |
+
tool_name="dock_target",
|
| 242 |
+
rationale="Potency is still under-characterized, so the team wants a more direct binding readout.",
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if (
|
| 246 |
+
observation.scenario_id == "level_2_hard"
|
| 247 |
+
and has_required_evidence
|
| 248 |
+
and not post_shift_potency_ready
|
| 249 |
+
and observation.remaining_budget >= 300
|
| 250 |
+
):
|
| 251 |
+
return MolForgeAction(
|
| 252 |
+
action_type="run_assay",
|
| 253 |
+
acting_role="assay_planner",
|
| 254 |
+
tool_name="dock_target",
|
| 255 |
+
rationale="The hard scenario requires post-mutation potency evidence for the submitted molecule.",
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if synth_threshold is not None and known_synth is not None and known_synth < synth_threshold:
|
| 259 |
+
for slot, fragment, rationale in [
|
| 260 |
+
("hinge", "pyridine", "Simplifying the hinge improves synthetic tractability."),
|
| 261 |
+
("back_pocket", "methoxy", "A smaller back-pocket group reduces route burden."),
|
| 262 |
+
]:
|
| 263 |
+
if current[slot] != fragment:
|
| 264 |
+
return MolForgeAction(
|
| 265 |
+
action_type="edit",
|
| 266 |
+
acting_role="lead_chemist",
|
| 267 |
+
edit_type="substitute",
|
| 268 |
+
slot=slot, # type: ignore[arg-type]
|
| 269 |
+
fragment=fragment,
|
| 270 |
+
rationale=rationale,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if has_required_evidence and (post_shift_potency_ready or observation.step_index >= observation.max_steps - 1):
|
| 274 |
+
return MolForgeAction(
|
| 275 |
+
action_type="submit",
|
| 276 |
+
acting_role="lead_chemist",
|
| 277 |
+
rationale="The episode horizon is nearly exhausted and current evidence is available, so the team should submit.",
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if observation.remaining_budget >= 100:
|
| 281 |
+
return MolForgeAction(
|
| 282 |
+
action_type="run_assay",
|
| 283 |
+
acting_role="assay_planner",
|
| 284 |
+
tool_name="search_literature",
|
| 285 |
+
rationale="The team needs additional qualitative signal before making the next irreversible move.",
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
return MolForgeAction(
|
| 289 |
+
action_type="defer",
|
| 290 |
+
acting_role="lead_chemist",
|
| 291 |
+
rationale="No high-confidence move remains under the current budget.",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def attach_team_messages(
|
| 296 |
+
observation: MolForgeObservation,
|
| 297 |
+
action: MolForgeAction,
|
| 298 |
+
) -> MolForgeAction:
|
| 299 |
+
messages = [
|
| 300 |
+
AgentMessage(
|
| 301 |
+
sender=action.acting_role,
|
| 302 |
+
message_type="proposal",
|
| 303 |
+
severity="medium",
|
| 304 |
+
summary=proposal_summary(action),
|
| 305 |
+
payload=proposal_payload(action),
|
| 306 |
+
)
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
current = current_fragments(observation)
|
| 310 |
+
known_potency = known_estimate(observation, "potency")
|
| 311 |
+
known_toxicity = known_estimate(observation, "toxicity")
|
| 312 |
+
known_synth = known_estimate(observation, "synth")
|
| 313 |
+
toxicity_threshold = threshold_value(observation, "toxicity_max")
|
| 314 |
+
synth_threshold = threshold_value(observation, "synth_min")
|
| 315 |
+
|
| 316 |
+
if action.action_type == "run_assay":
|
| 317 |
+
messages.append(
|
| 318 |
+
AgentMessage(
|
| 319 |
+
sender="toxicologist",
|
| 320 |
+
message_type="approval",
|
| 321 |
+
severity="medium",
|
| 322 |
+
summary="Fresh assay evidence improves safety oversight.",
|
| 323 |
+
)
|
| 324 |
+
)
|
| 325 |
+
if action.acting_role != "assay_planner":
|
| 326 |
+
messages.append(
|
| 327 |
+
AgentMessage(
|
| 328 |
+
sender="assay_planner",
|
| 329 |
+
message_type="approval",
|
| 330 |
+
severity="medium",
|
| 331 |
+
summary="This assay is budget-efficient for the current evidence gap.",
|
| 332 |
+
)
|
| 333 |
+
)
|
| 334 |
+
if "process_chemist" in observation.enabled_roles and len(messages) < 4:
|
| 335 |
+
messages.append(
|
| 336 |
+
AgentMessage(
|
| 337 |
+
sender="process_chemist",
|
| 338 |
+
message_type="approval",
|
| 339 |
+
severity="low",
|
| 340 |
+
summary="Additional evidence now will reduce late-stage feasibility surprises.",
|
| 341 |
+
)
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
elif action.action_type == "restart":
|
| 345 |
+
messages.extend(
|
| 346 |
+
[
|
| 347 |
+
AgentMessage(
|
| 348 |
+
sender="toxicologist",
|
| 349 |
+
message_type="approval",
|
| 350 |
+
severity="high",
|
| 351 |
+
summary="Restarting moves away from the current scaffold safety liabilities.",
|
| 352 |
+
),
|
| 353 |
+
AgentMessage(
|
| 354 |
+
sender="assay_planner",
|
| 355 |
+
message_type="approval",
|
| 356 |
+
severity="high",
|
| 357 |
+
summary="Restarting now is cheaper than polishing a doomed series.",
|
| 358 |
+
),
|
| 359 |
+
]
|
| 360 |
+
)
|
| 361 |
+
if "process_chemist" in observation.enabled_roles and len(messages) < 4:
|
| 362 |
+
messages.append(
|
| 363 |
+
AgentMessage(
|
| 364 |
+
sender="process_chemist",
|
| 365 |
+
message_type="approval",
|
| 366 |
+
severity="medium",
|
| 367 |
+
summary="The alternate scaffold family is more tractable to make.",
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
elif action.action_type == "submit":
|
| 372 |
+
tox_message_type = "approval"
|
| 373 |
+
tox_summary = "Visible evidence supports a safe-enough submission."
|
| 374 |
+
if known_toxicity is None:
|
| 375 |
+
tox_message_type = "assay_request"
|
| 376 |
+
tox_summary = "Submission should wait until toxicity has been assayed."
|
| 377 |
+
elif toxicity_threshold is not None and known_toxicity > toxicity_threshold:
|
| 378 |
+
tox_message_type = "objection"
|
| 379 |
+
tox_summary = "Visible toxicity evidence is still above the submission threshold."
|
| 380 |
+
messages.append(
|
| 381 |
+
AgentMessage(
|
| 382 |
+
sender="toxicologist",
|
| 383 |
+
message_type=tox_message_type,
|
| 384 |
+
severity="high" if tox_message_type != "approval" else "medium",
|
| 385 |
+
summary=tox_summary,
|
| 386 |
+
)
|
| 387 |
+
)
|
| 388 |
+
messages.append(
|
| 389 |
+
AgentMessage(
|
| 390 |
+
sender="assay_planner",
|
| 391 |
+
message_type=(
|
| 392 |
+
"approval"
|
| 393 |
+
if tox_message_type == "approval"
|
| 394 |
+
and known_potency is not None
|
| 395 |
+
and (synth_threshold is None or known_synth is not None)
|
| 396 |
+
else "assay_request"
|
| 397 |
+
),
|
| 398 |
+
severity="medium",
|
| 399 |
+
summary=(
|
| 400 |
+
"The team has enough evidence to submit."
|
| 401 |
+
if tox_message_type == "approval"
|
| 402 |
+
and known_potency is not None
|
| 403 |
+
and (synth_threshold is None or known_synth is not None)
|
| 404 |
+
else "More evidence is needed before budget should be spent on submission."
|
| 405 |
+
),
|
| 406 |
+
)
|
| 407 |
+
)
|
| 408 |
+
if "process_chemist" in observation.enabled_roles and len(messages) < 4:
|
| 409 |
+
if known_synth is None and synth_threshold is not None:
|
| 410 |
+
process_message_type = "assay_request"
|
| 411 |
+
process_summary = "Submission should wait for explicit route feasibility evidence."
|
| 412 |
+
elif synth_threshold is not None and known_synth is not None and known_synth < synth_threshold:
|
| 413 |
+
process_message_type = "objection"
|
| 414 |
+
process_summary = "Submission is premature because the route still looks too fragile."
|
| 415 |
+
else:
|
| 416 |
+
process_message_type = "approval"
|
| 417 |
+
process_summary = "Current route risk looks acceptable for submission."
|
| 418 |
+
messages.append(
|
| 419 |
+
AgentMessage(
|
| 420 |
+
sender="process_chemist",
|
| 421 |
+
message_type=process_message_type,
|
| 422 |
+
severity="medium",
|
| 423 |
+
summary=process_summary,
|
| 424 |
+
)
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
elif action.action_type == "edit":
|
| 428 |
+
safer_edit = is_safer_edit(current, action, known_toxicity, toxicity_threshold)
|
| 429 |
+
messages.append(
|
| 430 |
+
AgentMessage(
|
| 431 |
+
sender="toxicologist",
|
| 432 |
+
message_type="approval" if safer_edit else "risk_flag",
|
| 433 |
+
severity="medium",
|
| 434 |
+
summary=(
|
| 435 |
+
"This edit is directionally safer than the current fragment choice."
|
| 436 |
+
if safer_edit
|
| 437 |
+
else "This edit could carry additional safety pressure."
|
| 438 |
+
),
|
| 439 |
+
)
|
| 440 |
+
)
|
| 441 |
+
messages.append(
|
| 442 |
+
AgentMessage(
|
| 443 |
+
sender="assay_planner",
|
| 444 |
+
message_type="approval",
|
| 445 |
+
severity="low",
|
| 446 |
+
summary="The edit is cheap enough to try before another expensive assay.",
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
if "process_chemist" in observation.enabled_roles and len(messages) < 4:
|
| 450 |
+
route_risk = action.slot == "hinge" and action.fragment == "quinazoline"
|
| 451 |
+
messages.append(
|
| 452 |
+
AgentMessage(
|
| 453 |
+
sender="process_chemist",
|
| 454 |
+
message_type="approval" if not route_risk else "objection",
|
| 455 |
+
severity="low" if not route_risk else "medium",
|
| 456 |
+
summary=(
|
| 457 |
+
"The route impact looks manageable."
|
| 458 |
+
if not route_risk
|
| 459 |
+
else "This edit worsens route complexity more than I like."
|
| 460 |
+
),
|
| 461 |
+
)
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
action.messages = messages[:4]
|
| 465 |
+
return action
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def proposal_summary(action: MolForgeAction) -> str:
|
| 469 |
+
if action.action_type == "edit":
|
| 470 |
+
return f"Propose {action.edit_type} on {action.slot} to {action.fragment}."
|
| 471 |
+
if action.action_type == "run_assay":
|
| 472 |
+
return f"Propose running {action.tool_name}."
|
| 473 |
+
if action.action_type == "restart":
|
| 474 |
+
return "Propose abandoning the current scaffold and restarting."
|
| 475 |
+
if action.action_type == "submit":
|
| 476 |
+
return "Propose submitting the current candidate."
|
| 477 |
+
return "Propose holding the current state."
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def proposal_payload(action: MolForgeAction) -> Dict[str, Any]:
|
| 481 |
+
payload = {"action_type": action.action_type}
|
| 482 |
+
if action.slot:
|
| 483 |
+
payload["slot"] = action.slot
|
| 484 |
+
if action.fragment:
|
| 485 |
+
payload["fragment"] = action.fragment
|
| 486 |
+
if action.tool_name:
|
| 487 |
+
payload["tool_name"] = action.tool_name
|
| 488 |
+
return payload
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def build_action_evidence(
|
| 492 |
+
observation: MolForgeObservation,
|
| 493 |
+
action: MolForgeAction,
|
| 494 |
+
) -> list[str]:
|
| 495 |
+
evidence = [
|
| 496 |
+
f"scenario={observation.scenario_id}",
|
| 497 |
+
f"budget={observation.remaining_budget}/{observation.max_budget}",
|
| 498 |
+
f"step={observation.step_index}/{observation.max_steps}",
|
| 499 |
+
]
|
| 500 |
+
current = current_fragments(observation)
|
| 501 |
+
known_props = [
|
| 502 |
+
f"{name}={value:.3f}"
|
| 503 |
+
for name, value in observation.visible_metrics.items()
|
| 504 |
+
if name in {"potency", "toxicity", "synth", "novelty"}
|
| 505 |
+
]
|
| 506 |
+
if known_props:
|
| 507 |
+
evidence.append("visible_metrics:" + ",".join(known_props[:3]))
|
| 508 |
+
else:
|
| 509 |
+
unknown = [
|
| 510 |
+
constraint.name
|
| 511 |
+
for constraint in observation.constraint_status
|
| 512 |
+
if constraint.evidence_status == "unknown"
|
| 513 |
+
]
|
| 514 |
+
if unknown:
|
| 515 |
+
evidence.append("unknown_constraints:" + ",".join(unknown[:3]))
|
| 516 |
+
|
| 517 |
+
if action.action_type == "edit" and action.slot and action.fragment:
|
| 518 |
+
evidence.append(f"current_{action.slot}={current[action.slot]}")
|
| 519 |
+
evidence.append(f"candidate_{action.slot}={action.fragment}")
|
| 520 |
+
elif action.action_type == "run_assay" and action.tool_name:
|
| 521 |
+
gaps = [
|
| 522 |
+
constraint.name
|
| 523 |
+
for constraint in observation.constraint_status
|
| 524 |
+
if constraint.evidence_status == "unknown"
|
| 525 |
+
]
|
| 526 |
+
evidence.append(f"tool={action.tool_name}")
|
| 527 |
+
if gaps:
|
| 528 |
+
evidence.append("evidence_gaps:" + ",".join(gaps[:3]))
|
| 529 |
+
elif action.action_type == "submit":
|
| 530 |
+
known = [
|
| 531 |
+
constraint.name
|
| 532 |
+
for constraint in observation.constraint_status
|
| 533 |
+
if constraint.evidence_status == "known"
|
| 534 |
+
]
|
| 535 |
+
evidence.append("known_constraints:" + ",".join(known[:3]) if known else "known_constraints=none")
|
| 536 |
+
elif action.action_type == "restart":
|
| 537 |
+
evidence.append("restart_available=true")
|
| 538 |
+
evidence.append(f"current_molecule={observation.current_molecule}")
|
| 539 |
+
|
| 540 |
+
return evidence[:5]
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def build_expected_effects(
|
| 544 |
+
observation: MolForgeObservation,
|
| 545 |
+
action: MolForgeAction,
|
| 546 |
+
) -> Dict[str, str]:
|
| 547 |
+
effects: Dict[str, str] = {
|
| 548 |
+
"potency": "unknown",
|
| 549 |
+
"toxicity": "unknown",
|
| 550 |
+
"synth": "unknown",
|
| 551 |
+
"novelty": "unknown",
|
| 552 |
+
"budget": "neutral",
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
if action.action_type == "run_assay":
|
| 556 |
+
effects.update(
|
| 557 |
+
{
|
| 558 |
+
"potency": "not_applicable",
|
| 559 |
+
"toxicity": "not_applicable",
|
| 560 |
+
"synth": "not_applicable",
|
| 561 |
+
"novelty": "not_applicable",
|
| 562 |
+
"budget": "down",
|
| 563 |
+
}
|
| 564 |
+
)
|
| 565 |
+
return effects
|
| 566 |
+
|
| 567 |
+
if action.action_type == "submit":
|
| 568 |
+
effects.update(
|
| 569 |
+
{
|
| 570 |
+
"potency": "not_applicable",
|
| 571 |
+
"toxicity": "not_applicable",
|
| 572 |
+
"synth": "not_applicable",
|
| 573 |
+
"novelty": "not_applicable",
|
| 574 |
+
"budget": "neutral",
|
| 575 |
+
}
|
| 576 |
+
)
|
| 577 |
+
return effects
|
| 578 |
+
|
| 579 |
+
if action.action_type == "restart":
|
| 580 |
+
effects.update({"toxicity": "down", "synth": "up", "budget": "down"})
|
| 581 |
+
if observation.scenario_id == "level_2_hard":
|
| 582 |
+
effects["potency"] = "up"
|
| 583 |
+
return effects
|
| 584 |
+
|
| 585 |
+
if action.action_type != "edit":
|
| 586 |
+
return effects
|
| 587 |
+
|
| 588 |
+
fragment = action.fragment or ""
|
| 589 |
+
slot = action.slot or ""
|
| 590 |
+
if slot == "hinge" and fragment == "azaindole":
|
| 591 |
+
effects["potency"] = "up"
|
| 592 |
+
if slot == "back_pocket" and fragment == "cyano":
|
| 593 |
+
effects["potency"] = "up"
|
| 594 |
+
effects["toxicity"] = "down"
|
| 595 |
+
if slot == "back_pocket" and fragment in {"chloro", "trifluoromethyl"}:
|
| 596 |
+
effects["potency"] = "up"
|
| 597 |
+
effects["toxicity"] = "up"
|
| 598 |
+
if slot == "solvent_tail" and fragment == "morpholine":
|
| 599 |
+
effects["toxicity"] = "down"
|
| 600 |
+
effects["synth"] = "up"
|
| 601 |
+
if slot == "solvent_tail" and fragment == "dimethylamino":
|
| 602 |
+
effects["toxicity"] = "up"
|
| 603 |
+
if slot == "warhead" and fragment == "reversible_cyanoacrylamide":
|
| 604 |
+
effects["toxicity"] = "down"
|
| 605 |
+
effects["novelty"] = "up"
|
| 606 |
+
if slot == "warhead" and fragment == "nitrile":
|
| 607 |
+
effects["toxicity"] = "down"
|
| 608 |
+
if observation.scenario_id == "level_2_hard":
|
| 609 |
+
effects["potency"] = "up"
|
| 610 |
+
return effects
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def current_fragments(observation: MolForgeObservation) -> Dict[str, str]:
|
| 614 |
+
return {entry.slot: entry.fragment for entry in observation.molecule_slots}
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def known_estimate(observation: MolForgeObservation, property_name: str) -> Optional[float]:
|
| 618 |
+
current_signature = observation.current_molecule
|
| 619 |
+
for reading in reversed(observation.known_assays):
|
| 620 |
+
if reading.molecule_signature == current_signature and reading.property_name == property_name:
|
| 621 |
+
return reading.estimate
|
| 622 |
+
return None
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def current_property_names(observation: MolForgeObservation) -> set[str]:
|
| 626 |
+
current_signature = observation.current_molecule
|
| 627 |
+
return {
|
| 628 |
+
reading.property_name
|
| 629 |
+
for reading in observation.known_assays
|
| 630 |
+
if reading.molecule_signature == current_signature
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def has_assay_tool(observation: MolForgeObservation, tool_name: str) -> bool:
|
| 635 |
+
current_signature = observation.current_molecule
|
| 636 |
+
return any(
|
| 637 |
+
reading.molecule_signature == current_signature and reading.tool_name == tool_name
|
| 638 |
+
for reading in observation.known_assays
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def planned_fragment_edit(
|
| 643 |
+
observation: MolForgeObservation,
|
| 644 |
+
current: Dict[str, str],
|
| 645 |
+
) -> Optional[tuple[str, str, str]]:
|
| 646 |
+
plans = {
|
| 647 |
+
"level_0_easy": [
|
| 648 |
+
("solvent_tail", "morpholine", "Morpholine improves safety and keeps synthesis comfortably feasible."),
|
| 649 |
+
("back_pocket", "cyano", "Cyano repairs the chloro safety liability while preserving potency."),
|
| 650 |
+
("hinge", "azaindole", "Azaindole is needed to clear the stricter potency floor after safety is stabilized."),
|
| 651 |
+
],
|
| 652 |
+
"level_1_medium": [
|
| 653 |
+
("solvent_tail", "morpholine", "First remove the largest safety liability before paying for assays."),
|
| 654 |
+
("back_pocket", "cyano", "Cyano keeps potency while avoiding the chloro safety penalty."),
|
| 655 |
+
("hinge", "azaindole", "Azaindole recovers enough potency for the tighter medium target."),
|
| 656 |
+
],
|
| 657 |
+
}
|
| 658 |
+
for slot, fragment, rationale in plans.get(observation.scenario_id, []):
|
| 659 |
+
if current[slot] != fragment:
|
| 660 |
+
return slot, fragment, rationale
|
| 661 |
+
return None
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def on_planned_final_candidate(
|
| 665 |
+
observation: MolForgeObservation,
|
| 666 |
+
current: Dict[str, str],
|
| 667 |
+
) -> bool:
|
| 668 |
+
finals = {
|
| 669 |
+
"level_0_easy": {
|
| 670 |
+
"warhead": "acrylamide",
|
| 671 |
+
"hinge": "azaindole",
|
| 672 |
+
"solvent_tail": "morpholine",
|
| 673 |
+
"back_pocket": "cyano",
|
| 674 |
+
},
|
| 675 |
+
"level_1_medium": {
|
| 676 |
+
"warhead": "acrylamide",
|
| 677 |
+
"hinge": "azaindole",
|
| 678 |
+
"solvent_tail": "morpholine",
|
| 679 |
+
"back_pocket": "cyano",
|
| 680 |
+
},
|
| 681 |
+
"level_2_hard": {
|
| 682 |
+
"warhead": "nitrile",
|
| 683 |
+
"hinge": "azaindole",
|
| 684 |
+
"solvent_tail": "morpholine",
|
| 685 |
+
"back_pocket": "cyano",
|
| 686 |
+
},
|
| 687 |
+
}
|
| 688 |
+
return current == finals.get(observation.scenario_id, {})
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def can_collect_potency_now(observation: MolForgeObservation) -> bool:
|
| 692 |
+
return observation.scenario_id != "level_2_hard" or observation.step_index >= 3
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def hard_post_shift_potency_ready(observation: MolForgeObservation) -> bool:
|
| 696 |
+
if observation.scenario_id != "level_2_hard":
|
| 697 |
+
return True
|
| 698 |
+
current_signature = observation.current_molecule
|
| 699 |
+
return any(
|
| 700 |
+
reading.molecule_signature == current_signature
|
| 701 |
+
and reading.property_name == "potency"
|
| 702 |
+
and observation.step_index >= 4
|
| 703 |
+
for reading in observation.known_assays
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def constraints_pass_from_visible_evidence(observation: MolForgeObservation) -> bool:
|
| 708 |
+
if not observation.constraint_status:
|
| 709 |
+
return False
|
| 710 |
+
return all(
|
| 711 |
+
constraint.evidence_status == "known" and constraint.satisfied is True
|
| 712 |
+
for constraint in observation.constraint_status
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def threshold_value(observation: MolForgeObservation, constraint_name: str) -> Optional[float]:
|
| 717 |
+
for constraint in observation.constraint_status:
|
| 718 |
+
if constraint.name != constraint_name:
|
| 719 |
+
continue
|
| 720 |
+
try:
|
| 721 |
+
return float(constraint.target.split()[-1])
|
| 722 |
+
except Exception:
|
| 723 |
+
return None
|
| 724 |
+
return None
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def is_safety_risky(
|
| 728 |
+
fragments: Dict[str, str],
|
| 729 |
+
known_toxicity: Optional[float],
|
| 730 |
+
toxicity_threshold: Optional[float],
|
| 731 |
+
) -> bool:
|
| 732 |
+
if known_toxicity is not None and toxicity_threshold is not None and known_toxicity > toxicity_threshold:
|
| 733 |
+
return True
|
| 734 |
+
risky_patterns = [
|
| 735 |
+
fragments["solvent_tail"] == "dimethylamino",
|
| 736 |
+
fragments["back_pocket"] == "trifluoromethyl",
|
| 737 |
+
fragments["hinge"] == "fluorophenyl" and fragments["back_pocket"] == "chloro",
|
| 738 |
+
]
|
| 739 |
+
return any(risky_patterns)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def is_safer_edit(
|
| 743 |
+
current: Dict[str, str],
|
| 744 |
+
action: MolForgeAction,
|
| 745 |
+
known_toxicity: Optional[float],
|
| 746 |
+
toxicity_threshold: Optional[float],
|
| 747 |
+
) -> bool:
|
| 748 |
+
if action.slot == "solvent_tail" and action.fragment == "morpholine":
|
| 749 |
+
return True
|
| 750 |
+
if action.slot == "back_pocket" and action.fragment == "cyano":
|
| 751 |
+
return True
|
| 752 |
+
if action.slot == "warhead" and action.fragment == "reversible_cyanoacrylamide":
|
| 753 |
+
return True
|
| 754 |
+
if known_toxicity is not None and toxicity_threshold is not None:
|
| 755 |
+
return known_toxicity <= toxicity_threshold
|
| 756 |
+
return current["solvent_tail"] != "dimethylamino"
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def extract_json(text: str) -> Dict[str, Any]:
|
| 760 |
+
start = text.find("{")
|
| 761 |
+
end = text.rfind("}")
|
| 762 |
+
if start == -1 or end == -1 or start >= end:
|
| 763 |
+
raise ValueError("No JSON object found in model response")
|
| 764 |
+
return json.loads(text[start : end + 1])
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def build_model_payload(
|
| 768 |
+
observation: MolForgeObservation,
|
| 769 |
+
*,
|
| 770 |
+
compact: bool,
|
| 771 |
+
) -> Dict[str, Any]:
|
| 772 |
+
base_payload = {
|
| 773 |
+
"valid_top_level_action_types": ["edit", "run_assay", "submit", "restart", "defer"],
|
| 774 |
+
"invalid_top_level_action_types": [
|
| 775 |
+
"proposal",
|
| 776 |
+
"approval",
|
| 777 |
+
"objection",
|
| 778 |
+
"risk_flag",
|
| 779 |
+
"assay_request",
|
| 780 |
+
"rejection",
|
| 781 |
+
"submission_recommendation",
|
| 782 |
+
],
|
| 783 |
+
"scenario_id": observation.scenario_id,
|
| 784 |
+
"difficulty": observation.difficulty,
|
| 785 |
+
"task_brief": observation.task_brief,
|
| 786 |
+
"state_label": observation.state_label,
|
| 787 |
+
"state_path_tail": observation.state_path[-4:],
|
| 788 |
+
"current_molecule": observation.current_molecule,
|
| 789 |
+
"current_smiles": observation.metadata.get("current_smiles", ""),
|
| 790 |
+
"oracle_backend": observation.metadata.get("oracle_backend", {}),
|
| 791 |
+
"visible_metrics": observation.visible_metrics,
|
| 792 |
+
"constraint_status": [constraint.model_dump() for constraint in observation.constraint_status],
|
| 793 |
+
"governance": observation.governance.model_dump(),
|
| 794 |
+
"last_transition_summary": observation.last_transition_summary,
|
| 795 |
+
"allowed_actions": observation.allowed_actions,
|
| 796 |
+
"role_message_rules": {
|
| 797 |
+
"lead_chemist": ["proposal", "revision_request", "submission_recommendation"],
|
| 798 |
+
"assay_planner": ["proposal", "approval", "rejection", "assay_request", "submission_recommendation"],
|
| 799 |
+
"toxicologist": ["approval", "objection", "risk_flag", "assay_request", "rejection"],
|
| 800 |
+
"process_chemist": ["approval", "objection", "risk_flag", "assay_request"],
|
| 801 |
+
},
|
| 802 |
+
"remaining_budget": observation.remaining_budget,
|
| 803 |
+
"step_index": observation.step_index,
|
| 804 |
+
"max_steps": observation.max_steps,
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
if compact:
|
| 808 |
+
base_payload["known_assays"] = [
|
| 809 |
+
{
|
| 810 |
+
"tool_name": reading.tool_name,
|
| 811 |
+
"property_name": reading.property_name,
|
| 812 |
+
"estimate": reading.estimate,
|
| 813 |
+
"confidence_low": reading.confidence_low,
|
| 814 |
+
"confidence_high": reading.confidence_high,
|
| 815 |
+
}
|
| 816 |
+
for reading in observation.known_assays[-6:]
|
| 817 |
+
]
|
| 818 |
+
base_payload["role_summaries"] = [
|
| 819 |
+
{
|
| 820 |
+
"role": role.role,
|
| 821 |
+
"local_objective": role.local_objective,
|
| 822 |
+
"key_fields": list(role.observation.keys())[:5],
|
| 823 |
+
}
|
| 824 |
+
for role in observation.role_observations
|
| 825 |
+
]
|
| 826 |
+
return base_payload
|
| 827 |
+
|
| 828 |
+
base_payload["known_assays"] = [reading.model_dump() for reading in observation.known_assays]
|
| 829 |
+
base_payload["role_observations"] = [role.model_dump() for role in observation.role_observations]
|
| 830 |
+
base_payload["recent_messages"] = [message.model_dump() for message in observation.message_log[-6:]]
|
| 831 |
+
return base_payload
|
local_inference.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local-only inference runner for Ollama-backed MolForge testing.
|
| 2 |
+
|
| 3 |
+
This script is intentionally separate from `inference.py`.
|
| 4 |
+
Use `inference.py` for the judge-facing OpenAI-client baseline required by the
|
| 5 |
+
hackathon. Use this file for local development against Ollama's native API,
|
| 6 |
+
where reasoning models often behave better when `think` is explicitly disabled.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
from typing import Any, Dict, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
import requests
|
| 16 |
+
|
| 17 |
+
from inference_common import (
|
| 18 |
+
COMPACT_SYSTEM_PROMPT,
|
| 19 |
+
SYSTEM_PROMPT,
|
| 20 |
+
build_model_payload,
|
| 21 |
+
extract_json,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from molforge.models import MolForgeAction, MolForgeObservation
|
| 26 |
+
from molforge.server.molforge_environment import MolForgeEnvironment
|
| 27 |
+
except ImportError:
|
| 28 |
+
from models import MolForgeAction, MolForgeObservation
|
| 29 |
+
from server.molforge_environment import MolForgeEnvironment
|
| 30 |
+
|
| 31 |
+
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
| 32 |
+
LOCAL_MODEL_NAME = os.getenv("LOCAL_MODEL_NAME", "gemma4:e2b")
|
| 33 |
+
LOCAL_NUM_EPISODES = int(os.getenv("LOCAL_NUM_EPISODES", "3"))
|
| 34 |
+
LOCAL_MAX_TURNS = int(os.getenv("LOCAL_MAX_TURNS", "10"))
|
| 35 |
+
OLLAMA_TIMEOUT_S = float(os.getenv("OLLAMA_TIMEOUT_S", "240"))
|
| 36 |
+
OLLAMA_RETRY_TIMEOUT_S = float(os.getenv("OLLAMA_RETRY_TIMEOUT_S", "120"))
|
| 37 |
+
OLLAMA_MAX_TOKENS = int(os.getenv("OLLAMA_MAX_TOKENS", "768"))
|
| 38 |
+
OLLAMA_THINK = os.getenv("OLLAMA_THINK", "false").lower() == "true"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main() -> None:
|
| 42 |
+
env = MolForgeEnvironment()
|
| 43 |
+
scores = []
|
| 44 |
+
submission_scores = []
|
| 45 |
+
progress_scores = []
|
| 46 |
+
|
| 47 |
+
print(f"Using Ollama model: {LOCAL_MODEL_NAME}", flush=True)
|
| 48 |
+
print(f"Ollama base URL: {OLLAMA_BASE_URL}", flush=True)
|
| 49 |
+
print(f"Thinking enabled: {OLLAMA_THINK}", flush=True)
|
| 50 |
+
|
| 51 |
+
for episode_index in range(LOCAL_NUM_EPISODES):
|
| 52 |
+
observation = env.reset()
|
| 53 |
+
print(f"\n=== Episode {episode_index + 1}: {observation.scenario_id} ===", flush=True)
|
| 54 |
+
|
| 55 |
+
for _ in range(LOCAL_MAX_TURNS):
|
| 56 |
+
if observation.done:
|
| 57 |
+
break
|
| 58 |
+
action, source = choose_local_action(observation)
|
| 59 |
+
observation = env.step(action)
|
| 60 |
+
print(
|
| 61 |
+
f"step={observation.step_index:02d} action={action.action_type} actor={action.acting_role} "
|
| 62 |
+
f"source={source} reward={observation.reward:+.3f} budget={observation.remaining_budget} "
|
| 63 |
+
f"governance={observation.governance.status}",
|
| 64 |
+
flush=True,
|
| 65 |
+
)
|
| 66 |
+
print(f" {observation.last_transition_summary}", flush=True)
|
| 67 |
+
if observation.done:
|
| 68 |
+
break
|
| 69 |
+
|
| 70 |
+
grader_scores = observation.metadata.get("terminal_grader_scores", {})
|
| 71 |
+
final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0)))
|
| 72 |
+
submission_score = float(grader_scores.get("submission_score", 0.0))
|
| 73 |
+
progress_score = float(grader_scores.get("progress_score", 0.0))
|
| 74 |
+
scores.append(final_score)
|
| 75 |
+
submission_scores.append(submission_score)
|
| 76 |
+
progress_scores.append(progress_score)
|
| 77 |
+
print(f"final_score={final_score:.3f}", flush=True)
|
| 78 |
+
print(f"submission_score={submission_score:.3f}", flush=True)
|
| 79 |
+
print(f"progress_score={progress_score:.3f}", flush=True)
|
| 80 |
+
if observation.report_card:
|
| 81 |
+
print(observation.report_card, flush=True)
|
| 82 |
+
|
| 83 |
+
average = sum(scores) / len(scores)
|
| 84 |
+
average_progress = sum(progress_scores) / len(progress_scores)
|
| 85 |
+
print("\n=== Local Baseline Summary ===", flush=True)
|
| 86 |
+
print(
|
| 87 |
+
json.dumps(
|
| 88 |
+
{
|
| 89 |
+
"model": LOCAL_MODEL_NAME,
|
| 90 |
+
"scores": scores,
|
| 91 |
+
"average_final_score": round(average, 4),
|
| 92 |
+
"submission_scores": submission_scores,
|
| 93 |
+
"average_submission_score": round(sum(submission_scores) / len(submission_scores), 4),
|
| 94 |
+
"progress_scores": progress_scores,
|
| 95 |
+
"average_progress_score": round(average_progress, 4),
|
| 96 |
+
},
|
| 97 |
+
indent=2,
|
| 98 |
+
),
|
| 99 |
+
flush=True,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def choose_local_action(observation: MolForgeObservation) -> Tuple[MolForgeAction, str]:
|
| 104 |
+
"""Use Ollama output and fail loudly if it cannot produce a valid action."""
|
| 105 |
+
|
| 106 |
+
action, error = ask_ollama_model(observation)
|
| 107 |
+
if action is not None:
|
| 108 |
+
return action, "model"
|
| 109 |
+
raise RuntimeError(f"Local model action failed: {error}")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def ask_ollama_model(observation: MolForgeObservation) -> Tuple[Optional[MolForgeAction], str]:
|
| 113 |
+
"""Call Ollama's native chat API.
|
| 114 |
+
|
| 115 |
+
Official Ollama docs note that reasoning traces live in `message.thinking`
|
| 116 |
+
while the final answer lives in `message.content`, and that `think: false`
|
| 117 |
+
can disable thinking on the native chat endpoint.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
errors = []
|
| 121 |
+
try:
|
| 122 |
+
payload = build_model_payload(observation, compact=False)
|
| 123 |
+
response_json = ollama_chat(
|
| 124 |
+
system_prompt=SYSTEM_PROMPT,
|
| 125 |
+
user_payload=payload,
|
| 126 |
+
timeout_s=OLLAMA_TIMEOUT_S,
|
| 127 |
+
)
|
| 128 |
+
data = parse_ollama_json_response(response_json)
|
| 129 |
+
return MolForgeAction(**data), ""
|
| 130 |
+
except Exception as exc:
|
| 131 |
+
errors.append(f"full_prompt:{exc.__class__.__name__}:{exc}")
|
| 132 |
+
try:
|
| 133 |
+
payload = build_model_payload(observation, compact=True)
|
| 134 |
+
response_json = ollama_chat(
|
| 135 |
+
system_prompt=COMPACT_SYSTEM_PROMPT,
|
| 136 |
+
user_payload=payload,
|
| 137 |
+
timeout_s=OLLAMA_RETRY_TIMEOUT_S,
|
| 138 |
+
)
|
| 139 |
+
data = parse_ollama_json_response(response_json)
|
| 140 |
+
return MolForgeAction(**data), ""
|
| 141 |
+
except Exception as retry_exc:
|
| 142 |
+
errors.append(f"compact_prompt:{retry_exc.__class__.__name__}:{retry_exc}")
|
| 143 |
+
return None, " | ".join(errors)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def ollama_chat(
|
| 147 |
+
*,
|
| 148 |
+
system_prompt: str,
|
| 149 |
+
user_payload: Dict[str, Any],
|
| 150 |
+
timeout_s: float,
|
| 151 |
+
) -> Dict[str, Any]:
|
| 152 |
+
"""Issue a native Ollama chat request."""
|
| 153 |
+
|
| 154 |
+
response = requests.post(
|
| 155 |
+
f"{OLLAMA_BASE_URL.rstrip('/')}/api/chat",
|
| 156 |
+
json={
|
| 157 |
+
"model": LOCAL_MODEL_NAME,
|
| 158 |
+
"stream": False,
|
| 159 |
+
"think": OLLAMA_THINK,
|
| 160 |
+
"format": "json",
|
| 161 |
+
"messages": [
|
| 162 |
+
{"role": "system", "content": system_prompt},
|
| 163 |
+
{"role": "user", "content": json.dumps(user_payload, indent=2)},
|
| 164 |
+
],
|
| 165 |
+
"options": {
|
| 166 |
+
"temperature": 0,
|
| 167 |
+
"num_predict": OLLAMA_MAX_TOKENS,
|
| 168 |
+
},
|
| 169 |
+
},
|
| 170 |
+
timeout=timeout_s,
|
| 171 |
+
)
|
| 172 |
+
response.raise_for_status()
|
| 173 |
+
return response.json()
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def parse_ollama_json_response(response_json: Dict[str, Any]) -> Dict[str, Any]:
|
| 177 |
+
"""Extract a JSON action from a native Ollama response."""
|
| 178 |
+
|
| 179 |
+
message = response_json.get("message", {}) or {}
|
| 180 |
+
content = message.get("content", "") or ""
|
| 181 |
+
thinking = message.get("thinking", "") or ""
|
| 182 |
+
|
| 183 |
+
if content:
|
| 184 |
+
try:
|
| 185 |
+
return extract_json(content)
|
| 186 |
+
except Exception:
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
if thinking:
|
| 190 |
+
try:
|
| 191 |
+
return extract_json(thinking)
|
| 192 |
+
except Exception:
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
combined = f"{content}\n{thinking}".strip()
|
| 196 |
+
if combined:
|
| 197 |
+
return extract_json(combined)
|
| 198 |
+
|
| 199 |
+
raise ValueError("No parseable JSON action found in Ollama response")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
main()
|
lora_inference.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local PEFT/LoRA inference runner for MolForge.
|
| 2 |
+
|
| 3 |
+
Use this to test an SFT adapter against the environment before RL. It loads the
|
| 4 |
+
base model named in the adapter config, attaches the LoRA weights, and requires
|
| 5 |
+
the model to emit a valid MolForgeAction JSON object. There is no heuristic
|
| 6 |
+
fallback or schema repair.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from peft import PeftConfig, PeftModel
|
| 18 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Qwen3_5ForConditionalGeneration
|
| 19 |
+
|
| 20 |
+
from inference_common import (
|
| 21 |
+
COMPACT_SYSTEM_PROMPT,
|
| 22 |
+
SYSTEM_PROMPT,
|
| 23 |
+
build_model_payload,
|
| 24 |
+
extract_json,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from molforge.models import MolForgeAction, MolForgeObservation
|
| 29 |
+
from molforge.server.molforge_environment import MolForgeEnvironment
|
| 30 |
+
except ImportError:
|
| 31 |
+
from models import MolForgeAction, MolForgeObservation
|
| 32 |
+
from server.molforge_environment import MolForgeEnvironment
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
ADAPTER_PATH = Path(os.getenv("LORA_ADAPTER_PATH", "qwen3_5_2b_lora_adapters"))
|
| 36 |
+
LOCAL_NUM_EPISODES = int(os.getenv("LOCAL_NUM_EPISODES", "3"))
|
| 37 |
+
LOCAL_MAX_TURNS = int(os.getenv("LOCAL_MAX_TURNS", "10"))
|
| 38 |
+
LORA_MAX_NEW_TOKENS = int(os.getenv("LORA_MAX_NEW_TOKENS", "768"))
|
| 39 |
+
LORA_RETRY_MAX_NEW_TOKENS = int(os.getenv("LORA_RETRY_MAX_NEW_TOKENS", "512"))
|
| 40 |
+
LORA_DEVICE = os.getenv("LORA_DEVICE", "auto")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main() -> None:
|
| 44 |
+
adapter_path = ADAPTER_PATH.expanduser().resolve()
|
| 45 |
+
tokenizer, model, base_model_name, device = load_adapter_model(adapter_path)
|
| 46 |
+
env = MolForgeEnvironment()
|
| 47 |
+
scores = []
|
| 48 |
+
submission_scores = []
|
| 49 |
+
progress_scores = []
|
| 50 |
+
|
| 51 |
+
print(f"Using LoRA adapter: {adapter_path}", flush=True)
|
| 52 |
+
print(f"Base model: {base_model_name}", flush=True)
|
| 53 |
+
print(f"Device: {device}", flush=True)
|
| 54 |
+
|
| 55 |
+
for episode_index in range(LOCAL_NUM_EPISODES):
|
| 56 |
+
observation = env.reset()
|
| 57 |
+
print(f"\n=== Episode {episode_index + 1}: {observation.scenario_id} ===", flush=True)
|
| 58 |
+
|
| 59 |
+
for _ in range(LOCAL_MAX_TURNS):
|
| 60 |
+
if observation.done:
|
| 61 |
+
break
|
| 62 |
+
action, source = choose_lora_action(tokenizer, model, observation, device)
|
| 63 |
+
observation = env.step(action)
|
| 64 |
+
print(
|
| 65 |
+
f"step={observation.step_index:02d} action={action.action_type} actor={action.acting_role} "
|
| 66 |
+
f"source={source} reward={observation.reward:+.3f} budget={observation.remaining_budget} "
|
| 67 |
+
f"governance={observation.governance.status}",
|
| 68 |
+
flush=True,
|
| 69 |
+
)
|
| 70 |
+
print(f" {observation.last_transition_summary}", flush=True)
|
| 71 |
+
if observation.done:
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
grader_scores = observation.metadata.get("terminal_grader_scores", {})
|
| 75 |
+
final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0)))
|
| 76 |
+
submission_score = float(grader_scores.get("submission_score", 0.0))
|
| 77 |
+
progress_score = float(grader_scores.get("progress_score", 0.0))
|
| 78 |
+
scores.append(final_score)
|
| 79 |
+
submission_scores.append(submission_score)
|
| 80 |
+
progress_scores.append(progress_score)
|
| 81 |
+
print(f"final_score={final_score:.3f}", flush=True)
|
| 82 |
+
print(f"submission_score={submission_score:.3f}", flush=True)
|
| 83 |
+
print(f"progress_score={progress_score:.3f}", flush=True)
|
| 84 |
+
if observation.report_card:
|
| 85 |
+
print(observation.report_card, flush=True)
|
| 86 |
+
|
| 87 |
+
average = sum(scores) / len(scores)
|
| 88 |
+
average_progress = sum(progress_scores) / len(progress_scores)
|
| 89 |
+
print("\n=== LoRA Local Summary ===", flush=True)
|
| 90 |
+
print(
|
| 91 |
+
json.dumps(
|
| 92 |
+
{
|
| 93 |
+
"adapter": str(adapter_path),
|
| 94 |
+
"base_model": base_model_name,
|
| 95 |
+
"scores": scores,
|
| 96 |
+
"average_final_score": round(average, 4),
|
| 97 |
+
"submission_scores": submission_scores,
|
| 98 |
+
"average_submission_score": round(sum(submission_scores) / len(submission_scores), 4),
|
| 99 |
+
"progress_scores": progress_scores,
|
| 100 |
+
"average_progress_score": round(average_progress, 4),
|
| 101 |
+
},
|
| 102 |
+
indent=2,
|
| 103 |
+
),
|
| 104 |
+
flush=True,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def load_adapter_model(adapter_path: Path):
|
| 109 |
+
config = PeftConfig.from_pretrained(adapter_path)
|
| 110 |
+
base_model_name = config.base_model_name_or_path
|
| 111 |
+
device = resolve_device()
|
| 112 |
+
dtype = torch.float16 if device in {"cuda", "mps"} else torch.float32
|
| 113 |
+
|
| 114 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 115 |
+
adapter_path,
|
| 116 |
+
trust_remote_code=True,
|
| 117 |
+
use_fast=True,
|
| 118 |
+
)
|
| 119 |
+
if tokenizer.pad_token_id is None:
|
| 120 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 121 |
+
|
| 122 |
+
base_config = AutoConfig.from_pretrained(base_model_name, trust_remote_code=True)
|
| 123 |
+
model_class = (
|
| 124 |
+
Qwen3_5ForConditionalGeneration
|
| 125 |
+
if "Qwen3_5ForConditionalGeneration" in (base_config.architectures or [])
|
| 126 |
+
else AutoModelForCausalLM
|
| 127 |
+
)
|
| 128 |
+
base_model = model_class.from_pretrained(
|
| 129 |
+
base_model_name,
|
| 130 |
+
dtype=dtype,
|
| 131 |
+
trust_remote_code=True,
|
| 132 |
+
low_cpu_mem_usage=True,
|
| 133 |
+
)
|
| 134 |
+
model = PeftModel.from_pretrained(base_model, adapter_path)
|
| 135 |
+
model.to(device)
|
| 136 |
+
model.eval()
|
| 137 |
+
return tokenizer, model, base_model_name, device
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def resolve_device() -> str:
|
| 141 |
+
if LORA_DEVICE != "auto":
|
| 142 |
+
return LORA_DEVICE
|
| 143 |
+
if torch.cuda.is_available():
|
| 144 |
+
return "cuda"
|
| 145 |
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 146 |
+
return "mps"
|
| 147 |
+
return "cpu"
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def choose_lora_action(
|
| 151 |
+
tokenizer,
|
| 152 |
+
model,
|
| 153 |
+
observation: MolForgeObservation,
|
| 154 |
+
device: str,
|
| 155 |
+
) -> Tuple[MolForgeAction, str]:
|
| 156 |
+
action, error = ask_lora_model(
|
| 157 |
+
tokenizer,
|
| 158 |
+
model,
|
| 159 |
+
observation,
|
| 160 |
+
device,
|
| 161 |
+
compact=False,
|
| 162 |
+
max_new_tokens=LORA_MAX_NEW_TOKENS,
|
| 163 |
+
)
|
| 164 |
+
if action is not None:
|
| 165 |
+
return action, "lora_model"
|
| 166 |
+
|
| 167 |
+
retry_action, retry_error = ask_lora_model(
|
| 168 |
+
tokenizer,
|
| 169 |
+
model,
|
| 170 |
+
observation,
|
| 171 |
+
device,
|
| 172 |
+
compact=True,
|
| 173 |
+
max_new_tokens=LORA_RETRY_MAX_NEW_TOKENS,
|
| 174 |
+
)
|
| 175 |
+
if retry_action is not None:
|
| 176 |
+
return retry_action, "lora_model_compact_retry"
|
| 177 |
+
|
| 178 |
+
raise RuntimeError(f"LoRA model action failed: full_prompt:{error} | compact_prompt:{retry_error}")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def ask_lora_model(
|
| 182 |
+
tokenizer,
|
| 183 |
+
model,
|
| 184 |
+
observation: MolForgeObservation,
|
| 185 |
+
device: str,
|
| 186 |
+
*,
|
| 187 |
+
compact: bool,
|
| 188 |
+
max_new_tokens: int,
|
| 189 |
+
) -> Tuple[Optional[MolForgeAction], str]:
|
| 190 |
+
response_text = ""
|
| 191 |
+
try:
|
| 192 |
+
payload = build_model_payload(observation, compact=compact)
|
| 193 |
+
system_prompt = COMPACT_SYSTEM_PROMPT if compact else SYSTEM_PROMPT
|
| 194 |
+
response_text = generate_response(
|
| 195 |
+
tokenizer,
|
| 196 |
+
model,
|
| 197 |
+
device,
|
| 198 |
+
system_prompt=system_prompt,
|
| 199 |
+
user_payload=payload,
|
| 200 |
+
max_new_tokens=max_new_tokens,
|
| 201 |
+
)
|
| 202 |
+
data = extract_json(response_text)
|
| 203 |
+
return MolForgeAction(**data), ""
|
| 204 |
+
except Exception as exc:
|
| 205 |
+
snippet = response_text[:1200].replace("\n", "\\n")
|
| 206 |
+
return None, f"{exc.__class__.__name__}:{exc}; raw={snippet}"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def generate_response(
|
| 210 |
+
tokenizer,
|
| 211 |
+
model,
|
| 212 |
+
device: str,
|
| 213 |
+
*,
|
| 214 |
+
system_prompt: str,
|
| 215 |
+
user_payload: Dict[str, Any],
|
| 216 |
+
max_new_tokens: int,
|
| 217 |
+
) -> str:
|
| 218 |
+
messages = [
|
| 219 |
+
{"role": "system", "content": system_prompt},
|
| 220 |
+
{"role": "user", "content": json.dumps(user_payload, separators=(",", ":"))},
|
| 221 |
+
]
|
| 222 |
+
prompt = tokenizer.apply_chat_template(
|
| 223 |
+
messages,
|
| 224 |
+
tokenize=False,
|
| 225 |
+
add_generation_prompt=True,
|
| 226 |
+
enable_thinking=False,
|
| 227 |
+
)
|
| 228 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 229 |
+
with torch.inference_mode():
|
| 230 |
+
generated = model.generate(
|
| 231 |
+
**inputs,
|
| 232 |
+
do_sample=False,
|
| 233 |
+
temperature=None,
|
| 234 |
+
top_p=None,
|
| 235 |
+
max_new_tokens=max_new_tokens,
|
| 236 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 237 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 238 |
+
)
|
| 239 |
+
new_tokens = generated[0, inputs["input_ids"].shape[-1] :]
|
| 240 |
+
return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
main()
|
mlx_lora_inference.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MLX-backed local LoRA inference runner for MolForge on Apple Silicon."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
from mlx_lm import generate, load
|
| 12 |
+
from mlx_lm.sample_utils import make_sampler
|
| 13 |
+
|
| 14 |
+
from inference_common import (
|
| 15 |
+
COMPACT_SYSTEM_PROMPT,
|
| 16 |
+
SYSTEM_PROMPT,
|
| 17 |
+
attach_team_messages,
|
| 18 |
+
build_model_payload,
|
| 19 |
+
extract_json,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from molforge.models import MolForgeAction, MolForgeObservation
|
| 24 |
+
from molforge.server.molforge_environment import MolForgeEnvironment
|
| 25 |
+
except ImportError:
|
| 26 |
+
from models import MolForgeAction, MolForgeObservation
|
| 27 |
+
from server.molforge_environment import MolForgeEnvironment
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
ADAPTER_PATH = Path(os.getenv("LORA_ADAPTER_PATH", "qwen3_5_2b_lora_adapters_strict"))
|
| 31 |
+
BASE_MODEL_NAME = os.getenv("BASE_MODEL_NAME", "unsloth/Qwen3.5-2B")
|
| 32 |
+
LOCAL_NUM_EPISODES = int(os.getenv("LOCAL_NUM_EPISODES", "3"))
|
| 33 |
+
LOCAL_MAX_TURNS = int(os.getenv("LOCAL_MAX_TURNS", "10"))
|
| 34 |
+
MLX_MAX_TOKENS = int(os.getenv("MLX_MAX_TOKENS", "768"))
|
| 35 |
+
MLX_RETRY_MAX_TOKENS = int(os.getenv("MLX_RETRY_MAX_TOKENS", "512"))
|
| 36 |
+
MLX_JSON_PREFILL = os.getenv("MLX_JSON_PREFILL", "true").lower() == "true"
|
| 37 |
+
MLX_COMPACT_ACTION = os.getenv("MLX_COMPACT_ACTION", "false").lower() == "true"
|
| 38 |
+
MLX_COMPACT_REPAIR = os.getenv("MLX_COMPACT_REPAIR", "false").lower() == "true"
|
| 39 |
+
MLX_FORCED_ACTION_TYPES = [
|
| 40 |
+
item.strip()
|
| 41 |
+
for item in os.getenv("MLX_FORCED_ACTION_TYPES", "").split(",")
|
| 42 |
+
if item.strip()
|
| 43 |
+
]
|
| 44 |
+
JSON_PREFILL = '{"action_type":"'
|
| 45 |
+
COMPACT_ACTION_SYSTEM_PROMPT = """
|
| 46 |
+
You control the MolForge action policy.
|
| 47 |
+
Return exactly one JSON object with only these top-level keys:
|
| 48 |
+
action_type, acting_role, edit_type, slot, fragment, tool_name, rationale,
|
| 49 |
+
evidence, expected_effects.
|
| 50 |
+
|
| 51 |
+
Valid action_type values are exactly:
|
| 52 |
+
edit, run_assay, submit, restart, defer.
|
| 53 |
+
|
| 54 |
+
Do not output team messages. Do not output proposal, approval, objection,
|
| 55 |
+
risk_flag, assay_request, rejection, or submission_recommendation as action_type.
|
| 56 |
+
The environment will attach governance messages automatically.
|
| 57 |
+
|
| 58 |
+
Role rules:
|
| 59 |
+
- run_assay uses acting_role "assay_planner" and a valid tool_name.
|
| 60 |
+
- edit, submit, restart, and defer use acting_role "lead_chemist".
|
| 61 |
+
- unused optional fields must be JSON null.
|
| 62 |
+
""".strip()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def main() -> None:
|
| 66 |
+
adapter_path = ADAPTER_PATH.expanduser().resolve()
|
| 67 |
+
print(f"Using MLX base model: {BASE_MODEL_NAME}", flush=True)
|
| 68 |
+
print(f"Using LoRA adapter: {adapter_path}", flush=True)
|
| 69 |
+
model, tokenizer = load(BASE_MODEL_NAME, adapter_path=str(adapter_path))
|
| 70 |
+
sampler = make_sampler(temp=0.0)
|
| 71 |
+
|
| 72 |
+
env = MolForgeEnvironment()
|
| 73 |
+
scores = []
|
| 74 |
+
submission_scores = []
|
| 75 |
+
progress_scores = []
|
| 76 |
+
|
| 77 |
+
for episode_index in range(LOCAL_NUM_EPISODES):
|
| 78 |
+
observation = env.reset()
|
| 79 |
+
print(f"\n=== Episode {episode_index + 1}: {observation.scenario_id} ===", flush=True)
|
| 80 |
+
|
| 81 |
+
for _ in range(LOCAL_MAX_TURNS):
|
| 82 |
+
if observation.done:
|
| 83 |
+
break
|
| 84 |
+
action, source, elapsed = choose_mlx_action(model, tokenizer, sampler, observation)
|
| 85 |
+
if MLX_COMPACT_ACTION:
|
| 86 |
+
action = attach_team_messages(observation, action)
|
| 87 |
+
observation = env.step(action)
|
| 88 |
+
print(
|
| 89 |
+
f"step={observation.step_index:02d} action={action.action_type} actor={action.acting_role} "
|
| 90 |
+
f"source={source} gen_s={elapsed:.2f} reward={observation.reward:+.3f} "
|
| 91 |
+
f"budget={observation.remaining_budget} governance={observation.governance.status}",
|
| 92 |
+
flush=True,
|
| 93 |
+
)
|
| 94 |
+
print(f" {observation.last_transition_summary}", flush=True)
|
| 95 |
+
if observation.done:
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
grader_scores = observation.metadata.get("terminal_grader_scores", {})
|
| 99 |
+
final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0)))
|
| 100 |
+
submission_score = float(grader_scores.get("submission_score", 0.0))
|
| 101 |
+
progress_score = float(grader_scores.get("progress_score", 0.0))
|
| 102 |
+
scores.append(final_score)
|
| 103 |
+
submission_scores.append(submission_score)
|
| 104 |
+
progress_scores.append(progress_score)
|
| 105 |
+
print(f"final_score={final_score:.3f}", flush=True)
|
| 106 |
+
print(f"submission_score={submission_score:.3f}", flush=True)
|
| 107 |
+
print(f"progress_score={progress_score:.3f}", flush=True)
|
| 108 |
+
if observation.report_card:
|
| 109 |
+
print(observation.report_card, flush=True)
|
| 110 |
+
|
| 111 |
+
average = sum(scores) / len(scores)
|
| 112 |
+
average_progress = sum(progress_scores) / len(progress_scores)
|
| 113 |
+
print("\n=== MLX LoRA Local Summary ===", flush=True)
|
| 114 |
+
print(
|
| 115 |
+
json.dumps(
|
| 116 |
+
{
|
| 117 |
+
"adapter": str(adapter_path),
|
| 118 |
+
"base_model": BASE_MODEL_NAME,
|
| 119 |
+
"scores": scores,
|
| 120 |
+
"average_final_score": round(average, 4),
|
| 121 |
+
"submission_scores": submission_scores,
|
| 122 |
+
"average_submission_score": round(sum(submission_scores) / len(submission_scores), 4),
|
| 123 |
+
"progress_scores": progress_scores,
|
| 124 |
+
"average_progress_score": round(average_progress, 4),
|
| 125 |
+
},
|
| 126 |
+
indent=2,
|
| 127 |
+
),
|
| 128 |
+
flush=True,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def choose_mlx_action(
|
| 133 |
+
model,
|
| 134 |
+
tokenizer,
|
| 135 |
+
sampler,
|
| 136 |
+
observation: MolForgeObservation,
|
| 137 |
+
) -> Tuple[MolForgeAction, str, float]:
|
| 138 |
+
started = time.perf_counter()
|
| 139 |
+
action, error = ask_mlx_model(
|
| 140 |
+
model,
|
| 141 |
+
tokenizer,
|
| 142 |
+
sampler,
|
| 143 |
+
observation,
|
| 144 |
+
compact=False,
|
| 145 |
+
max_tokens=MLX_MAX_TOKENS,
|
| 146 |
+
forced_action_type=None,
|
| 147 |
+
)
|
| 148 |
+
if action is not None:
|
| 149 |
+
return action, "mlx_lora_model", time.perf_counter() - started
|
| 150 |
+
|
| 151 |
+
forced_errors = []
|
| 152 |
+
for forced_action_type in forced_action_types(observation):
|
| 153 |
+
forced_action, forced_error = ask_mlx_model(
|
| 154 |
+
model,
|
| 155 |
+
tokenizer,
|
| 156 |
+
sampler,
|
| 157 |
+
observation,
|
| 158 |
+
compact=True,
|
| 159 |
+
max_tokens=MLX_RETRY_MAX_TOKENS,
|
| 160 |
+
forced_action_type=forced_action_type,
|
| 161 |
+
)
|
| 162 |
+
if forced_action is not None:
|
| 163 |
+
return (
|
| 164 |
+
forced_action,
|
| 165 |
+
f"mlx_lora_forced_{forced_action_type}",
|
| 166 |
+
time.perf_counter() - started,
|
| 167 |
+
)
|
| 168 |
+
forced_errors.append(f"{forced_action_type}:{forced_error}")
|
| 169 |
+
|
| 170 |
+
retry_action, retry_error = ask_mlx_model(
|
| 171 |
+
model,
|
| 172 |
+
tokenizer,
|
| 173 |
+
sampler,
|
| 174 |
+
observation,
|
| 175 |
+
compact=True,
|
| 176 |
+
max_tokens=MLX_RETRY_MAX_TOKENS,
|
| 177 |
+
forced_action_type=None,
|
| 178 |
+
)
|
| 179 |
+
if retry_action is not None:
|
| 180 |
+
return retry_action, "mlx_lora_compact_retry", time.perf_counter() - started
|
| 181 |
+
|
| 182 |
+
raise RuntimeError(
|
| 183 |
+
"MLX LoRA action failed: "
|
| 184 |
+
f"full_prompt:{error} | forced:{' || '.join(forced_errors)} | compact_prompt:{retry_error}"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def ask_mlx_model(
|
| 189 |
+
model,
|
| 190 |
+
tokenizer,
|
| 191 |
+
sampler,
|
| 192 |
+
observation: MolForgeObservation,
|
| 193 |
+
*,
|
| 194 |
+
compact: bool,
|
| 195 |
+
max_tokens: int,
|
| 196 |
+
forced_action_type: Optional[str],
|
| 197 |
+
) -> Tuple[Optional[MolForgeAction], str]:
|
| 198 |
+
response_text = ""
|
| 199 |
+
try:
|
| 200 |
+
payload = (
|
| 201 |
+
compact_action_payload(observation)
|
| 202 |
+
if MLX_COMPACT_ACTION
|
| 203 |
+
else build_model_payload(observation, compact=compact)
|
| 204 |
+
)
|
| 205 |
+
system_prompt = (
|
| 206 |
+
COMPACT_ACTION_SYSTEM_PROMPT
|
| 207 |
+
if MLX_COMPACT_ACTION
|
| 208 |
+
else (COMPACT_SYSTEM_PROMPT if compact else SYSTEM_PROMPT)
|
| 209 |
+
)
|
| 210 |
+
response_text = generate_response(
|
| 211 |
+
model,
|
| 212 |
+
tokenizer,
|
| 213 |
+
sampler,
|
| 214 |
+
system_prompt=system_prompt,
|
| 215 |
+
user_payload=payload,
|
| 216 |
+
max_tokens=max_tokens,
|
| 217 |
+
use_json_prefill=MLX_JSON_PREFILL,
|
| 218 |
+
forced_action_type=forced_action_type,
|
| 219 |
+
)
|
| 220 |
+
if MLX_JSON_PREFILL:
|
| 221 |
+
response_text = json_prefill(forced_action_type) + response_text
|
| 222 |
+
data = extract_json(response_text)
|
| 223 |
+
repair_notes: list[str] = []
|
| 224 |
+
if MLX_COMPACT_ACTION and MLX_COMPACT_REPAIR:
|
| 225 |
+
data, repair_notes = repair_compact_action(data)
|
| 226 |
+
if MLX_COMPACT_ACTION and "messages" in data:
|
| 227 |
+
raise ValueError("compact action output must not include messages")
|
| 228 |
+
action = MolForgeAction(**data)
|
| 229 |
+
if repair_notes:
|
| 230 |
+
action.metadata["compact_repair_notes"] = repair_notes
|
| 231 |
+
return action, ""
|
| 232 |
+
except Exception as exc:
|
| 233 |
+
snippet = response_text[:1200].replace("\n", "\\n")
|
| 234 |
+
return None, f"{exc.__class__.__name__}:{exc}; raw={snippet}"
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def generate_response(
|
| 238 |
+
model,
|
| 239 |
+
tokenizer,
|
| 240 |
+
sampler,
|
| 241 |
+
*,
|
| 242 |
+
system_prompt: str,
|
| 243 |
+
user_payload: Dict[str, Any],
|
| 244 |
+
max_tokens: int,
|
| 245 |
+
use_json_prefill: bool,
|
| 246 |
+
forced_action_type: Optional[str],
|
| 247 |
+
) -> str:
|
| 248 |
+
messages = [
|
| 249 |
+
{"role": "system", "content": system_prompt},
|
| 250 |
+
{"role": "user", "content": json.dumps(user_payload, separators=(",", ":"))},
|
| 251 |
+
]
|
| 252 |
+
prompt = tokenizer.apply_chat_template(
|
| 253 |
+
messages,
|
| 254 |
+
tokenize=False,
|
| 255 |
+
add_generation_prompt=True,
|
| 256 |
+
enable_thinking=False,
|
| 257 |
+
)
|
| 258 |
+
if use_json_prefill:
|
| 259 |
+
prompt += json_prefill(forced_action_type)
|
| 260 |
+
return generate(
|
| 261 |
+
model,
|
| 262 |
+
tokenizer,
|
| 263 |
+
prompt,
|
| 264 |
+
verbose=False,
|
| 265 |
+
max_tokens=max_tokens,
|
| 266 |
+
sampler=sampler,
|
| 267 |
+
).strip()
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def json_prefill(forced_action_type: Optional[str]) -> str:
|
| 271 |
+
if forced_action_type:
|
| 272 |
+
return f'{{"action_type":"{forced_action_type}",'
|
| 273 |
+
return JSON_PREFILL
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def forced_action_types(observation: MolForgeObservation) -> list[str]:
|
| 277 |
+
if MLX_FORCED_ACTION_TYPES:
|
| 278 |
+
return MLX_FORCED_ACTION_TYPES
|
| 279 |
+
if observation.step_index == 0:
|
| 280 |
+
if observation.scenario_id == "level_2_hard":
|
| 281 |
+
return ["restart", "edit", "run_assay", "defer"]
|
| 282 |
+
return ["edit", "run_assay", "defer"]
|
| 283 |
+
return ["run_assay", "edit", "submit", "restart", "defer"]
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def compact_action_payload(observation: MolForgeObservation) -> dict[str, Any]:
|
| 287 |
+
lead_view = next(
|
| 288 |
+
(role.observation for role in observation.role_observations if role.role == "lead_chemist"),
|
| 289 |
+
{},
|
| 290 |
+
)
|
| 291 |
+
assay_view = next(
|
| 292 |
+
(role.observation for role in observation.role_observations if role.role == "assay_planner"),
|
| 293 |
+
{},
|
| 294 |
+
)
|
| 295 |
+
return {
|
| 296 |
+
"valid_action_types": ["edit", "run_assay", "submit", "restart", "defer"],
|
| 297 |
+
"scenario_id": observation.scenario_id,
|
| 298 |
+
"difficulty": observation.difficulty,
|
| 299 |
+
"task_brief": observation.task_brief,
|
| 300 |
+
"current_molecule": observation.current_molecule,
|
| 301 |
+
"current_smiles": observation.metadata.get("current_smiles", ""),
|
| 302 |
+
"visible_metrics": observation.visible_metrics,
|
| 303 |
+
"constraint_status": [constraint.model_dump() for constraint in observation.constraint_status],
|
| 304 |
+
"remaining_budget": observation.remaining_budget,
|
| 305 |
+
"max_budget": observation.max_budget,
|
| 306 |
+
"step_index": observation.step_index,
|
| 307 |
+
"max_steps": observation.max_steps,
|
| 308 |
+
"molecule_slots": lead_view.get("molecule_slots", {}),
|
| 309 |
+
"candidate_edits": lead_view.get("candidate_edits", [])[:12],
|
| 310 |
+
"open_questions": lead_view.get("open_questions", []),
|
| 311 |
+
"known_assays": [
|
| 312 |
+
{
|
| 313 |
+
"tool_name": reading.tool_name,
|
| 314 |
+
"property_name": reading.property_name,
|
| 315 |
+
"estimate": reading.estimate,
|
| 316 |
+
"confidence_low": reading.confidence_low,
|
| 317 |
+
"confidence_high": reading.confidence_high,
|
| 318 |
+
"molecule_signature": reading.molecule_signature,
|
| 319 |
+
}
|
| 320 |
+
for reading in observation.known_assays[-8:]
|
| 321 |
+
],
|
| 322 |
+
"tool_costs": assay_view.get("tool_costs", {}),
|
| 323 |
+
"evidence_gaps": assay_view.get("evidence_gaps", []),
|
| 324 |
+
"estimated_information_value": assay_view.get("estimated_information_value", {}),
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def repair_compact_action(data: Dict[str, Any]) -> tuple[Dict[str, Any], list[str]]:
|
| 329 |
+
"""Bounded normalization for compact-action models.
|
| 330 |
+
|
| 331 |
+
This repairs only schema-near-misses. It does not invent an action from a
|
| 332 |
+
non-action wrapper and it still rejects invalid top-level action types.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
repaired = dict(data)
|
| 336 |
+
notes: list[str] = []
|
| 337 |
+
|
| 338 |
+
if "role" in repaired and "acting_role" not in repaired:
|
| 339 |
+
repaired["acting_role"] = repaired.pop("role")
|
| 340 |
+
notes.append("role->acting_role")
|
| 341 |
+
|
| 342 |
+
action_type = repaired.get("action_type")
|
| 343 |
+
if action_type not in {"edit", "run_assay", "submit", "restart", "defer"}:
|
| 344 |
+
return repaired, notes
|
| 345 |
+
|
| 346 |
+
if repaired.get("edit_type") == "replace":
|
| 347 |
+
repaired["edit_type"] = "substitute"
|
| 348 |
+
notes.append("edit_type:replace->substitute")
|
| 349 |
+
|
| 350 |
+
if isinstance(repaired.get("evidence"), str):
|
| 351 |
+
repaired["evidence"] = [repaired["evidence"]]
|
| 352 |
+
notes.append("evidence:string->list")
|
| 353 |
+
|
| 354 |
+
repaired["expected_effects"] = repair_effects(repaired.get("expected_effects"), notes)
|
| 355 |
+
|
| 356 |
+
if action_type == "run_assay":
|
| 357 |
+
repaired["acting_role"] = "assay_planner"
|
| 358 |
+
repaired["edit_type"] = None
|
| 359 |
+
repaired["slot"] = None
|
| 360 |
+
repaired["fragment"] = None
|
| 361 |
+
if repaired.get("tool_name") not in {
|
| 362 |
+
"evaluate_properties",
|
| 363 |
+
"dock_target",
|
| 364 |
+
"assay_toxicity",
|
| 365 |
+
"estimate_synthesizability",
|
| 366 |
+
"evaluate_novelty",
|
| 367 |
+
"search_literature",
|
| 368 |
+
"run_md_simulation",
|
| 369 |
+
}:
|
| 370 |
+
repaired["tool_name"] = "evaluate_properties"
|
| 371 |
+
notes.append("tool_name:invalid->evaluate_properties")
|
| 372 |
+
else:
|
| 373 |
+
repaired["acting_role"] = "lead_chemist"
|
| 374 |
+
if action_type == "edit":
|
| 375 |
+
if repaired.get("edit_type") not in {"add_fragment", "substitute", "remove", "undo_last_edit"}:
|
| 376 |
+
repaired["edit_type"] = "substitute"
|
| 377 |
+
notes.append("edit_type:invalid->substitute")
|
| 378 |
+
if repaired.get("tool_name") is not None:
|
| 379 |
+
repaired["tool_name"] = None
|
| 380 |
+
notes.append("tool_name:edit->null")
|
| 381 |
+
else:
|
| 382 |
+
for key in ("edit_type", "slot", "fragment", "tool_name"):
|
| 383 |
+
if repaired.get(key) is not None:
|
| 384 |
+
repaired[key] = None
|
| 385 |
+
notes.append(f"{key}:{action_type}->null")
|
| 386 |
+
|
| 387 |
+
allowed_keys = {
|
| 388 |
+
"action_type",
|
| 389 |
+
"acting_role",
|
| 390 |
+
"edit_type",
|
| 391 |
+
"slot",
|
| 392 |
+
"fragment",
|
| 393 |
+
"tool_name",
|
| 394 |
+
"rationale",
|
| 395 |
+
"evidence",
|
| 396 |
+
"expected_effects",
|
| 397 |
+
}
|
| 398 |
+
for key in list(repaired):
|
| 399 |
+
if key not in allowed_keys:
|
| 400 |
+
repaired.pop(key)
|
| 401 |
+
notes.append(f"drop_extra:{key}")
|
| 402 |
+
|
| 403 |
+
repaired.setdefault("rationale", "Choose the next compact MolForge action.")
|
| 404 |
+
repaired.setdefault("evidence", [])
|
| 405 |
+
for key in ("edit_type", "slot", "fragment", "tool_name"):
|
| 406 |
+
repaired.setdefault(key, None)
|
| 407 |
+
|
| 408 |
+
return repaired, notes
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def repair_effects(value: Any, notes: list[str]) -> dict[str, str]:
|
| 412 |
+
defaults = {
|
| 413 |
+
"potency": "unknown",
|
| 414 |
+
"toxicity": "unknown",
|
| 415 |
+
"synth": "unknown",
|
| 416 |
+
"novelty": "unknown",
|
| 417 |
+
"budget": "neutral",
|
| 418 |
+
}
|
| 419 |
+
if not isinstance(value, dict):
|
| 420 |
+
notes.append("expected_effects:non_dict->defaults")
|
| 421 |
+
return defaults
|
| 422 |
+
|
| 423 |
+
aliases = {
|
| 424 |
+
"synthesizability": "synth",
|
| 425 |
+
"synthesis": "synth",
|
| 426 |
+
}
|
| 427 |
+
for raw_key, raw_value in value.items():
|
| 428 |
+
key = aliases.get(raw_key, raw_key)
|
| 429 |
+
if key not in defaults:
|
| 430 |
+
notes.append(f"expected_effects:drop_extra:{raw_key}")
|
| 431 |
+
continue
|
| 432 |
+
defaults[key] = normalize_effect_value(raw_value, notes, key)
|
| 433 |
+
return defaults
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def normalize_effect_value(value: Any, notes: list[str], key: str) -> str:
|
| 437 |
+
if value in {"up", "down", "neutral", "unknown", "not_applicable"}:
|
| 438 |
+
return value
|
| 439 |
+
text = str(value).lower().strip().replace("-", "_").replace(" ", "_")
|
| 440 |
+
if any(token in text for token in ("increase", "improve", "higher", "upward", "+")):
|
| 441 |
+
notes.append(f"expected_effects:{key}:{value}->up")
|
| 442 |
+
return "up"
|
| 443 |
+
if any(token in text for token in ("decrease", "lower", "reduce", "downward", "-")):
|
| 444 |
+
notes.append(f"expected_effects:{key}:{value}->down")
|
| 445 |
+
return "down"
|
| 446 |
+
if any(token in text for token in ("maintain", "stable", "unchanged", "same")):
|
| 447 |
+
notes.append(f"expected_effects:{key}:{value}->neutral")
|
| 448 |
+
return "neutral"
|
| 449 |
+
if "not_applicable" in text or text == "na":
|
| 450 |
+
notes.append(f"expected_effects:{key}:{value}->not_applicable")
|
| 451 |
+
return "not_applicable"
|
| 452 |
+
notes.append(f"expected_effects:{key}:{value}->unknown")
|
| 453 |
+
return "unknown"
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
if __name__ == "__main__":
|
| 457 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Typed models for the MolForge OpenEnv environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 6 |
+
|
| 7 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
EDIT_TYPES = Literal["add_fragment", "substitute", "remove", "undo_last_edit"]
|
| 12 |
+
ACTION_TYPES = Literal["edit", "run_assay", "submit", "restart", "defer"]
|
| 13 |
+
TOOL_TYPES = Literal[
|
| 14 |
+
"evaluate_properties",
|
| 15 |
+
"dock_target",
|
| 16 |
+
"assay_toxicity",
|
| 17 |
+
"estimate_synthesizability",
|
| 18 |
+
"evaluate_novelty",
|
| 19 |
+
"search_literature",
|
| 20 |
+
"run_md_simulation",
|
| 21 |
+
]
|
| 22 |
+
SLOT_TYPES = Literal["warhead", "hinge", "solvent_tail", "back_pocket"]
|
| 23 |
+
ROLE_TYPES = Literal[
|
| 24 |
+
"lead_chemist",
|
| 25 |
+
"toxicologist",
|
| 26 |
+
"assay_planner",
|
| 27 |
+
"process_chemist",
|
| 28 |
+
"team",
|
| 29 |
+
]
|
| 30 |
+
MESSAGE_TYPES = Literal[
|
| 31 |
+
"proposal",
|
| 32 |
+
"objection",
|
| 33 |
+
"risk_flag",
|
| 34 |
+
"assay_request",
|
| 35 |
+
"approval",
|
| 36 |
+
"rejection",
|
| 37 |
+
"revision_request",
|
| 38 |
+
"submission_recommendation",
|
| 39 |
+
]
|
| 40 |
+
SEVERITY_TYPES = Literal["low", "medium", "high", "critical"]
|
| 41 |
+
EFFECT_TYPES = Literal["up", "down", "neutral", "unknown", "not_applicable"]
|
| 42 |
+
COORDINATION_MODES = Literal["single_agent", "multi_agent"]
|
| 43 |
+
GOVERNANCE_STATES = Literal["ready", "executed", "needs_revision", "policy_veto"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MoleculeSlot(BaseModel):
|
| 47 |
+
"""Visible fragment assignment for a molecule slot."""
|
| 48 |
+
|
| 49 |
+
slot: SLOT_TYPES
|
| 50 |
+
fragment: str = Field(..., description="Selected fragment for the slot")
|
| 51 |
+
editable: bool = Field(default=True, description="Whether the slot is editable")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class AssayReading(BaseModel):
|
| 55 |
+
"""Structured oracle result surfaced to the agent."""
|
| 56 |
+
|
| 57 |
+
tool_name: str
|
| 58 |
+
property_name: str
|
| 59 |
+
estimate: float = Field(..., ge=0.0, le=1.0)
|
| 60 |
+
confidence_low: float = Field(..., ge=0.0, le=1.0)
|
| 61 |
+
confidence_high: float = Field(..., ge=0.0, le=1.0)
|
| 62 |
+
runs: int = Field(default=1, ge=1)
|
| 63 |
+
molecule_signature: str
|
| 64 |
+
summary: str = ""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class RewardComponent(BaseModel):
|
| 68 |
+
"""Named reward component used in report cards and debugging."""
|
| 69 |
+
|
| 70 |
+
name: str
|
| 71 |
+
value: float
|
| 72 |
+
explanation: str
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ConstraintCheck(BaseModel):
|
| 76 |
+
"""Constraint status based only on currently visible evidence."""
|
| 77 |
+
|
| 78 |
+
name: str
|
| 79 |
+
target: str
|
| 80 |
+
satisfied: Optional[bool] = None
|
| 81 |
+
actual: Optional[float] = None
|
| 82 |
+
evidence_status: Literal["known", "unknown"] = "unknown"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class AgentMessage(BaseModel):
|
| 86 |
+
"""Structured inter-agent communication message."""
|
| 87 |
+
|
| 88 |
+
message_id: str = ""
|
| 89 |
+
sender: ROLE_TYPES
|
| 90 |
+
receiver: str = "team"
|
| 91 |
+
message_type: MESSAGE_TYPES
|
| 92 |
+
severity: SEVERITY_TYPES = "low"
|
| 93 |
+
reference_action_type: Optional[ACTION_TYPES] = None
|
| 94 |
+
summary: str = Field(default="", max_length=240)
|
| 95 |
+
payload: Dict[str, Any] = Field(default_factory=dict)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class RoleObservation(BaseModel):
|
| 99 |
+
"""Role-specific structured observation slice."""
|
| 100 |
+
|
| 101 |
+
role: ROLE_TYPES
|
| 102 |
+
local_objective: str
|
| 103 |
+
permissions: List[str] = Field(default_factory=list)
|
| 104 |
+
observation: Dict[str, Any] = Field(default_factory=dict)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class GovernanceStatus(BaseModel):
|
| 108 |
+
"""Outcome of the multi-agent review process for the last turn."""
|
| 109 |
+
|
| 110 |
+
status: GOVERNANCE_STATES = "ready"
|
| 111 |
+
explanation: str = ""
|
| 112 |
+
required_roles: List[str] = Field(default_factory=list)
|
| 113 |
+
approvals: List[str] = Field(default_factory=list)
|
| 114 |
+
objections: List[str] = Field(default_factory=list)
|
| 115 |
+
vetoes: List[str] = Field(default_factory=list)
|
| 116 |
+
executable: bool = True
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class MolForgeAction(Action):
|
| 120 |
+
"""Single team turn action spanning edits, assays, messages, and submission."""
|
| 121 |
+
|
| 122 |
+
action_type: ACTION_TYPES = Field(
|
| 123 |
+
..., description="High-level action type to execute this turn"
|
| 124 |
+
)
|
| 125 |
+
acting_role: ROLE_TYPES = Field(
|
| 126 |
+
default="lead_chemist",
|
| 127 |
+
description="Role claiming ownership of the executable team decision",
|
| 128 |
+
)
|
| 129 |
+
edit_type: Optional[EDIT_TYPES] = Field(
|
| 130 |
+
default=None, description="Edit subtype when action_type is edit"
|
| 131 |
+
)
|
| 132 |
+
slot: Optional[SLOT_TYPES] = Field(
|
| 133 |
+
default=None, description="Editable molecular slot when performing edits"
|
| 134 |
+
)
|
| 135 |
+
fragment: Optional[str] = Field(
|
| 136 |
+
default=None, description="Fragment identifier for edit actions"
|
| 137 |
+
)
|
| 138 |
+
tool_name: Optional[TOOL_TYPES] = Field(
|
| 139 |
+
default=None, description="Oracle or tool name for run_assay actions"
|
| 140 |
+
)
|
| 141 |
+
messages: List[AgentMessage] = Field(
|
| 142 |
+
default_factory=list,
|
| 143 |
+
description="Structured multi-agent communication bundle for this decision turn",
|
| 144 |
+
)
|
| 145 |
+
rationale: str = Field(
|
| 146 |
+
default="",
|
| 147 |
+
description="Short explanation of why the final decision should help",
|
| 148 |
+
max_length=400,
|
| 149 |
+
)
|
| 150 |
+
evidence: List[str] = Field(
|
| 151 |
+
default_factory=list,
|
| 152 |
+
description="Visible observation facts supporting the action; do not include hidden state.",
|
| 153 |
+
max_length=5,
|
| 154 |
+
)
|
| 155 |
+
expected_effects: Dict[str, EFFECT_TYPES] = Field(
|
| 156 |
+
default_factory=dict,
|
| 157 |
+
description="Directional public prediction for potency, toxicity, synth, novelty, or budget.",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class MolForgeObservation(Observation):
|
| 162 |
+
"""Observation emitted after reset and each step."""
|
| 163 |
+
|
| 164 |
+
scenario_id: str
|
| 165 |
+
difficulty: str
|
| 166 |
+
state_label: str = "[start]"
|
| 167 |
+
state_path: List[str] = Field(default_factory=list)
|
| 168 |
+
coordination_mode: COORDINATION_MODES = "multi_agent"
|
| 169 |
+
enabled_roles: List[str] = Field(default_factory=list)
|
| 170 |
+
task_brief: str
|
| 171 |
+
target_name: str
|
| 172 |
+
current_molecule: str
|
| 173 |
+
molecule_slots: List[MoleculeSlot] = Field(default_factory=list)
|
| 174 |
+
editable_slots: List[str] = Field(default_factory=list)
|
| 175 |
+
step_index: int = Field(default=0, ge=0)
|
| 176 |
+
max_steps: int = Field(default=0, ge=1)
|
| 177 |
+
remaining_budget: int = Field(default=0, ge=0)
|
| 178 |
+
budget_used: int = Field(default=0, ge=0)
|
| 179 |
+
max_budget: int = Field(default=0, ge=1)
|
| 180 |
+
known_assays: List[AssayReading] = Field(default_factory=list)
|
| 181 |
+
role_observations: List[RoleObservation] = Field(default_factory=list)
|
| 182 |
+
message_log: List[AgentMessage] = Field(default_factory=list)
|
| 183 |
+
governance: GovernanceStatus = Field(default_factory=GovernanceStatus)
|
| 184 |
+
last_transition_summary: str = ""
|
| 185 |
+
visible_metrics: Dict[str, float] = Field(default_factory=dict)
|
| 186 |
+
constraint_status: List[ConstraintCheck] = Field(default_factory=list)
|
| 187 |
+
reward_breakdown: List[RewardComponent] = Field(default_factory=list)
|
| 188 |
+
allowed_actions: List[str] = Field(default_factory=list)
|
| 189 |
+
report_card: str = ""
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class MolForgeState(State):
|
| 193 |
+
"""Internal environment state surfaced through the state() API."""
|
| 194 |
+
|
| 195 |
+
scenario_id: str = ""
|
| 196 |
+
difficulty: str = ""
|
| 197 |
+
state_label: str = "[start]"
|
| 198 |
+
state_path: List[str] = Field(default_factory=list)
|
| 199 |
+
coordination_mode: COORDINATION_MODES = "multi_agent"
|
| 200 |
+
enabled_roles: List[str] = Field(default_factory=list)
|
| 201 |
+
target_name: str = ""
|
| 202 |
+
current_molecule: str = ""
|
| 203 |
+
remaining_budget: int = 0
|
| 204 |
+
budget_used: int = 0
|
| 205 |
+
max_budget: int = 0
|
| 206 |
+
visited_states: int = 0
|
| 207 |
+
known_assay_count: int = 0
|
| 208 |
+
invalid_action_count: int = 0
|
| 209 |
+
objection_count: int = 0
|
| 210 |
+
oracle_call_count: int = 0
|
| 211 |
+
message_count: int = 0
|
| 212 |
+
decision_count: int = 0
|
| 213 |
+
submitted: bool = False
|
| 214 |
+
last_error_code: str = ""
|
| 215 |
+
reward_total: float = 0.0
|
| 216 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
molforge_grpo_official_submission.ipynb
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# MolForge GRPO Training Pipeline\n",
|
| 8 |
+
"This notebook implements the Reinforcement Learning (GRPO) training pipeline for the MolForge environment.\n",
|
| 9 |
+
"We train the model using a **Proposer-Critic-Selector** architecture and targeted **reward shaping** to overcome local minima."
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "code",
|
| 14 |
+
"execution_count": null,
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"!pip install -U \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 19 |
+
"!pip install -U \"trl>=0.21.0\" peft accelerate bitsandbytes datasets matplotlib pandas huggingface_hub \"openenv-core[core]>=0.2.3\" rdkit jmespath xformers"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"import os\n",
|
| 29 |
+
"import sys\n",
|
| 30 |
+
"from pathlib import Path\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"# Clone the repository\n",
|
| 33 |
+
"if not Path(\"/content/molt_lab\").exists():\n",
|
| 34 |
+
" !git clone https://github.com/Adhitya-Vardhan/molt_lab.git /content/molt_lab\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"# Add project root to path\n",
|
| 37 |
+
"if \"/content/molt_lab\" not in sys.path:\n",
|
| 38 |
+
" sys.path.insert(0, \"/content/molt_lab\")\n",
|
| 39 |
+
" \n",
|
| 40 |
+
"# Change working directory\n",
|
| 41 |
+
"os.chdir(\"/content/molt_lab\")"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"import time\n",
|
| 51 |
+
"import os\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"# Training Configuration\n",
|
| 54 |
+
"os.environ[\"MOLFORGE_REWARD_MODE\"] = \"curriculum\"\n",
|
| 55 |
+
"os.environ[\"MOLFORGE_TRAINING_RANDOMIZATION\"] = \"1\"\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"RL_MAX_STEPS = 80\n",
|
| 58 |
+
"NUM_GENERATIONS = 2\n",
|
| 59 |
+
"PER_DEVICE_BATCH = 2\n",
|
| 60 |
+
"GRAD_ACCUM = 4\n",
|
| 61 |
+
"LEARNING_RATE = 2e-6\n",
|
| 62 |
+
"MAX_SEQ_LENGTH = 2048\n",
|
| 63 |
+
"MAX_PROMPT_LENGTH = 1536\n",
|
| 64 |
+
"MAX_COMPLETION_LENGTH = 384\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"RUN_NAME = time.strftime(\"molforge_grpo_%Y%m%d_%H%M%S\")\n",
|
| 67 |
+
"OUTPUT_DIR = Path(f\"/content/molforge_rl_runs/{RUN_NAME}\")\n",
|
| 68 |
+
"ADAPTER_SAVE_DIR = OUTPUT_DIR / \"adapters\"\n",
|
| 69 |
+
"PLOT_DIR = OUTPUT_DIR / \"plots\"\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 72 |
+
"PLOT_DIR.mkdir(parents=True, exist_ok=True)"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"source": [
|
| 79 |
+
"### Reward Function & OpenEnv Integration\n",
|
| 80 |
+
"We implement a custom reward function that wraps the native `MolForgeEnvironment`. \n",
|
| 81 |
+
"To prevent \"reward hacking\" (where the model endlessly farms `run_assay` for safe points), we apply targeted reward shaping."
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": null,
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": [
|
| 90 |
+
"import json\n",
|
| 91 |
+
"from typing import Any, Dict, Tuple\n",
|
| 92 |
+
"from inference_common import (\n",
|
| 93 |
+
" MolForgeAction,\n",
|
| 94 |
+
" attach_reasoning_fields,\n",
|
| 95 |
+
" attach_team_messages,\n",
|
| 96 |
+
" extract_json,\n",
|
| 97 |
+
")\n",
|
| 98 |
+
"from server.molforge_environment import MolForgeEnvironment\n",
|
| 99 |
+
"from models import MolForgeState\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"def replay_to_state(record: dict[str, Any]) -> MolForgeEnvironment:\n",
|
| 102 |
+
" env = MolForgeEnvironment()\n",
|
| 103 |
+
" env._state = MolForgeState(**record[\"state\"])\n",
|
| 104 |
+
" env._molecule = dict(record[\"molecule\"])\n",
|
| 105 |
+
" env._scenario = [s for s in env.SCENARIOS if s.scenario_id == env._state.scenario_id][0]\n",
|
| 106 |
+
" return env\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"def evaluate_completion(prompt_str: str, completion_str: str, record: dict[str, Any]) -> Tuple[float, dict]:\n",
|
| 109 |
+
" diagnostics = {\"valid_json\": False}\n",
|
| 110 |
+
" try:\n",
|
| 111 |
+
" action_dict = extract_json(completion_str)\n",
|
| 112 |
+
" action = MolForgeAction(**action_dict)\n",
|
| 113 |
+
" except Exception:\n",
|
| 114 |
+
" return -1.2, diagnostics\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" diagnostics[\"valid_json\"] = True\n",
|
| 117 |
+
" env = replay_to_state(record)\n",
|
| 118 |
+
" \n",
|
| 119 |
+
" # Create empty observation and attach reasoning\n",
|
| 120 |
+
" observation = env._build_observation(reward=0.0, done=False, reward_components=[])\n",
|
| 121 |
+
" action = attach_team_messages(observation, attach_reasoning_fields(observation, action))\n",
|
| 122 |
+
" \n",
|
| 123 |
+
" # Step the OpenEnv environment\n",
|
| 124 |
+
" next_observation = env.step(action)\n",
|
| 125 |
+
" reward = float(next_observation.reward)\n",
|
| 126 |
+
" grader_scores = next_observation.metadata.get(\"terminal_grader_scores\", {})\n",
|
| 127 |
+
" \n",
|
| 128 |
+
" # --- ANTI-REWARD-HACKING SHAPING ---\n",
|
| 129 |
+
" if action.action_type == \"run_assay\" and reward > 0:\n",
|
| 130 |
+
" reward *= 0.25 # Nerf assay farming\n",
|
| 131 |
+
" elif action.action_type == \"submit\":\n",
|
| 132 |
+
" sub_score = float(grader_scores.get(\"submission_score\", 0.0))\n",
|
| 133 |
+
" if sub_score > 0.0:\n",
|
| 134 |
+
" reward += sub_score * 3.0 # Massive multiplier for submissions\n",
|
| 135 |
+
" elif action.action_type == \"edit\" and reward > 0:\n",
|
| 136 |
+
" reward *= 1.5 # Boost edits\n",
|
| 137 |
+
"\n",
|
| 138 |
+
" diagnostics.update({\n",
|
| 139 |
+
" \"action_type\": action.action_type,\n",
|
| 140 |
+
" \"reward\": reward,\n",
|
| 141 |
+
" \"done\": next_observation.done,\n",
|
| 142 |
+
" })\n",
|
| 143 |
+
" return reward, diagnostics\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"def molforge_reward_func(prompts, completions, **kwargs) -> list[float]:\n",
|
| 146 |
+
" rewards = []\n",
|
| 147 |
+
" dataset_records = kwargs.get(\"record\", [])\n",
|
| 148 |
+
" \n",
|
| 149 |
+
" for prompt_list, completion, record in zip(prompts, completions, dataset_records):\n",
|
| 150 |
+
" prompt_str = prompt_list[-1][\"content\"] if isinstance(prompt_list, list) else str(prompt_list)\n",
|
| 151 |
+
" completion_str = completion[0][\"content\"] if isinstance(completion, list) else str(completion)\n",
|
| 152 |
+
" reward, _ = evaluate_completion(prompt_str, completion_str, record)\n",
|
| 153 |
+
" rewards.append(reward)\n",
|
| 154 |
+
" return rewards"
|
| 155 |
+
]
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"cell_type": "markdown",
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"source": [
|
| 161 |
+
"### Model & Tokenizer Loading"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"execution_count": null,
|
| 167 |
+
"metadata": {},
|
| 168 |
+
"outputs": [],
|
| 169 |
+
"source": [
|
| 170 |
+
"from unsloth import FastLanguageModel\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"# Set this to your SFT checkpoint\n",
|
| 173 |
+
"# You can set this to a local path or a Hugging Face repo\n",
|
| 174 |
+
"SFT_ADAPTER_PATH = \"/content/drive/MyDrive/Qwen_3.5_finetune/qwen3_5_2b_lora_adapters_compact_v4\" # <-- Change to your path\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"print(\"Loading model and applying Unsloth optimizations...\")\n",
|
| 177 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 178 |
+
" model_name=SFT_ADAPTER_PATH,\n",
|
| 179 |
+
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
| 180 |
+
" dtype=None,\n",
|
| 181 |
+
" load_in_4bit=True,\n",
|
| 182 |
+
")\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"# Enable fast training paths\n",
|
| 185 |
+
"FastLanguageModel.for_training(model)\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"# Extract underlying tokenizer if it is wrapped in a vision processor\n",
|
| 188 |
+
"if hasattr(tokenizer, \"tokenizer\"):\n",
|
| 189 |
+
" tokenizer = tokenizer.tokenizer"
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"cell_type": "markdown",
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"source": [
|
| 196 |
+
"### GRPO Training Loop"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"cell_type": "code",
|
| 201 |
+
"execution_count": null,
|
| 202 |
+
"metadata": {},
|
| 203 |
+
"outputs": [],
|
| 204 |
+
"source": [
|
| 205 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 206 |
+
"from datasets import Dataset\n",
|
| 207 |
+
"from scripts.generate_sft_compact_policy_v4_dataset import compact_action_payload, COMPACT_ACTION_SYSTEM_PROMPT\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"# Load dataset\n",
|
| 210 |
+
"def load_prompt_dataset() -> Dataset:\n",
|
| 211 |
+
" import json\n",
|
| 212 |
+
" data = []\n",
|
| 213 |
+
" with open(\"data/molforge_sft_compact_policy_v4.jsonl\", \"r\") as f:\n",
|
| 214 |
+
" for line in f:\n",
|
| 215 |
+
" record = json.loads(line)\n",
|
| 216 |
+
" prompt_text = compact_action_payload(record)\n",
|
| 217 |
+
" data.append({\n",
|
| 218 |
+
" \"prompt\": [\n",
|
| 219 |
+
" {\"role\": \"system\", \"content\": COMPACT_ACTION_SYSTEM_PROMPT},\n",
|
| 220 |
+
" {\"role\": \"user\", \"content\": prompt_text}\n",
|
| 221 |
+
" ],\n",
|
| 222 |
+
" \"record\": record\n",
|
| 223 |
+
" })\n",
|
| 224 |
+
" return Dataset.from_list(data)\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"dataset = load_prompt_dataset()\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"# Configure GRPO\n",
|
| 229 |
+
"training_args = GRPOConfig(\n",
|
| 230 |
+
" output_dir=str(OUTPUT_DIR),\n",
|
| 231 |
+
" learning_rate=LEARNING_RATE,\n",
|
| 232 |
+
" per_device_train_batch_size=PER_DEVICE_BATCH,\n",
|
| 233 |
+
" gradient_accumulation_steps=GRAD_ACCUM,\n",
|
| 234 |
+
" max_prompt_length=MAX_PROMPT_LENGTH,\n",
|
| 235 |
+
" max_completion_length=MAX_COMPLETION_LENGTH,\n",
|
| 236 |
+
" num_generations=NUM_GENERATIONS,\n",
|
| 237 |
+
" max_steps=RL_MAX_STEPS,\n",
|
| 238 |
+
" logging_steps=1,\n",
|
| 239 |
+
" save_steps=25,\n",
|
| 240 |
+
" bf16=True,\n",
|
| 241 |
+
" report_to=\"none\",\n",
|
| 242 |
+
" log_completions=True,\n",
|
| 243 |
+
")\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"# Initialize Trainer\n",
|
| 246 |
+
"trainer = GRPOTrainer(\n",
|
| 247 |
+
" model=model,\n",
|
| 248 |
+
" reward_funcs=molforge_reward_func,\n",
|
| 249 |
+
" args=training_args,\n",
|
| 250 |
+
" train_dataset=dataset,\n",
|
| 251 |
+
" processing_class=tokenizer,\n",
|
| 252 |
+
")\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"print(\"Starting GRPO Training...\")\n",
|
| 255 |
+
"trainer.train()\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"print(f\"Training complete. Saving adapters to {ADAPTER_SAVE_DIR}\")\n",
|
| 258 |
+
"trainer.save_model(str(ADAPTER_SAVE_DIR))\n",
|
| 259 |
+
"tokenizer.save_pretrained(str(ADAPTER_SAVE_DIR))"
|
| 260 |
+
]
|
| 261 |
+
}
|
| 262 |
+
],
|
| 263 |
+
"metadata": {
|
| 264 |
+
"colab": {
|
| 265 |
+
"provenance": []
|
| 266 |
+
},
|
| 267 |
+
"kernelspec": {
|
| 268 |
+
"display_name": "Python 3",
|
| 269 |
+
"name": "python3"
|
| 270 |
+
},
|
| 271 |
+
"language_info": {
|
| 272 |
+
"name": "python"
|
| 273 |
+
}
|
| 274 |
+
},
|
| 275 |
+
"nbformat": 4,
|
| 276 |
+
"nbformat_minor": 0
|
| 277 |
+
}
|
molforge_oracles.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RDKit/TDC-backed molecular oracle helpers for MolForge."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from typing import Any, Dict, Mapping, Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
WARHEAD_SMILES = {
|
| 11 |
+
"acrylamide": "C(=O)NC=C",
|
| 12 |
+
"reversible_cyanoacrylamide": "C(=O)NC(=C)C#N",
|
| 13 |
+
"nitrile": "C#N",
|
| 14 |
+
"vinyl_sulfonamide": "S(=O)(=O)NC=C",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
HINGE_SMILES = {
|
| 18 |
+
"azaindole": "c1[nH]c2ccccc2n1",
|
| 19 |
+
"pyridine": "c1ccncc1",
|
| 20 |
+
"fluorophenyl": "c1ccc(F)cc1",
|
| 21 |
+
"quinazoline": "c1ncnc2ccccc12",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
TAIL_SMILES = {
|
| 25 |
+
"morpholine": "N1CCOCC1",
|
| 26 |
+
"piperazine": "N1CCNCC1",
|
| 27 |
+
"cyclopropyl": "C1CC1",
|
| 28 |
+
"dimethylamino": "N(C)C",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
BACK_POCKET_SMILES = {
|
| 32 |
+
"methoxy": "OC",
|
| 33 |
+
"chloro": "Cl",
|
| 34 |
+
"trifluoromethyl": "C(F)(F)F",
|
| 35 |
+
"cyano": "C#N",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def assemble_surrogate_smiles(molecule: Mapping[str, str]) -> str:
|
| 40 |
+
"""Build a valid substituted-aryl SMILES for RDKit/TDC scoring."""
|
| 41 |
+
|
| 42 |
+
return (
|
| 43 |
+
f"c%10({WARHEAD_SMILES[molecule['warhead']]})"
|
| 44 |
+
f"c({HINGE_SMILES[molecule['hinge']]})"
|
| 45 |
+
f"c({TAIL_SMILES[molecule['solvent_tail']]})"
|
| 46 |
+
f"c({BACK_POCKET_SMILES[molecule['back_pocket']]})cc%10"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def oracle_backend_status() -> Dict[str, bool]:
|
| 51 |
+
"""Report which external chemistry engines are importable."""
|
| 52 |
+
|
| 53 |
+
return {"rdkit": _rdkit_modules() is not None, "tdc": _tdc_oracle_class() is not None}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def evaluate_with_rdkit_tdc(
|
| 57 |
+
molecule: Mapping[str, str],
|
| 58 |
+
fallback_properties: Mapping[str, float],
|
| 59 |
+
) -> Dict[str, float]:
|
| 60 |
+
"""Blend RDKit/TDC medicinal-chemistry signals into MolForge properties."""
|
| 61 |
+
|
| 62 |
+
modules = _rdkit_modules()
|
| 63 |
+
if modules is None:
|
| 64 |
+
return dict(fallback_properties)
|
| 65 |
+
|
| 66 |
+
Chem = modules["Chem"]
|
| 67 |
+
Descriptors = modules["Descriptors"]
|
| 68 |
+
Crippen = modules["Crippen"]
|
| 69 |
+
Lipinski = modules["Lipinski"]
|
| 70 |
+
QED = modules["QED"]
|
| 71 |
+
rdFingerprintGenerator = modules["rdFingerprintGenerator"]
|
| 72 |
+
rdMolDescriptors = modules["rdMolDescriptors"]
|
| 73 |
+
DataStructs = modules["DataStructs"]
|
| 74 |
+
|
| 75 |
+
smiles = assemble_surrogate_smiles(molecule)
|
| 76 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 77 |
+
if mol is None:
|
| 78 |
+
return dict(fallback_properties)
|
| 79 |
+
canonical = Chem.MolToSmiles(mol)
|
| 80 |
+
|
| 81 |
+
qed_value = _tdc_oracle_score("QED", canonical)
|
| 82 |
+
if qed_value is None:
|
| 83 |
+
qed_value = float(QED.qed(mol))
|
| 84 |
+
qed_score = _clamp01(qed_value)
|
| 85 |
+
|
| 86 |
+
sa_value = _tdc_oracle_score("SA", canonical)
|
| 87 |
+
synth_score = _normalize_sa(sa_value)
|
| 88 |
+
if synth_score is None:
|
| 89 |
+
synth_score = _rdkit_synth_proxy(mol, Descriptors, Lipinski, rdMolDescriptors)
|
| 90 |
+
|
| 91 |
+
logp = float(Crippen.MolLogP(mol))
|
| 92 |
+
tpsa = float(Descriptors.TPSA(mol))
|
| 93 |
+
mol_wt = float(Descriptors.MolWt(mol))
|
| 94 |
+
rotatable = float(Lipinski.NumRotatableBonds(mol))
|
| 95 |
+
aromatic_rings = float(rdMolDescriptors.CalcNumAromaticRings(mol))
|
| 96 |
+
|
| 97 |
+
property_risk = _property_risk(logp=logp, tpsa=tpsa, mol_wt=mol_wt, rotatable=rotatable)
|
| 98 |
+
structural_risk = _structural_alert_risk(molecule)
|
| 99 |
+
rdkit_toxicity = _clamp01(0.55 * property_risk + 0.45 * structural_risk)
|
| 100 |
+
|
| 101 |
+
target_fit = _target_fit_proxy(
|
| 102 |
+
molecule,
|
| 103 |
+
qed_score=qed_score,
|
| 104 |
+
logp=logp,
|
| 105 |
+
tpsa=tpsa,
|
| 106 |
+
aromatic_rings=aromatic_rings,
|
| 107 |
+
)
|
| 108 |
+
novelty = _novelty_proxy(mol, Chem, rdFingerprintGenerator, DataStructs)
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"potency": round(_blend(fallback_properties["potency"], target_fit, 0.35), 4),
|
| 112 |
+
"safety": round(_clamp01(1.0 - _blend(fallback_properties["toxicity"], rdkit_toxicity, 0.25)), 4),
|
| 113 |
+
"toxicity": round(_blend(fallback_properties["toxicity"], rdkit_toxicity, 0.25), 4),
|
| 114 |
+
"synth": round(_blend(fallback_properties["synth"], synth_score, 0.55), 4),
|
| 115 |
+
"novelty": round(_blend(fallback_properties["novelty"], novelty, 0.50), 4),
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@lru_cache(maxsize=1)
|
| 120 |
+
def _rdkit_modules() -> Optional[Dict[str, Any]]:
|
| 121 |
+
try:
|
| 122 |
+
from rdkit import Chem, DataStructs
|
| 123 |
+
from rdkit.Chem import Crippen, Descriptors, Lipinski, QED, rdFingerprintGenerator, rdMolDescriptors
|
| 124 |
+
except Exception:
|
| 125 |
+
return None
|
| 126 |
+
return {
|
| 127 |
+
"Chem": Chem,
|
| 128 |
+
"Crippen": Crippen,
|
| 129 |
+
"DataStructs": DataStructs,
|
| 130 |
+
"Descriptors": Descriptors,
|
| 131 |
+
"Lipinski": Lipinski,
|
| 132 |
+
"QED": QED,
|
| 133 |
+
"rdFingerprintGenerator": rdFingerprintGenerator,
|
| 134 |
+
"rdMolDescriptors": rdMolDescriptors,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@lru_cache(maxsize=1)
|
| 139 |
+
def _tdc_oracle_class() -> Optional[Any]:
|
| 140 |
+
try:
|
| 141 |
+
from tdc import Oracle
|
| 142 |
+
except Exception:
|
| 143 |
+
return None
|
| 144 |
+
return Oracle
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@lru_cache(maxsize=8)
|
| 148 |
+
def _tdc_oracle(name: str) -> Optional[Any]:
|
| 149 |
+
oracle_class = _tdc_oracle_class()
|
| 150 |
+
if oracle_class is None:
|
| 151 |
+
return None
|
| 152 |
+
try:
|
| 153 |
+
return oracle_class(name=name)
|
| 154 |
+
except Exception:
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _tdc_oracle_score(name: str, smiles: str) -> Optional[float]:
|
| 159 |
+
oracle = _tdc_oracle(name)
|
| 160 |
+
if oracle is None:
|
| 161 |
+
return None
|
| 162 |
+
try:
|
| 163 |
+
value = oracle(smiles)
|
| 164 |
+
except Exception:
|
| 165 |
+
return None
|
| 166 |
+
try:
|
| 167 |
+
return float(value)
|
| 168 |
+
except (TypeError, ValueError):
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _normalize_sa(value: Optional[float]) -> Optional[float]:
|
| 173 |
+
if value is None:
|
| 174 |
+
return None
|
| 175 |
+
if 0.0 <= value <= 1.0:
|
| 176 |
+
return _clamp01(value)
|
| 177 |
+
return _clamp01((10.0 - value) / 9.0)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _rdkit_synth_proxy(mol: Any, Descriptors: Any, Lipinski: Any, rdMolDescriptors: Any) -> float:
|
| 181 |
+
mol_wt = float(Descriptors.MolWt(mol))
|
| 182 |
+
rotatable = float(Lipinski.NumRotatableBonds(mol))
|
| 183 |
+
stereocenters = float(rdMolDescriptors.CalcNumAtomStereoCenters(mol))
|
| 184 |
+
ring_count = float(rdMolDescriptors.CalcNumRings(mol))
|
| 185 |
+
aromatic_rings = float(rdMolDescriptors.CalcNumAromaticRings(mol))
|
| 186 |
+
complexity = (
|
| 187 |
+
max(0.0, mol_wt - 350.0) / 260.0
|
| 188 |
+
+ rotatable / 12.0
|
| 189 |
+
+ stereocenters / 4.0
|
| 190 |
+
+ max(0.0, ring_count - 3.0) / 4.0
|
| 191 |
+
+ aromatic_rings / 8.0
|
| 192 |
+
)
|
| 193 |
+
return _clamp01(1.0 - 0.35 * complexity)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _property_risk(*, logp: float, tpsa: float, mol_wt: float, rotatable: float) -> float:
|
| 197 |
+
logp_risk = _sigmoid((logp - 3.5) / 1.15)
|
| 198 |
+
size_risk = _sigmoid((mol_wt - 500.0) / 90.0)
|
| 199 |
+
flexibility_risk = _sigmoid((rotatable - 8.0) / 2.5)
|
| 200 |
+
polarity_risk = _sigmoid((tpsa - 130.0) / 32.0)
|
| 201 |
+
return _clamp01(0.42 * logp_risk + 0.24 * size_risk + 0.20 * flexibility_risk + 0.14 * polarity_risk)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _structural_alert_risk(molecule: Mapping[str, str]) -> float:
|
| 205 |
+
risk = 0.18
|
| 206 |
+
if molecule["warhead"] == "acrylamide":
|
| 207 |
+
risk += 0.12
|
| 208 |
+
if molecule["warhead"] == "vinyl_sulfonamide":
|
| 209 |
+
risk += 0.22
|
| 210 |
+
if molecule["solvent_tail"] == "dimethylamino":
|
| 211 |
+
risk += 0.24
|
| 212 |
+
if molecule["back_pocket"] == "trifluoromethyl":
|
| 213 |
+
risk += 0.20
|
| 214 |
+
if molecule["hinge"] == "fluorophenyl" and molecule["back_pocket"] in {"chloro", "trifluoromethyl"}:
|
| 215 |
+
risk += 0.12
|
| 216 |
+
if molecule["solvent_tail"] in {"morpholine", "piperazine"}:
|
| 217 |
+
risk -= 0.08
|
| 218 |
+
if molecule["warhead"] == "nitrile":
|
| 219 |
+
risk -= 0.08
|
| 220 |
+
return _clamp01(risk)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _target_fit_proxy(
|
| 224 |
+
molecule: Mapping[str, str],
|
| 225 |
+
*,
|
| 226 |
+
qed_score: float,
|
| 227 |
+
logp: float,
|
| 228 |
+
tpsa: float,
|
| 229 |
+
aromatic_rings: float,
|
| 230 |
+
) -> float:
|
| 231 |
+
lipophilic_match = 1.0 - min(abs(logp - 3.0) / 4.0, 1.0)
|
| 232 |
+
polarity_match = 1.0 - min(abs(tpsa - 85.0) / 110.0, 1.0)
|
| 233 |
+
pocket_match = 0.0
|
| 234 |
+
if molecule["hinge"] in {"azaindole", "quinazoline"}:
|
| 235 |
+
pocket_match += 0.18
|
| 236 |
+
if molecule["back_pocket"] in {"cyano", "chloro", "trifluoromethyl"}:
|
| 237 |
+
pocket_match += 0.14
|
| 238 |
+
if molecule["warhead"] in {"acrylamide", "reversible_cyanoacrylamide", "nitrile"}:
|
| 239 |
+
pocket_match += 0.12
|
| 240 |
+
if aromatic_rings >= 2:
|
| 241 |
+
pocket_match += 0.08
|
| 242 |
+
return _clamp01(0.20 + 0.30 * lipophilic_match + 0.22 * polarity_match + 0.18 * qed_score + pocket_match)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _novelty_proxy(mol: Any, Chem: Any, rdFingerprintGenerator: Any, DataStructs: Any) -> float:
|
| 246 |
+
refs = [
|
| 247 |
+
"c%10(C(=O)NC=C)c(c1ccncc1)c(C1CC1)c(OC)cc%10",
|
| 248 |
+
"c%10(C#N)c(c1ccncc1)c(N1CCOCC1)c(C#N)cc%10",
|
| 249 |
+
"c%10(C(=O)NC=C)c(c1ccc(F)cc1)c(N(C)C)c(Cl)cc%10",
|
| 250 |
+
]
|
| 251 |
+
generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=1024)
|
| 252 |
+
fp = generator.GetFingerprint(mol)
|
| 253 |
+
similarities = []
|
| 254 |
+
for ref in refs:
|
| 255 |
+
ref_mol = Chem.MolFromSmiles(ref)
|
| 256 |
+
if ref_mol is None:
|
| 257 |
+
continue
|
| 258 |
+
ref_fp = generator.GetFingerprint(ref_mol)
|
| 259 |
+
similarities.append(float(DataStructs.TanimotoSimilarity(fp, ref_fp)))
|
| 260 |
+
if not similarities:
|
| 261 |
+
return 0.5
|
| 262 |
+
return _clamp01(1.0 - max(similarities))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _blend(fallback_value: float, oracle_value: float, oracle_weight: float) -> float:
|
| 266 |
+
return _clamp01((1.0 - oracle_weight) * fallback_value + oracle_weight * oracle_value)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _sigmoid(value: float) -> float:
|
| 270 |
+
return 1.0 / (1.0 + math.exp(-value))
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _clamp01(value: float) -> float:
|
| 274 |
+
return min(max(float(value), 0.0), 1.0)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: molforge
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
openenv_shim.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""Lightweight openenv-core shim for environments that only need the base types.
|
| 3 |
+
|
| 4 |
+
Import this module **before** any ``from openenv.core...`` imports when the
|
| 5 |
+
full ``openenv-core`` package is not installed (e.g. Colab RL training). It
|
| 6 |
+
registers minimal stubs into ``sys.modules`` so that the following imports
|
| 7 |
+
work identically to the real package:
|
| 8 |
+
|
| 9 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 10 |
+
from openenv.core.env_server.interfaces import Environment
|
| 11 |
+
|
| 12 |
+
Usage::
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import openenv # real package available
|
| 16 |
+
except ImportError:
|
| 17 |
+
import openenv_shim # registers lightweight stubs
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
from abc import ABC, abstractmethod
|
| 24 |
+
from types import ModuleType
|
| 25 |
+
from typing import Any, Dict, Optional
|
| 26 |
+
|
| 27 |
+
from pydantic import BaseModel, Field
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ── Base types (mirror openenv.core.env_server.types) ────────────────────
|
| 31 |
+
|
| 32 |
+
class Action(BaseModel):
|
| 33 |
+
"""Minimal action base matching openenv-core's Action."""
|
| 34 |
+
|
| 35 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Observation(BaseModel):
|
| 39 |
+
"""Minimal observation base matching openenv-core's Observation."""
|
| 40 |
+
|
| 41 |
+
done: bool = False
|
| 42 |
+
reward: float = 0.0
|
| 43 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class State(BaseModel):
|
| 47 |
+
"""Minimal state base matching openenv-core's State."""
|
| 48 |
+
|
| 49 |
+
episode_id: str = ""
|
| 50 |
+
step_count: int = 0
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ── Environment ABC (mirror openenv.core.env_server.interfaces) ──────────
|
| 54 |
+
|
| 55 |
+
class Environment(ABC):
|
| 56 |
+
"""Minimal environment ABC matching openenv-core's Environment."""
|
| 57 |
+
|
| 58 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = False
|
| 59 |
+
|
| 60 |
+
def __init__(self, **_kwargs: Any):
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
def reset(self, **kwargs: Any) -> Any:
|
| 65 |
+
...
|
| 66 |
+
|
| 67 |
+
@abstractmethod
|
| 68 |
+
def step(self, action: Any, **kwargs: Any) -> Any:
|
| 69 |
+
...
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
@abstractmethod
|
| 73 |
+
def state(self) -> Any:
|
| 74 |
+
...
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ── Register shim modules into sys.modules ───────────────────────────────
|
| 78 |
+
|
| 79 |
+
def _register() -> None:
|
| 80 |
+
"""Inject stub modules so ``from openenv.core...`` imports resolve."""
|
| 81 |
+
|
| 82 |
+
# Build the types module
|
| 83 |
+
types_mod = ModuleType("openenv.core.env_server.types")
|
| 84 |
+
types_mod.Action = Action # type: ignore[attr-defined]
|
| 85 |
+
types_mod.Observation = Observation # type: ignore[attr-defined]
|
| 86 |
+
types_mod.State = State # type: ignore[attr-defined]
|
| 87 |
+
|
| 88 |
+
# Build the interfaces module
|
| 89 |
+
interfaces_mod = ModuleType("openenv.core.env_server.interfaces")
|
| 90 |
+
interfaces_mod.Environment = Environment # type: ignore[attr-defined]
|
| 91 |
+
|
| 92 |
+
# Build the package hierarchy
|
| 93 |
+
openenv_mod = ModuleType("openenv")
|
| 94 |
+
core_mod = ModuleType("openenv.core")
|
| 95 |
+
env_server_mod = ModuleType("openenv.core.env_server")
|
| 96 |
+
|
| 97 |
+
# Wire up sub-modules
|
| 98 |
+
env_server_mod.types = types_mod # type: ignore[attr-defined]
|
| 99 |
+
env_server_mod.interfaces = interfaces_mod # type: ignore[attr-defined]
|
| 100 |
+
core_mod.env_server = env_server_mod # type: ignore[attr-defined]
|
| 101 |
+
openenv_mod.core = core_mod # type: ignore[attr-defined]
|
| 102 |
+
|
| 103 |
+
# Register everything
|
| 104 |
+
for name, mod in [
|
| 105 |
+
("openenv", openenv_mod),
|
| 106 |
+
("openenv.core", core_mod),
|
| 107 |
+
("openenv.core.env_server", env_server_mod),
|
| 108 |
+
("openenv.core.env_server.types", types_mod),
|
| 109 |
+
("openenv.core.env_server.interfaces", interfaces_mod),
|
| 110 |
+
]:
|
| 111 |
+
sys.modules.setdefault(name, mod)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
_register()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=69", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-molforge"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "MolForge: a deterministic medicinal-chemistry OpenEnv environment."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"openenv-core[core]>=0.2.3",
|
| 13 |
+
"pydantic>=2.8.0",
|
| 14 |
+
"rdkit>=2023.9.5,<2024.3.1; python_version < '3.13'",
|
| 15 |
+
"rdkit>=2026.3.1; python_version >= '3.13'",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[project.optional-dependencies]
|
| 19 |
+
dev = [
|
| 20 |
+
"pytest>=8.0.0",
|
| 21 |
+
]
|
| 22 |
+
tdc = [
|
| 23 |
+
"pytdc>=1.1.0; python_version < '3.13'",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[project.scripts]
|
| 27 |
+
server = "molforge.server.app:main"
|
| 28 |
+
|
| 29 |
+
[tool.setuptools]
|
| 30 |
+
include-package-data = true
|
| 31 |
+
packages = ["molforge", "molforge.server"]
|
| 32 |
+
package-dir = { "molforge" = ".", "molforge.server" = "server" }
|
scenarios.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario configs and RDKit/TDC-backed surrogate chemistry for MolForge."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Dict, Iterable, List, Mapping
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
SLOT_ORDER = ["warhead", "hinge", "solvent_tail", "back_pocket"]
|
| 10 |
+
EDITABLE_SLOTS = ["warhead", "hinge", "solvent_tail", "back_pocket"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class FragmentSpec:
|
| 15 |
+
"""Per-fragment surrogate property contributions."""
|
| 16 |
+
|
| 17 |
+
name: str
|
| 18 |
+
potency: float
|
| 19 |
+
safety: float
|
| 20 |
+
synth: float
|
| 21 |
+
novelty: float
|
| 22 |
+
literature_hint: str
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class ScenarioConfig:
|
| 27 |
+
"""Single evaluation scenario."""
|
| 28 |
+
|
| 29 |
+
scenario_id: str
|
| 30 |
+
difficulty: str
|
| 31 |
+
target_name: str
|
| 32 |
+
task_brief: str
|
| 33 |
+
oracle_budget: int
|
| 34 |
+
max_steps: int
|
| 35 |
+
starting_scaffold: Mapping[str, str]
|
| 36 |
+
restart_scaffold: Mapping[str, str]
|
| 37 |
+
objective_weights: Mapping[str, float]
|
| 38 |
+
hard_constraints: Mapping[str, float]
|
| 39 |
+
target_shift_step: int | None = None
|
| 40 |
+
trap_penalty: bool = False
|
| 41 |
+
enabled_tools: List[str] = field(default_factory=list)
|
| 42 |
+
enabled_actions: List[str] = field(default_factory=list)
|
| 43 |
+
coordination_mode: str = "multi_agent"
|
| 44 |
+
enabled_roles: List[str] = field(default_factory=list)
|
| 45 |
+
required_review_roles: List[str] = field(default_factory=list)
|
| 46 |
+
max_messages_per_turn: int = 4
|
| 47 |
+
baseline_to_beat: float = 0.5
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
FRAGMENT_LIBRARY: Dict[str, Dict[str, FragmentSpec]] = {
|
| 51 |
+
"warhead": {
|
| 52 |
+
"acrylamide": FragmentSpec(
|
| 53 |
+
"acrylamide",
|
| 54 |
+
potency=0.18,
|
| 55 |
+
safety=-0.03,
|
| 56 |
+
synth=0.02,
|
| 57 |
+
novelty=0.03,
|
| 58 |
+
literature_hint="Covalent warheads often boost KRAS potency but can increase reactivity risk.",
|
| 59 |
+
),
|
| 60 |
+
"reversible_cyanoacrylamide": FragmentSpec(
|
| 61 |
+
"reversible_cyanoacrylamide",
|
| 62 |
+
potency=0.16,
|
| 63 |
+
safety=0.06,
|
| 64 |
+
synth=-0.04,
|
| 65 |
+
novelty=0.08,
|
| 66 |
+
literature_hint="Reversible covalent warheads can preserve potency while softening safety liabilities.",
|
| 67 |
+
),
|
| 68 |
+
"nitrile": FragmentSpec(
|
| 69 |
+
"nitrile",
|
| 70 |
+
potency=0.11,
|
| 71 |
+
safety=0.09,
|
| 72 |
+
synth=0.05,
|
| 73 |
+
novelty=0.04,
|
| 74 |
+
literature_hint="Nitrile warheads are safer but may need stronger pocket complementarity to keep potency.",
|
| 75 |
+
),
|
| 76 |
+
"vinyl_sulfonamide": FragmentSpec(
|
| 77 |
+
"vinyl_sulfonamide",
|
| 78 |
+
potency=0.13,
|
| 79 |
+
safety=-0.07,
|
| 80 |
+
synth=-0.05,
|
| 81 |
+
novelty=0.10,
|
| 82 |
+
literature_hint="Sulfonamide warheads can be potent but often pressure synthesis and safety.",
|
| 83 |
+
),
|
| 84 |
+
},
|
| 85 |
+
"hinge": {
|
| 86 |
+
"azaindole": FragmentSpec(
|
| 87 |
+
"azaindole",
|
| 88 |
+
potency=0.17,
|
| 89 |
+
safety=0.01,
|
| 90 |
+
synth=-0.03,
|
| 91 |
+
novelty=0.06,
|
| 92 |
+
literature_hint="Azaindoles are strong binders in KRAS-like pockets when the warhead is well aligned.",
|
| 93 |
+
),
|
| 94 |
+
"pyridine": FragmentSpec(
|
| 95 |
+
"pyridine",
|
| 96 |
+
potency=0.10,
|
| 97 |
+
safety=0.04,
|
| 98 |
+
synth=0.05,
|
| 99 |
+
novelty=0.02,
|
| 100 |
+
literature_hint="Simple heteroaryl hinges improve tractability and keep synthesis accessible.",
|
| 101 |
+
),
|
| 102 |
+
"fluorophenyl": FragmentSpec(
|
| 103 |
+
"fluorophenyl",
|
| 104 |
+
potency=0.12,
|
| 105 |
+
safety=-0.08,
|
| 106 |
+
synth=0.04,
|
| 107 |
+
novelty=0.03,
|
| 108 |
+
literature_hint="Hydrophobic hinge binders can lift affinity while increasing lipophilic liability.",
|
| 109 |
+
),
|
| 110 |
+
"quinazoline": FragmentSpec(
|
| 111 |
+
"quinazoline",
|
| 112 |
+
potency=0.15,
|
| 113 |
+
safety=-0.04,
|
| 114 |
+
synth=-0.06,
|
| 115 |
+
novelty=0.05,
|
| 116 |
+
literature_hint="Quinazolines are potent but can create a heavy, synthesis-taxing scaffold.",
|
| 117 |
+
),
|
| 118 |
+
},
|
| 119 |
+
"solvent_tail": {
|
| 120 |
+
"morpholine": FragmentSpec(
|
| 121 |
+
"morpholine",
|
| 122 |
+
potency=0.06,
|
| 123 |
+
safety=0.16,
|
| 124 |
+
synth=0.07,
|
| 125 |
+
novelty=0.02,
|
| 126 |
+
literature_hint="Morpholine tails frequently de-risk hERG and improve solubility.",
|
| 127 |
+
),
|
| 128 |
+
"piperazine": FragmentSpec(
|
| 129 |
+
"piperazine",
|
| 130 |
+
potency=0.05,
|
| 131 |
+
safety=0.10,
|
| 132 |
+
synth=0.03,
|
| 133 |
+
novelty=0.03,
|
| 134 |
+
literature_hint="Basic cyclic tails improve polarity but can trigger clearance concerns if overused.",
|
| 135 |
+
),
|
| 136 |
+
"cyclopropyl": FragmentSpec(
|
| 137 |
+
"cyclopropyl",
|
| 138 |
+
potency=0.08,
|
| 139 |
+
safety=-0.03,
|
| 140 |
+
synth=0.04,
|
| 141 |
+
novelty=0.04,
|
| 142 |
+
literature_hint="Compact hydrophobes sometimes improve fit but rarely help safety.",
|
| 143 |
+
),
|
| 144 |
+
"dimethylamino": FragmentSpec(
|
| 145 |
+
"dimethylamino",
|
| 146 |
+
potency=0.04,
|
| 147 |
+
safety=-0.13,
|
| 148 |
+
synth=0.02,
|
| 149 |
+
novelty=0.04,
|
| 150 |
+
literature_hint="Strongly basic tails can quickly create cardiac and CNS liabilities.",
|
| 151 |
+
),
|
| 152 |
+
},
|
| 153 |
+
"back_pocket": {
|
| 154 |
+
"methoxy": FragmentSpec(
|
| 155 |
+
"methoxy",
|
| 156 |
+
potency=0.07,
|
| 157 |
+
safety=0.08,
|
| 158 |
+
synth=0.06,
|
| 159 |
+
novelty=0.02,
|
| 160 |
+
literature_hint="Small polar back-pocket groups often stabilize potency without blowing up toxicity.",
|
| 161 |
+
),
|
| 162 |
+
"chloro": FragmentSpec(
|
| 163 |
+
"chloro",
|
| 164 |
+
potency=0.12,
|
| 165 |
+
safety=-0.12,
|
| 166 |
+
synth=0.04,
|
| 167 |
+
novelty=0.02,
|
| 168 |
+
literature_hint="Halogens often buy potency at the cost of lipophilic risk.",
|
| 169 |
+
),
|
| 170 |
+
"trifluoromethyl": FragmentSpec(
|
| 171 |
+
"trifluoromethyl",
|
| 172 |
+
potency=0.14,
|
| 173 |
+
safety=-0.15,
|
| 174 |
+
synth=-0.02,
|
| 175 |
+
novelty=0.06,
|
| 176 |
+
literature_hint="CF3 groups can strongly improve affinity but frequently over-shoot safety windows.",
|
| 177 |
+
),
|
| 178 |
+
"cyano": FragmentSpec(
|
| 179 |
+
"cyano",
|
| 180 |
+
potency=0.10,
|
| 181 |
+
safety=0.03,
|
| 182 |
+
synth=0.01,
|
| 183 |
+
novelty=0.05,
|
| 184 |
+
literature_hint="Cyano groups are efficient potency handles when hydrophobic groups are too risky.",
|
| 185 |
+
),
|
| 186 |
+
},
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
DEFAULT_TOOL_COSTS: Dict[str, int] = {
|
| 190 |
+
"evaluate_properties": 50,
|
| 191 |
+
"search_literature": 100,
|
| 192 |
+
"dock_target": 300,
|
| 193 |
+
"estimate_synthesizability": 120,
|
| 194 |
+
"evaluate_novelty": 80,
|
| 195 |
+
"assay_toxicity": 2000,
|
| 196 |
+
"run_md_simulation": 2500,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
SCENARIOS: List[ScenarioConfig] = [
|
| 201 |
+
ScenarioConfig(
|
| 202 |
+
scenario_id="level_0_easy",
|
| 203 |
+
difficulty="easy",
|
| 204 |
+
target_name="KRAS G12C",
|
| 205 |
+
task_brief=(
|
| 206 |
+
"Improve target potency while repairing a mild safety liability and keeping synthesis "
|
| 207 |
+
"evidence current. The starting scaffold is close, but a strong submission still needs "
|
| 208 |
+
"the right edit sequence plus assay support."
|
| 209 |
+
),
|
| 210 |
+
oracle_budget=3600,
|
| 211 |
+
max_steps=7,
|
| 212 |
+
starting_scaffold={
|
| 213 |
+
"warhead": "acrylamide",
|
| 214 |
+
"hinge": "pyridine",
|
| 215 |
+
"solvent_tail": "cyclopropyl",
|
| 216 |
+
"back_pocket": "chloro",
|
| 217 |
+
},
|
| 218 |
+
restart_scaffold={
|
| 219 |
+
"warhead": "reversible_cyanoacrylamide",
|
| 220 |
+
"hinge": "pyridine",
|
| 221 |
+
"solvent_tail": "morpholine",
|
| 222 |
+
"back_pocket": "methoxy",
|
| 223 |
+
},
|
| 224 |
+
objective_weights={
|
| 225 |
+
"potency": 0.55,
|
| 226 |
+
"safety": 0.15,
|
| 227 |
+
"synth": 0.15,
|
| 228 |
+
"novelty": 0.15,
|
| 229 |
+
},
|
| 230 |
+
hard_constraints={"potency_min": 0.84, "toxicity_max": 0.28, "synth_min": 0.62},
|
| 231 |
+
enabled_tools=list(DEFAULT_TOOL_COSTS.keys()),
|
| 232 |
+
enabled_actions=["edit", "run_assay", "submit", "defer", "restart"],
|
| 233 |
+
enabled_roles=[
|
| 234 |
+
"lead_chemist",
|
| 235 |
+
"toxicologist",
|
| 236 |
+
"assay_planner",
|
| 237 |
+
"process_chemist",
|
| 238 |
+
],
|
| 239 |
+
required_review_roles=["toxicologist", "assay_planner", "process_chemist"],
|
| 240 |
+
baseline_to_beat=0.70,
|
| 241 |
+
),
|
| 242 |
+
ScenarioConfig(
|
| 243 |
+
scenario_id="level_1_medium",
|
| 244 |
+
difficulty="medium",
|
| 245 |
+
target_name="KRAS G12C",
|
| 246 |
+
task_brief=(
|
| 247 |
+
"Balance potency, toxicity, and synthesizability under budget pressure. The best "
|
| 248 |
+
"molecules require coordinated safety edits plus current assay evidence."
|
| 249 |
+
),
|
| 250 |
+
oracle_budget=4300,
|
| 251 |
+
max_steps=8,
|
| 252 |
+
starting_scaffold={
|
| 253 |
+
"warhead": "acrylamide",
|
| 254 |
+
"hinge": "fluorophenyl",
|
| 255 |
+
"solvent_tail": "dimethylamino",
|
| 256 |
+
"back_pocket": "chloro",
|
| 257 |
+
},
|
| 258 |
+
restart_scaffold={
|
| 259 |
+
"warhead": "reversible_cyanoacrylamide",
|
| 260 |
+
"hinge": "azaindole",
|
| 261 |
+
"solvent_tail": "morpholine",
|
| 262 |
+
"back_pocket": "cyano",
|
| 263 |
+
},
|
| 264 |
+
objective_weights={
|
| 265 |
+
"potency": 0.42,
|
| 266 |
+
"safety": 0.33,
|
| 267 |
+
"synth": 0.13,
|
| 268 |
+
"novelty": 0.12,
|
| 269 |
+
},
|
| 270 |
+
hard_constraints={"potency_min": 0.76, "toxicity_max": 0.34, "synth_min": 0.62},
|
| 271 |
+
enabled_tools=list(DEFAULT_TOOL_COSTS.keys()),
|
| 272 |
+
enabled_actions=["edit", "run_assay", "submit", "defer", "restart"],
|
| 273 |
+
enabled_roles=[
|
| 274 |
+
"lead_chemist",
|
| 275 |
+
"toxicologist",
|
| 276 |
+
"assay_planner",
|
| 277 |
+
"process_chemist",
|
| 278 |
+
],
|
| 279 |
+
required_review_roles=["toxicologist", "assay_planner", "process_chemist"],
|
| 280 |
+
baseline_to_beat=0.64,
|
| 281 |
+
),
|
| 282 |
+
ScenarioConfig(
|
| 283 |
+
scenario_id="level_2_hard",
|
| 284 |
+
difficulty="hard",
|
| 285 |
+
target_name="KRAS G12C resistance panel",
|
| 286 |
+
task_brief=(
|
| 287 |
+
"Solve a non-stationary design problem with a fixed, problematic core. The starting "
|
| 288 |
+
"series is a sunk-cost trap, and the target pocket shifts late in the episode."
|
| 289 |
+
),
|
| 290 |
+
oracle_budget=5000,
|
| 291 |
+
max_steps=9,
|
| 292 |
+
starting_scaffold={
|
| 293 |
+
"warhead": "acrylamide",
|
| 294 |
+
"hinge": "quinazoline",
|
| 295 |
+
"solvent_tail": "dimethylamino",
|
| 296 |
+
"back_pocket": "trifluoromethyl",
|
| 297 |
+
},
|
| 298 |
+
restart_scaffold={
|
| 299 |
+
"warhead": "nitrile",
|
| 300 |
+
"hinge": "azaindole",
|
| 301 |
+
"solvent_tail": "morpholine",
|
| 302 |
+
"back_pocket": "cyano",
|
| 303 |
+
},
|
| 304 |
+
objective_weights={
|
| 305 |
+
"potency": 0.38,
|
| 306 |
+
"safety": 0.32,
|
| 307 |
+
"synth": 0.16,
|
| 308 |
+
"novelty": 0.14,
|
| 309 |
+
},
|
| 310 |
+
hard_constraints={"potency_min": 0.78, "toxicity_max": 0.46, "synth_min": 0.62},
|
| 311 |
+
target_shift_step=4,
|
| 312 |
+
trap_penalty=True,
|
| 313 |
+
enabled_tools=list(DEFAULT_TOOL_COSTS.keys()),
|
| 314 |
+
enabled_actions=["edit", "run_assay", "submit", "defer", "restart"],
|
| 315 |
+
enabled_roles=[
|
| 316 |
+
"lead_chemist",
|
| 317 |
+
"toxicologist",
|
| 318 |
+
"assay_planner",
|
| 319 |
+
"process_chemist",
|
| 320 |
+
],
|
| 321 |
+
required_review_roles=["toxicologist", "assay_planner", "process_chemist"],
|
| 322 |
+
baseline_to_beat=0.66,
|
| 323 |
+
),
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
SCENARIO_BY_ID = {scenario.scenario_id: scenario for scenario in SCENARIOS}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def get_scenario(index: int) -> ScenarioConfig:
|
| 330 |
+
"""Return scenarios in a stable cycle so repeated resets cover all tasks."""
|
| 331 |
+
|
| 332 |
+
return SCENARIOS[index % len(SCENARIOS)]
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def format_molecule(molecule: Mapping[str, str]) -> str:
|
| 336 |
+
"""Human-readable canonical representation."""
|
| 337 |
+
|
| 338 |
+
ordered = [f"{slot}={molecule[slot]}" for slot in SLOT_ORDER]
|
| 339 |
+
return " | ".join(ordered)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def fragment_choices(slot: str) -> List[str]:
|
| 343 |
+
"""Return the editable fragments for a slot."""
|
| 344 |
+
|
| 345 |
+
return sorted(FRAGMENT_LIBRARY[slot].keys())
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def evaluate_molecule(
|
| 349 |
+
molecule: Mapping[str, str],
|
| 350 |
+
scenario: ScenarioConfig,
|
| 351 |
+
*,
|
| 352 |
+
target_shift_active: bool = False,
|
| 353 |
+
) -> Dict[str, float]:
|
| 354 |
+
"""Evaluate a molecule with target logic plus RDKit/TDC medicinal chemistry signals."""
|
| 355 |
+
|
| 356 |
+
potency = 0.23
|
| 357 |
+
safety = 0.56
|
| 358 |
+
synth = 0.58
|
| 359 |
+
novelty = 0.18
|
| 360 |
+
|
| 361 |
+
for slot, fragment_name in molecule.items():
|
| 362 |
+
fragment = FRAGMENT_LIBRARY[slot][fragment_name]
|
| 363 |
+
potency += fragment.potency
|
| 364 |
+
safety += fragment.safety
|
| 365 |
+
synth += fragment.synth
|
| 366 |
+
novelty += fragment.novelty
|
| 367 |
+
|
| 368 |
+
if molecule["warhead"] == "acrylamide" and molecule["hinge"] == "azaindole":
|
| 369 |
+
potency += 0.10
|
| 370 |
+
if molecule["solvent_tail"] == "morpholine" and molecule["back_pocket"] == "methoxy":
|
| 371 |
+
safety += 0.08
|
| 372 |
+
if molecule["hinge"] == "fluorophenyl" and molecule["back_pocket"] == "chloro":
|
| 373 |
+
potency += 0.06
|
| 374 |
+
safety -= 0.16
|
| 375 |
+
if molecule["solvent_tail"] == "dimethylamino" and molecule["back_pocket"] == "trifluoromethyl":
|
| 376 |
+
safety -= 0.15
|
| 377 |
+
if molecule["warhead"] == "nitrile" and molecule["back_pocket"] == "cyano":
|
| 378 |
+
potency += 0.04
|
| 379 |
+
novelty += 0.03
|
| 380 |
+
if molecule["warhead"] == "reversible_cyanoacrylamide" and molecule["solvent_tail"] == "morpholine":
|
| 381 |
+
safety += 0.05
|
| 382 |
+
|
| 383 |
+
if target_shift_active:
|
| 384 |
+
if molecule["warhead"] == "acrylamide":
|
| 385 |
+
potency -= 0.16
|
| 386 |
+
if molecule["warhead"] == "nitrile":
|
| 387 |
+
potency += 0.10
|
| 388 |
+
if molecule["back_pocket"] == "cyano":
|
| 389 |
+
potency += 0.03
|
| 390 |
+
|
| 391 |
+
if scenario.trap_penalty:
|
| 392 |
+
potency = min(potency, 0.71)
|
| 393 |
+
safety = min(safety, 0.44)
|
| 394 |
+
|
| 395 |
+
potency = min(max(potency, 0.0), 1.0)
|
| 396 |
+
safety = min(max(safety, 0.0), 1.0)
|
| 397 |
+
synth = min(max(synth, 0.0), 1.0)
|
| 398 |
+
novelty = min(max(novelty, 0.0), 1.0)
|
| 399 |
+
toxicity = min(max(1.0 - safety, 0.0), 1.0)
|
| 400 |
+
|
| 401 |
+
fallback_properties = {
|
| 402 |
+
"potency": round(potency, 4),
|
| 403 |
+
"safety": round(safety, 4),
|
| 404 |
+
"toxicity": round(toxicity, 4),
|
| 405 |
+
"synth": round(synth, 4),
|
| 406 |
+
"novelty": round(novelty, 4),
|
| 407 |
+
}
|
| 408 |
+
try:
|
| 409 |
+
from molforge_oracles import evaluate_with_rdkit_tdc
|
| 410 |
+
except Exception:
|
| 411 |
+
return fallback_properties
|
| 412 |
+
return evaluate_with_rdkit_tdc(molecule, fallback_properties)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def molecule_to_smiles(molecule: Mapping[str, str]) -> str:
|
| 416 |
+
"""Return the RDKit/TDC surrogate SMILES used by the chemistry oracle."""
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
from molforge_oracles import assemble_surrogate_smiles
|
| 420 |
+
except Exception:
|
| 421 |
+
return ""
|
| 422 |
+
return assemble_surrogate_smiles(molecule)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def oracle_backend_status() -> Dict[str, bool]:
|
| 426 |
+
"""Return whether RDKit and TDC are active for scoring."""
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
from molforge_oracles import oracle_backend_status as backend_status
|
| 430 |
+
except Exception:
|
| 431 |
+
return {"rdkit": False, "tdc": False}
|
| 432 |
+
return backend_status()
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def compute_objective_score(properties: Mapping[str, float], scenario: ScenarioConfig) -> float:
|
| 436 |
+
"""Aggregate visible scientific goals into a single 0-1 quality score."""
|
| 437 |
+
|
| 438 |
+
safety_score = 1.0 - properties["toxicity"]
|
| 439 |
+
score = (
|
| 440 |
+
scenario.objective_weights["potency"] * properties["potency"]
|
| 441 |
+
+ scenario.objective_weights["safety"] * safety_score
|
| 442 |
+
+ scenario.objective_weights["synth"] * properties["synth"]
|
| 443 |
+
+ scenario.objective_weights["novelty"] * properties["novelty"]
|
| 444 |
+
)
|
| 445 |
+
return round(min(max(score, 0.0), 1.0), 4)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def evaluate_constraints(
|
| 449 |
+
properties: Mapping[str, float], scenario: ScenarioConfig
|
| 450 |
+
) -> Dict[str, tuple[bool, float]]:
|
| 451 |
+
"""Return hard-constraint satisfaction results."""
|
| 452 |
+
|
| 453 |
+
results: Dict[str, tuple[bool, float]] = {}
|
| 454 |
+
if "potency_min" in scenario.hard_constraints:
|
| 455 |
+
threshold = scenario.hard_constraints["potency_min"]
|
| 456 |
+
results["potency_min"] = (properties["potency"] >= threshold, threshold)
|
| 457 |
+
if "toxicity_max" in scenario.hard_constraints:
|
| 458 |
+
threshold = scenario.hard_constraints["toxicity_max"]
|
| 459 |
+
results["toxicity_max"] = (properties["toxicity"] <= threshold, threshold)
|
| 460 |
+
if "synth_min" in scenario.hard_constraints:
|
| 461 |
+
threshold = scenario.hard_constraints["synth_min"]
|
| 462 |
+
results["synth_min"] = (properties["synth"] >= threshold, threshold)
|
| 463 |
+
return results
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def evaluate_constraint_margins(
|
| 467 |
+
properties: Mapping[str, float], scenario: ScenarioConfig
|
| 468 |
+
) -> Dict[str, float]:
|
| 469 |
+
"""Return proportional 0-1 constraint scores where larger violations score lower."""
|
| 470 |
+
|
| 471 |
+
margins: Dict[str, float] = {}
|
| 472 |
+
if "potency_min" in scenario.hard_constraints:
|
| 473 |
+
threshold = scenario.hard_constraints["potency_min"]
|
| 474 |
+
margins["potency_min"] = min(1.0, max(0.0, properties["potency"] / max(threshold, 1e-6)))
|
| 475 |
+
if "toxicity_max" in scenario.hard_constraints:
|
| 476 |
+
threshold = scenario.hard_constraints["toxicity_max"]
|
| 477 |
+
if properties["toxicity"] <= threshold:
|
| 478 |
+
margins["toxicity_max"] = 1.0
|
| 479 |
+
else:
|
| 480 |
+
excess = properties["toxicity"] - threshold
|
| 481 |
+
margins["toxicity_max"] = max(0.0, 1.0 - excess / max(1.0 - threshold, 1e-6))
|
| 482 |
+
if "synth_min" in scenario.hard_constraints:
|
| 483 |
+
threshold = scenario.hard_constraints["synth_min"]
|
| 484 |
+
margins["synth_min"] = min(1.0, max(0.0, properties["synth"] / max(threshold, 1e-6)))
|
| 485 |
+
return margins
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def literature_hints(molecule: Mapping[str, str]) -> List[str]:
|
| 489 |
+
"""Collect deterministic medicinal chemistry hints for the current molecule."""
|
| 490 |
+
|
| 491 |
+
hints = []
|
| 492 |
+
for slot in SLOT_ORDER:
|
| 493 |
+
fragment_name = molecule[slot]
|
| 494 |
+
hints.append(FRAGMENT_LIBRARY[slot][fragment_name].literature_hint)
|
| 495 |
+
return hints
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def enumerate_candidate_edits(molecule: Mapping[str, str]) -> Iterable[tuple[str, str]]:
|
| 499 |
+
"""Generate all single-edit candidates from the current molecule."""
|
| 500 |
+
|
| 501 |
+
for slot in SLOT_ORDER:
|
| 502 |
+
for fragment in fragment_choices(slot):
|
| 503 |
+
if molecule[slot] != fragment:
|
| 504 |
+
yield slot, fragment
|
scripts/convert_peft_lora_to_mlx.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convert a PEFT LoRA adapter into the adapter format expected by mlx-lm."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
import shutil
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import mlx.core as mx
|
| 12 |
+
from safetensors import safe_open
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
KEY_RE = re.compile(
|
| 16 |
+
r"^base_model\.model\.model\.(?P<prefix>.+?)\.layers\."
|
| 17 |
+
r"(?P<layer>\d+)\.(?P<module>.+?)\.lora_(?P<ab>[AB])\.weight$"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main() -> None:
|
| 22 |
+
parser = argparse.ArgumentParser(description="Convert PEFT LoRA adapter to MLX LoRA adapter.")
|
| 23 |
+
parser.add_argument("peft_adapter", help="Path containing PEFT adapter_model.safetensors")
|
| 24 |
+
parser.add_argument("mlx_adapter", help="Output path for MLX adapters.safetensors")
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
peft_path = Path(args.peft_adapter)
|
| 28 |
+
mlx_path = Path(args.mlx_adapter)
|
| 29 |
+
mlx_path.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
peft_config = json.loads((peft_path / "adapter_config.json").read_text())
|
| 32 |
+
rank = int(peft_config["r"])
|
| 33 |
+
alpha = float(peft_config["lora_alpha"])
|
| 34 |
+
scale = alpha / rank
|
| 35 |
+
target_modules = list(peft_config["target_modules"])
|
| 36 |
+
|
| 37 |
+
weights = {}
|
| 38 |
+
layer_ids = set()
|
| 39 |
+
module_keys = set()
|
| 40 |
+
with safe_open(peft_path / "adapter_model.safetensors", framework="numpy") as handle:
|
| 41 |
+
for key in handle.keys():
|
| 42 |
+
match = KEY_RE.match(key)
|
| 43 |
+
if not match:
|
| 44 |
+
continue
|
| 45 |
+
layer = int(match.group("layer"))
|
| 46 |
+
module = match.group("module")
|
| 47 |
+
ab = match.group("ab")
|
| 48 |
+
layer_ids.add(layer)
|
| 49 |
+
module_keys.add(module)
|
| 50 |
+
tensor = handle.get_tensor(key)
|
| 51 |
+
mlx_key = f"language_model.model.layers.{layer}.{module}.lora_{ab.lower()}"
|
| 52 |
+
weights[mlx_key] = mx.array(tensor.T)
|
| 53 |
+
|
| 54 |
+
if not weights:
|
| 55 |
+
raise SystemExit(f"No PEFT LoRA weights found in {peft_path}")
|
| 56 |
+
|
| 57 |
+
mx.save_safetensors(str(mlx_path / "adapters.safetensors"), weights)
|
| 58 |
+
config = {
|
| 59 |
+
"fine_tune_type": "lora",
|
| 60 |
+
"num_layers": max(layer_ids) + 1,
|
| 61 |
+
"lora_parameters": {
|
| 62 |
+
"rank": rank,
|
| 63 |
+
"scale": scale,
|
| 64 |
+
"dropout": float(peft_config.get("lora_dropout", 0.0)),
|
| 65 |
+
"keys": sorted(module_keys),
|
| 66 |
+
},
|
| 67 |
+
}
|
| 68 |
+
(mlx_path / "adapter_config.json").write_text(json.dumps(config, indent=2) + "\n")
|
| 69 |
+
|
| 70 |
+
for filename in [
|
| 71 |
+
"tokenizer.json",
|
| 72 |
+
"tokenizer_config.json",
|
| 73 |
+
"chat_template.jinja",
|
| 74 |
+
"processor_config.json",
|
| 75 |
+
"README.md",
|
| 76 |
+
]:
|
| 77 |
+
source = peft_path / filename
|
| 78 |
+
if source.exists():
|
| 79 |
+
shutil.copy2(source, mlx_path / filename)
|
| 80 |
+
|
| 81 |
+
print(
|
| 82 |
+
json.dumps(
|
| 83 |
+
{
|
| 84 |
+
"output": str(mlx_path),
|
| 85 |
+
"weights": len(weights),
|
| 86 |
+
"num_layers": config["num_layers"],
|
| 87 |
+
"rank": rank,
|
| 88 |
+
"scale": scale,
|
| 89 |
+
"keys": sorted(module_keys),
|
| 90 |
+
"target_modules": target_modules,
|
| 91 |
+
},
|
| 92 |
+
indent=2,
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
main()
|
scripts/generate_sft_all_actions_dataset.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate a MolForge SFT JSONL dataset with rare-action coverage.
|
| 2 |
+
|
| 3 |
+
Most records come from the deterministic team policy so the examples are
|
| 4 |
+
grounded in real environment trajectories. A smaller coverage slice is added
|
| 5 |
+
for rare but valid schema variants such as defer, each assay tool, and edit
|
| 6 |
+
subtypes so SFT teaches the model the whole action surface.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Iterable
|
| 17 |
+
|
| 18 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 21 |
+
|
| 22 |
+
from inference_common import ( # noqa: E402
|
| 23 |
+
SYSTEM_PROMPT,
|
| 24 |
+
MolForgeAction,
|
| 25 |
+
MolForgeObservation,
|
| 26 |
+
attach_reasoning_fields,
|
| 27 |
+
attach_team_messages,
|
| 28 |
+
build_model_payload,
|
| 29 |
+
heuristic_team_action,
|
| 30 |
+
)
|
| 31 |
+
from scenarios import DEFAULT_TOOL_COSTS # noqa: E402
|
| 32 |
+
from server.molforge_environment import MolForgeEnvironment # noqa: E402
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main() -> None:
|
| 36 |
+
parser = argparse.ArgumentParser(description="Generate MolForge all-action SFT JSONL.")
|
| 37 |
+
parser.add_argument("--episodes", type=int, default=90)
|
| 38 |
+
parser.add_argument("--max-turns", type=int, default=10)
|
| 39 |
+
parser.add_argument("--output", default="data/molforge_sft_all_actions.jsonl")
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--randomized",
|
| 42 |
+
action="store_true",
|
| 43 |
+
help="Enable MolForge training randomization while collecting policy traces.",
|
| 44 |
+
)
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
|
| 47 |
+
if args.randomized:
|
| 48 |
+
os.environ["MOLFORGE_TRAINING_RANDOMIZATION"] = "1"
|
| 49 |
+
|
| 50 |
+
output_path = Path(args.output)
|
| 51 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
env = MolForgeEnvironment()
|
| 54 |
+
records = []
|
| 55 |
+
|
| 56 |
+
for _ in range(args.episodes):
|
| 57 |
+
observation = env.reset()
|
| 58 |
+
for _ in range(args.max_turns):
|
| 59 |
+
if observation.done:
|
| 60 |
+
break
|
| 61 |
+
action = heuristic_team_action(observation)
|
| 62 |
+
records.append(make_record(observation, action, source="policy_trace"))
|
| 63 |
+
observation = env.step(action)
|
| 64 |
+
|
| 65 |
+
for observation, action in curated_coverage_examples():
|
| 66 |
+
action = attach_reasoning_fields(observation, action)
|
| 67 |
+
action = attach_team_messages(observation, action)
|
| 68 |
+
records.append(make_record(observation, action, source="coverage_example"))
|
| 69 |
+
|
| 70 |
+
with output_path.open("w", encoding="utf-8") as handle:
|
| 71 |
+
for record in records:
|
| 72 |
+
handle.write(json.dumps(record, ensure_ascii=True) + "\n")
|
| 73 |
+
|
| 74 |
+
print(
|
| 75 |
+
json.dumps(
|
| 76 |
+
{
|
| 77 |
+
"output": str(output_path),
|
| 78 |
+
"records": len(records),
|
| 79 |
+
"coverage_records": sum(
|
| 80 |
+
1 for record in records if record["metadata"]["source"] == "coverage_example"
|
| 81 |
+
),
|
| 82 |
+
},
|
| 83 |
+
indent=2,
|
| 84 |
+
)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def curated_coverage_examples() -> Iterable[tuple[MolForgeObservation, MolForgeAction]]:
|
| 89 |
+
env = MolForgeEnvironment()
|
| 90 |
+
observations = [env.reset(), env.reset(), env.reset()]
|
| 91 |
+
|
| 92 |
+
for observation in observations:
|
| 93 |
+
yield observation, MolForgeAction(
|
| 94 |
+
action_type="defer",
|
| 95 |
+
acting_role="lead_chemist",
|
| 96 |
+
rationale="Hold this turn because the team needs a cleaner evidence-backed move.",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
easy, medium, hard = observations
|
| 100 |
+
|
| 101 |
+
yield easy, MolForgeAction(
|
| 102 |
+
action_type="edit",
|
| 103 |
+
acting_role="lead_chemist",
|
| 104 |
+
edit_type="add_fragment",
|
| 105 |
+
slot="back_pocket",
|
| 106 |
+
fragment="cyano",
|
| 107 |
+
rationale="Add a compact cyano handle to improve potency without large lipophilic risk.",
|
| 108 |
+
)
|
| 109 |
+
yield medium, MolForgeAction(
|
| 110 |
+
action_type="edit",
|
| 111 |
+
acting_role="lead_chemist",
|
| 112 |
+
edit_type="remove",
|
| 113 |
+
slot="back_pocket",
|
| 114 |
+
rationale="Remove the risky back-pocket group and return to a simpler default handle.",
|
| 115 |
+
)
|
| 116 |
+
yield hard, MolForgeAction(
|
| 117 |
+
action_type="edit",
|
| 118 |
+
acting_role="lead_chemist",
|
| 119 |
+
edit_type="undo_last_edit",
|
| 120 |
+
slot="solvent_tail",
|
| 121 |
+
rationale="Undo the last tail change when the visible evidence suggests it raised risk.",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
for observation in observations:
|
| 125 |
+
for tool_name in DEFAULT_TOOL_COSTS:
|
| 126 |
+
yield observation, MolForgeAction(
|
| 127 |
+
action_type="run_assay",
|
| 128 |
+
acting_role="assay_planner",
|
| 129 |
+
tool_name=tool_name,
|
| 130 |
+
rationale=f"Run {tool_name} to close a visible evidence gap before committing.",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
yield hard, MolForgeAction(
|
| 134 |
+
action_type="restart",
|
| 135 |
+
acting_role="lead_chemist",
|
| 136 |
+
rationale="Restart early because the hard scenario starts in a trap series.",
|
| 137 |
+
)
|
| 138 |
+
yield easy, MolForgeAction(
|
| 139 |
+
action_type="submit",
|
| 140 |
+
acting_role="lead_chemist",
|
| 141 |
+
rationale="Submit only when visible evidence is sufficient and budget should be preserved.",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def make_record(
|
| 146 |
+
observation: MolForgeObservation,
|
| 147 |
+
action: MolForgeAction,
|
| 148 |
+
*,
|
| 149 |
+
source: str,
|
| 150 |
+
) -> dict[str, object]:
|
| 151 |
+
return {
|
| 152 |
+
"messages": [
|
| 153 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 154 |
+
{
|
| 155 |
+
"role": "user",
|
| 156 |
+
"content": json.dumps(
|
| 157 |
+
build_model_payload(observation, compact=False),
|
| 158 |
+
separators=(",", ":"),
|
| 159 |
+
),
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"role": "assistant",
|
| 163 |
+
"content": json.dumps(
|
| 164 |
+
action.model_dump(exclude_none=True),
|
| 165 |
+
separators=(",", ":"),
|
| 166 |
+
),
|
| 167 |
+
},
|
| 168 |
+
],
|
| 169 |
+
"metadata": {
|
| 170 |
+
"source": source,
|
| 171 |
+
"scenario_id": observation.scenario_id,
|
| 172 |
+
"difficulty": observation.difficulty,
|
| 173 |
+
"step_index": observation.step_index,
|
| 174 |
+
"action_type": action.action_type,
|
| 175 |
+
},
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
main()
|
scripts/generate_sft_compact_policy_v4_dataset.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate MolForge compact-policy SFT data aligned to MLX inference.
|
| 2 |
+
|
| 3 |
+
V4 is designed around the failures seen in the v3 adapter:
|
| 4 |
+
- train on the exact compact prompt/payload shape used at inference time
|
| 5 |
+
- emphasize successful end-to-end expert trajectories
|
| 6 |
+
- include recovery examples after governance vetoes
|
| 7 |
+
- include enough schema coverage for all core action types without making
|
| 8 |
+
unsafe edits or wasteful assays dominate the positive training signal
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any, Iterable
|
| 19 |
+
|
| 20 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 21 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 22 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 23 |
+
|
| 24 |
+
from inference_common import ( # noqa: E402
|
| 25 |
+
MolForgeAction,
|
| 26 |
+
MolForgeObservation,
|
| 27 |
+
attach_reasoning_fields,
|
| 28 |
+
attach_team_messages,
|
| 29 |
+
heuristic_team_action,
|
| 30 |
+
)
|
| 31 |
+
from scenarios import DEFAULT_TOOL_COSTS # noqa: E402
|
| 32 |
+
from server.molforge_environment import MolForgeEnvironment # noqa: E402
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
COMPACT_ACTION_SYSTEM_PROMPT = """
|
| 36 |
+
You control the MolForge action policy.
|
| 37 |
+
Return exactly one JSON object with only these top-level keys:
|
| 38 |
+
action_type, acting_role, edit_type, slot, fragment, tool_name, rationale,
|
| 39 |
+
evidence, expected_effects.
|
| 40 |
+
|
| 41 |
+
Valid action_type values are exactly:
|
| 42 |
+
edit, run_assay, submit, restart, defer.
|
| 43 |
+
|
| 44 |
+
Do not output team messages. Do not output proposal, approval, objection,
|
| 45 |
+
risk_flag, assay_request, rejection, or submission_recommendation as action_type.
|
| 46 |
+
The environment will attach governance messages automatically.
|
| 47 |
+
|
| 48 |
+
Role rules:
|
| 49 |
+
- run_assay uses acting_role "assay_planner" and a valid tool_name.
|
| 50 |
+
- edit, submit, restart, and defer use acting_role "lead_chemist".
|
| 51 |
+
- unused optional fields must be JSON null.
|
| 52 |
+
""".strip()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main() -> None:
|
| 56 |
+
parser = argparse.ArgumentParser(description="Generate compact MolForge v4 policy SFT JSONL.")
|
| 57 |
+
parser.add_argument("--episodes", type=int, default=520)
|
| 58 |
+
parser.add_argument("--max-turns", type=int, default=10)
|
| 59 |
+
parser.add_argument("--seed", default="policy-v4")
|
| 60 |
+
parser.add_argument("--output", default="issue/molforge_sft_compact_policy_v4.jsonl")
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
records: list[dict[str, Any]] = []
|
| 64 |
+
seen: set[str] = set()
|
| 65 |
+
|
| 66 |
+
add_expert_traces(records, seen, episodes=18, max_turns=args.max_turns, randomized=False, seed=args.seed)
|
| 67 |
+
add_expert_traces(records, seen, episodes=args.episodes, max_turns=args.max_turns, randomized=True, seed=args.seed)
|
| 68 |
+
add_recovery_traces(records, seen, episodes=max(90, args.episodes // 3), seed=args.seed)
|
| 69 |
+
add_schema_coverage(records, seen, episodes=36, seed=args.seed)
|
| 70 |
+
|
| 71 |
+
output = Path(args.output)
|
| 72 |
+
output.parent.mkdir(parents=True, exist_ok=True)
|
| 73 |
+
with output.open("w", encoding="utf-8") as handle:
|
| 74 |
+
for record in records:
|
| 75 |
+
handle.write(json.dumps(record, ensure_ascii=True) + "\n")
|
| 76 |
+
|
| 77 |
+
print(json.dumps(summarize(records, str(output)), indent=2))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def add_expert_traces(
|
| 81 |
+
records: list[dict[str, Any]],
|
| 82 |
+
seen: set[str],
|
| 83 |
+
*,
|
| 84 |
+
episodes: int,
|
| 85 |
+
max_turns: int,
|
| 86 |
+
randomized: bool,
|
| 87 |
+
seed: str,
|
| 88 |
+
) -> None:
|
| 89 |
+
with_training_randomization(randomized, seed)
|
| 90 |
+
env = MolForgeEnvironment()
|
| 91 |
+
source = "expert_randomized" if randomized else "expert_canonical"
|
| 92 |
+
|
| 93 |
+
for _ in range(episodes):
|
| 94 |
+
observation = env.reset()
|
| 95 |
+
for _ in range(max_turns):
|
| 96 |
+
if observation.done:
|
| 97 |
+
break
|
| 98 |
+
action = heuristic_team_action(observation)
|
| 99 |
+
add_record(records, seen, observation, action, source=source)
|
| 100 |
+
observation = env.step(action)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def add_recovery_traces(records: list[dict[str, Any]], seen: set[str], *, episodes: int, seed: str) -> None:
|
| 104 |
+
with_training_randomization(True, f"{seed}-recovery")
|
| 105 |
+
env = MolForgeEnvironment()
|
| 106 |
+
|
| 107 |
+
for episode_index in range(episodes):
|
| 108 |
+
observation = env.reset()
|
| 109 |
+
|
| 110 |
+
# Move some episodes to a useful intermediate state before injecting a bad decision.
|
| 111 |
+
for _ in range(episode_index % 3):
|
| 112 |
+
if observation.done:
|
| 113 |
+
break
|
| 114 |
+
observation = env.step(heuristic_team_action(observation))
|
| 115 |
+
if observation.done:
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
for bad_action in bad_actions_for(observation):
|
| 119 |
+
trial = clone_env_at_observation(env, episode_index)
|
| 120 |
+
trial_obs = advance_like_source(trial, episode_index % 3)
|
| 121 |
+
if trial_obs.done:
|
| 122 |
+
continue
|
| 123 |
+
veto_obs = trial.step(attach_team_messages(trial_obs, attach_reasoning_fields(trial_obs, bad_action)))
|
| 124 |
+
if veto_obs.done:
|
| 125 |
+
continue
|
| 126 |
+
if veto_obs.governance.status != "policy_veto":
|
| 127 |
+
continue
|
| 128 |
+
recovery = heuristic_team_action(veto_obs)
|
| 129 |
+
add_record(records, seen, veto_obs, recovery, source="recovery_after_veto")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def add_schema_coverage(records: list[dict[str, Any]], seen: set[str], *, episodes: int, seed: str) -> None:
|
| 133 |
+
with_training_randomization(True, f"{seed}-coverage")
|
| 134 |
+
env = MolForgeEnvironment()
|
| 135 |
+
observations: list[MolForgeObservation] = []
|
| 136 |
+
for _ in range(episodes):
|
| 137 |
+
observation = env.reset()
|
| 138 |
+
observations.append(observation)
|
| 139 |
+
for _ in range(2):
|
| 140 |
+
if observation.done:
|
| 141 |
+
break
|
| 142 |
+
observation = env.step(heuristic_team_action(observation))
|
| 143 |
+
observations.append(observation)
|
| 144 |
+
|
| 145 |
+
defer_examples = 0
|
| 146 |
+
for observation in observations:
|
| 147 |
+
current = {slot.slot: slot.fragment for slot in observation.molecule_slots}
|
| 148 |
+
safe_edits = [
|
| 149 |
+
("solvent_tail", "morpholine", "Use morpholine to reduce safety risk."),
|
| 150 |
+
("back_pocket", "cyano", "Use cyano to preserve potency with lower lipophilic risk."),
|
| 151 |
+
("warhead", "reversible_cyanoacrylamide", "Use a softer warhead to reduce reactivity."),
|
| 152 |
+
("hinge", "azaindole", "Use azaindole when potency needs recovery."),
|
| 153 |
+
]
|
| 154 |
+
for slot, fragment, rationale in safe_edits:
|
| 155 |
+
if current.get(slot) == fragment:
|
| 156 |
+
continue
|
| 157 |
+
add_record(
|
| 158 |
+
records,
|
| 159 |
+
seen,
|
| 160 |
+
observation,
|
| 161 |
+
MolForgeAction(
|
| 162 |
+
action_type="edit",
|
| 163 |
+
acting_role="lead_chemist",
|
| 164 |
+
edit_type="substitute",
|
| 165 |
+
slot=slot, # type: ignore[arg-type]
|
| 166 |
+
fragment=fragment,
|
| 167 |
+
rationale=rationale,
|
| 168 |
+
),
|
| 169 |
+
source="schema_safe_edit",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if observation.step_index > 0:
|
| 173 |
+
add_record(
|
| 174 |
+
records,
|
| 175 |
+
seen,
|
| 176 |
+
observation,
|
| 177 |
+
MolForgeAction(
|
| 178 |
+
action_type="edit",
|
| 179 |
+
acting_role="lead_chemist",
|
| 180 |
+
edit_type="remove",
|
| 181 |
+
slot="back_pocket",
|
| 182 |
+
rationale="Remove the back-pocket group to simplify risk before reassay.",
|
| 183 |
+
),
|
| 184 |
+
source="schema_remove",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
for tool_name in useful_tool_subset(observation):
|
| 188 |
+
add_record(
|
| 189 |
+
records,
|
| 190 |
+
seen,
|
| 191 |
+
observation,
|
| 192 |
+
MolForgeAction(
|
| 193 |
+
action_type="run_assay",
|
| 194 |
+
acting_role="assay_planner",
|
| 195 |
+
tool_name=tool_name, # type: ignore[arg-type]
|
| 196 |
+
rationale=f"Run {tool_name} to close a visible evidence gap.",
|
| 197 |
+
),
|
| 198 |
+
source="schema_tool_coverage",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if (
|
| 202 |
+
defer_examples < 36
|
| 203 |
+
and observation.step_index >= 1
|
| 204 |
+
and observation.scenario_id != "level_2_hard"
|
| 205 |
+
):
|
| 206 |
+
add_record(
|
| 207 |
+
records,
|
| 208 |
+
seen,
|
| 209 |
+
observation,
|
| 210 |
+
MolForgeAction(
|
| 211 |
+
action_type="defer",
|
| 212 |
+
acting_role="lead_chemist",
|
| 213 |
+
rationale="Defer because no safe evidence-backed action remains in the current budget window.",
|
| 214 |
+
),
|
| 215 |
+
source="schema_defer",
|
| 216 |
+
)
|
| 217 |
+
defer_examples += 1
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def useful_tool_subset(observation: MolForgeObservation) -> list[str]:
|
| 221 |
+
gaps = set()
|
| 222 |
+
for constraint in observation.constraint_status:
|
| 223 |
+
if constraint.evidence_status == "unknown":
|
| 224 |
+
if constraint.name == "toxicity_max":
|
| 225 |
+
gaps.add("toxicity")
|
| 226 |
+
else:
|
| 227 |
+
gaps.add(constraint.name.split("_")[0])
|
| 228 |
+
tools: list[str] = []
|
| 229 |
+
if "potency" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["dock_target"]:
|
| 230 |
+
tools.extend(["evaluate_properties", "dock_target"])
|
| 231 |
+
if "toxicity" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["assay_toxicity"]:
|
| 232 |
+
tools.append("assay_toxicity")
|
| 233 |
+
if "synth" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["estimate_synthesizability"]:
|
| 234 |
+
tools.append("estimate_synthesizability")
|
| 235 |
+
if observation.remaining_budget >= DEFAULT_TOOL_COSTS["evaluate_novelty"]:
|
| 236 |
+
tools.append("evaluate_novelty")
|
| 237 |
+
if observation.remaining_budget >= DEFAULT_TOOL_COSTS["search_literature"]:
|
| 238 |
+
tools.append("search_literature")
|
| 239 |
+
if observation.scenario_id == "level_2_hard" and observation.remaining_budget >= DEFAULT_TOOL_COSTS["run_md_simulation"]:
|
| 240 |
+
tools.append("run_md_simulation")
|
| 241 |
+
return tools
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def bad_actions_for(observation: MolForgeObservation) -> Iterable[MolForgeAction]:
|
| 245 |
+
current = {slot.slot: slot.fragment for slot in observation.molecule_slots}
|
| 246 |
+
candidates = [
|
| 247 |
+
("solvent_tail", "dimethylamino", "This would add a safety liability and should be recovered from."),
|
| 248 |
+
("back_pocket", "trifluoromethyl", "This would over-shoot lipophilic risk and should be recovered from."),
|
| 249 |
+
("hinge", "quinazoline", "This can create route pressure and should be recovered from."),
|
| 250 |
+
]
|
| 251 |
+
for slot, fragment, rationale in candidates:
|
| 252 |
+
if current.get(slot) == fragment:
|
| 253 |
+
continue
|
| 254 |
+
yield MolForgeAction(
|
| 255 |
+
action_type="edit",
|
| 256 |
+
acting_role="lead_chemist",
|
| 257 |
+
edit_type="substitute",
|
| 258 |
+
slot=slot, # type: ignore[arg-type]
|
| 259 |
+
fragment=fragment,
|
| 260 |
+
rationale=rationale,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def clone_env_at_observation(source_env: MolForgeEnvironment, episode_index: int) -> MolForgeEnvironment:
|
| 265 |
+
del source_env
|
| 266 |
+
env = MolForgeEnvironment()
|
| 267 |
+
for _ in range(episode_index + 1):
|
| 268 |
+
observation = env.reset()
|
| 269 |
+
return env
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def advance_like_source(env: MolForgeEnvironment, steps: int) -> MolForgeObservation:
|
| 273 |
+
observation = env._build_observation(reward=0.0, done=False, reward_components=[]) # noqa: SLF001
|
| 274 |
+
for _ in range(steps):
|
| 275 |
+
if observation.done:
|
| 276 |
+
return observation
|
| 277 |
+
observation = env.step(heuristic_team_action(observation))
|
| 278 |
+
return observation
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def with_training_randomization(enabled: bool, seed: str) -> None:
|
| 282 |
+
if enabled:
|
| 283 |
+
os.environ["MOLFORGE_TRAINING_RANDOMIZATION"] = "1"
|
| 284 |
+
else:
|
| 285 |
+
os.environ.pop("MOLFORGE_TRAINING_RANDOMIZATION", None)
|
| 286 |
+
os.environ["MOLFORGE_RANDOM_SEED"] = seed
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def add_record(
|
| 290 |
+
records: list[dict[str, Any]],
|
| 291 |
+
seen: set[str],
|
| 292 |
+
observation: MolForgeObservation,
|
| 293 |
+
action: MolForgeAction,
|
| 294 |
+
*,
|
| 295 |
+
source: str,
|
| 296 |
+
) -> None:
|
| 297 |
+
action = attach_reasoning_fields(observation, action)
|
| 298 |
+
record = make_record(observation, action, source=source)
|
| 299 |
+
key = json.dumps(
|
| 300 |
+
{"user": record["messages"][1]["content"], "assistant": record["messages"][2]["content"]},
|
| 301 |
+
sort_keys=True,
|
| 302 |
+
)
|
| 303 |
+
if key in seen:
|
| 304 |
+
return
|
| 305 |
+
validate_target(record["messages"][2]["content"])
|
| 306 |
+
records.append(record)
|
| 307 |
+
seen.add(key)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def make_record(observation: MolForgeObservation, action: MolForgeAction, *, source: str) -> dict[str, Any]:
|
| 311 |
+
return {
|
| 312 |
+
"messages": [
|
| 313 |
+
{"role": "system", "content": COMPACT_ACTION_SYSTEM_PROMPT},
|
| 314 |
+
{"role": "user", "content": json.dumps(compact_action_payload(observation), separators=(",", ":"))},
|
| 315 |
+
{"role": "assistant", "content": json.dumps(target_action(action), separators=(",", ":"))},
|
| 316 |
+
],
|
| 317 |
+
"metadata": {
|
| 318 |
+
"source": source,
|
| 319 |
+
"scenario_id": observation.scenario_id,
|
| 320 |
+
"difficulty": observation.difficulty,
|
| 321 |
+
"step_index": observation.step_index,
|
| 322 |
+
"action_type": action.action_type,
|
| 323 |
+
},
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def compact_action_payload(observation: MolForgeObservation) -> dict[str, Any]:
|
| 328 |
+
lead_view = next(
|
| 329 |
+
(role.observation for role in observation.role_observations if role.role == "lead_chemist"),
|
| 330 |
+
{},
|
| 331 |
+
)
|
| 332 |
+
assay_view = next(
|
| 333 |
+
(role.observation for role in observation.role_observations if role.role == "assay_planner"),
|
| 334 |
+
{},
|
| 335 |
+
)
|
| 336 |
+
return {
|
| 337 |
+
"valid_action_types": ["edit", "run_assay", "submit", "restart", "defer"],
|
| 338 |
+
"scenario_id": observation.scenario_id,
|
| 339 |
+
"difficulty": observation.difficulty,
|
| 340 |
+
"task_brief": observation.task_brief,
|
| 341 |
+
"current_molecule": observation.current_molecule,
|
| 342 |
+
"current_smiles": observation.metadata.get("current_smiles", ""),
|
| 343 |
+
"visible_metrics": observation.visible_metrics,
|
| 344 |
+
"constraint_status": [constraint.model_dump() for constraint in observation.constraint_status],
|
| 345 |
+
"remaining_budget": observation.remaining_budget,
|
| 346 |
+
"max_budget": observation.max_budget,
|
| 347 |
+
"step_index": observation.step_index,
|
| 348 |
+
"max_steps": observation.max_steps,
|
| 349 |
+
"molecule_slots": lead_view.get("molecule_slots", {}),
|
| 350 |
+
"candidate_edits": lead_view.get("candidate_edits", [])[:12],
|
| 351 |
+
"open_questions": lead_view.get("open_questions", []),
|
| 352 |
+
"known_assays": [
|
| 353 |
+
{
|
| 354 |
+
"tool_name": reading.tool_name,
|
| 355 |
+
"property_name": reading.property_name,
|
| 356 |
+
"estimate": reading.estimate,
|
| 357 |
+
"confidence_low": reading.confidence_low,
|
| 358 |
+
"confidence_high": reading.confidence_high,
|
| 359 |
+
"molecule_signature": reading.molecule_signature,
|
| 360 |
+
}
|
| 361 |
+
for reading in observation.known_assays[-8:]
|
| 362 |
+
],
|
| 363 |
+
"tool_costs": assay_view.get("tool_costs", {}),
|
| 364 |
+
"evidence_gaps": assay_view.get("evidence_gaps", []),
|
| 365 |
+
"estimated_information_value": assay_view.get("estimated_information_value", {}),
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def target_action(action: MolForgeAction) -> dict[str, Any]:
|
| 370 |
+
effects = {
|
| 371 |
+
"potency": "unknown",
|
| 372 |
+
"toxicity": "unknown",
|
| 373 |
+
"synth": "unknown",
|
| 374 |
+
"novelty": "unknown",
|
| 375 |
+
"budget": "neutral",
|
| 376 |
+
}
|
| 377 |
+
effects.update({key: value for key, value in action.expected_effects.items() if key in effects})
|
| 378 |
+
return {
|
| 379 |
+
"action_type": action.action_type,
|
| 380 |
+
"acting_role": action.acting_role,
|
| 381 |
+
"edit_type": action.edit_type,
|
| 382 |
+
"slot": action.slot,
|
| 383 |
+
"fragment": action.fragment,
|
| 384 |
+
"tool_name": action.tool_name,
|
| 385 |
+
"rationale": action.rationale[:220],
|
| 386 |
+
"evidence": list(action.evidence[:5]),
|
| 387 |
+
"expected_effects": effects,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def validate_target(text: str) -> None:
|
| 392 |
+
data = json.loads(text)
|
| 393 |
+
allowed = {
|
| 394 |
+
"action_type",
|
| 395 |
+
"acting_role",
|
| 396 |
+
"edit_type",
|
| 397 |
+
"slot",
|
| 398 |
+
"fragment",
|
| 399 |
+
"tool_name",
|
| 400 |
+
"rationale",
|
| 401 |
+
"evidence",
|
| 402 |
+
"expected_effects",
|
| 403 |
+
}
|
| 404 |
+
if set(data) != allowed:
|
| 405 |
+
raise ValueError(f"target keys mismatch: {sorted(data)}")
|
| 406 |
+
if data["action_type"] not in {"edit", "run_assay", "submit", "restart", "defer"}:
|
| 407 |
+
raise ValueError(f"invalid action_type: {data['action_type']}")
|
| 408 |
+
if data["action_type"] == "proposal":
|
| 409 |
+
raise ValueError("proposal is not a compact action type")
|
| 410 |
+
if data["edit_type"] == "replace":
|
| 411 |
+
raise ValueError("replace must never be used; use substitute")
|
| 412 |
+
if "messages" in data:
|
| 413 |
+
raise ValueError("compact target must not contain messages")
|
| 414 |
+
if not isinstance(data["evidence"], list):
|
| 415 |
+
raise ValueError("evidence must be a list")
|
| 416 |
+
if set(data["expected_effects"]) != {"potency", "toxicity", "synth", "novelty", "budget"}:
|
| 417 |
+
raise ValueError("expected_effects must have exactly five keys")
|
| 418 |
+
MolForgeAction(**data)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def summarize(records: list[dict[str, Any]], output: str) -> dict[str, Any]:
|
| 422 |
+
actions: dict[str, int] = {}
|
| 423 |
+
sources: dict[str, int] = {}
|
| 424 |
+
scenarios: dict[str, int] = {}
|
| 425 |
+
users = set()
|
| 426 |
+
assistants = set()
|
| 427 |
+
for record in records:
|
| 428 |
+
metadata = record["metadata"]
|
| 429 |
+
actions[metadata["action_type"]] = actions.get(metadata["action_type"], 0) + 1
|
| 430 |
+
sources[metadata["source"]] = sources.get(metadata["source"], 0) + 1
|
| 431 |
+
scenarios[metadata["scenario_id"]] = scenarios.get(metadata["scenario_id"], 0) + 1
|
| 432 |
+
users.add(record["messages"][1]["content"])
|
| 433 |
+
assistants.add(record["messages"][2]["content"])
|
| 434 |
+
return {
|
| 435 |
+
"output": output,
|
| 436 |
+
"records": len(records),
|
| 437 |
+
"unique_user_prompts": len(users),
|
| 438 |
+
"unique_assistant_targets": len(assistants),
|
| 439 |
+
"action_types": actions,
|
| 440 |
+
"sources": sources,
|
| 441 |
+
"scenario_ids": scenarios,
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if __name__ == "__main__":
|
| 446 |
+
main()
|
scripts/validate_sft_traces.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Validate MolForge SFT JSONL before training."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 14 |
+
|
| 15 |
+
from models import MolForgeAction
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main() -> None:
|
| 19 |
+
parser = argparse.ArgumentParser(description="Validate MolForge SFT trace JSONL.")
|
| 20 |
+
parser.add_argument("path", help="Path to JSONL generated by scripts/generate_sft_traces.py")
|
| 21 |
+
parser.add_argument("--max-errors", type=int, default=20)
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
path = Path(args.path)
|
| 25 |
+
errors: list[str] = []
|
| 26 |
+
records = 0
|
| 27 |
+
action_types: dict[str, int] = {}
|
| 28 |
+
scenario_ids: dict[str, int] = {}
|
| 29 |
+
|
| 30 |
+
for line_number, line in enumerate(path.open(encoding="utf-8"), start=1):
|
| 31 |
+
if not line.strip():
|
| 32 |
+
continue
|
| 33 |
+
records += 1
|
| 34 |
+
try:
|
| 35 |
+
record = json.loads(line)
|
| 36 |
+
messages = record["messages"]
|
| 37 |
+
assistant_content = messages[-1]["content"]
|
| 38 |
+
action_dict = json.loads(assistant_content)
|
| 39 |
+
action = MolForgeAction(**action_dict)
|
| 40 |
+
validation_error = validate_action_contract(action)
|
| 41 |
+
if validation_error:
|
| 42 |
+
raise ValueError(validation_error)
|
| 43 |
+
metadata = record.get("metadata", {})
|
| 44 |
+
scenario_id = metadata.get("scenario_id", "unknown")
|
| 45 |
+
scenario_ids[scenario_id] = scenario_ids.get(scenario_id, 0) + 1
|
| 46 |
+
action_types[action.action_type] = action_types.get(action.action_type, 0) + 1
|
| 47 |
+
except Exception as exc:
|
| 48 |
+
errors.append(f"line {line_number}: {exc}")
|
| 49 |
+
if len(errors) >= args.max_errors:
|
| 50 |
+
break
|
| 51 |
+
|
| 52 |
+
summary: dict[str, Any] = {
|
| 53 |
+
"path": str(path),
|
| 54 |
+
"records_checked": records,
|
| 55 |
+
"valid": not errors,
|
| 56 |
+
"action_types": action_types,
|
| 57 |
+
"scenario_ids": scenario_ids,
|
| 58 |
+
"errors": errors,
|
| 59 |
+
}
|
| 60 |
+
print(json.dumps(summary, indent=2))
|
| 61 |
+
if errors:
|
| 62 |
+
raise SystemExit(1)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def validate_action_contract(action: MolForgeAction) -> str:
|
| 66 |
+
if action.action_type == "run_assay" and action.acting_role != "assay_planner":
|
| 67 |
+
return "run_assay must use acting_role=assay_planner"
|
| 68 |
+
if action.action_type in {"edit", "submit", "restart", "defer"} and action.acting_role != "lead_chemist":
|
| 69 |
+
return f"{action.action_type} must use acting_role=lead_chemist"
|
| 70 |
+
if not action.rationale.strip():
|
| 71 |
+
return "missing rationale"
|
| 72 |
+
if not action.evidence:
|
| 73 |
+
return "missing evidence"
|
| 74 |
+
if not action.expected_effects:
|
| 75 |
+
return "missing expected_effects"
|
| 76 |
+
|
| 77 |
+
allowed_message_types = {
|
| 78 |
+
"lead_chemist": {"proposal", "revision_request", "submission_recommendation"},
|
| 79 |
+
"assay_planner": {"proposal", "approval", "rejection", "assay_request", "submission_recommendation"},
|
| 80 |
+
"toxicologist": {"approval", "objection", "risk_flag", "assay_request", "rejection"},
|
| 81 |
+
"process_chemist": {"approval", "objection", "risk_flag", "assay_request"},
|
| 82 |
+
}
|
| 83 |
+
seen_senders = set()
|
| 84 |
+
for message in action.messages:
|
| 85 |
+
if message.sender in seen_senders:
|
| 86 |
+
return f"duplicate message sender {message.sender}"
|
| 87 |
+
seen_senders.add(message.sender)
|
| 88 |
+
if message.message_type not in allowed_message_types.get(message.sender, set()):
|
| 89 |
+
return f"{message.sender} cannot emit {message.message_type}"
|
| 90 |
+
actor_message = next((message for message in action.messages if message.sender == action.acting_role), None)
|
| 91 |
+
if action.action_type != "defer" and (actor_message is None or actor_message.message_type != "proposal"):
|
| 92 |
+
return "acting_role must include a proposal message"
|
| 93 |
+
return ""
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
main()
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
+
FROM ${BASE_IMAGE} AS builder
|
| 3 |
+
ARG INSTALL_TDC=0
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
RUN apt-get update && \
|
| 8 |
+
apt-get install -y --no-install-recommends git && \
|
| 9 |
+
rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY . /app/env
|
| 12 |
+
WORKDIR /app/env
|
| 13 |
+
|
| 14 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 15 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 16 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 17 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
ENV UV_LINK_MODE=copy
|
| 21 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 22 |
+
if [ "$INSTALL_TDC" = "1" ]; then \
|
| 23 |
+
uv sync --no-editable --extra tdc; \
|
| 24 |
+
else \
|
| 25 |
+
uv sync --no-editable; \
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
FROM ${BASE_IMAGE}
|
| 29 |
+
|
| 30 |
+
WORKDIR /app
|
| 31 |
+
|
| 32 |
+
RUN apt-get update && \
|
| 33 |
+
apt-get install -y --no-install-recommends curl && \
|
| 34 |
+
rm -rf /var/lib/apt/lists/*
|
| 35 |
+
|
| 36 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 37 |
+
COPY --from=builder /app/env /app/env
|
| 38 |
+
|
| 39 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 40 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 41 |
+
|
| 42 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 43 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 44 |
+
|
| 45 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Server package for MolForge."""
|
server/actions.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Action execution mixin for MolForge."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Mapping
|
| 6 |
+
|
| 7 |
+
from .shared import (
|
| 8 |
+
DEFAULT_TOOL_COSTS,
|
| 9 |
+
compute_objective_score,
|
| 10 |
+
evaluate_constraint_margins,
|
| 11 |
+
evaluate_constraints,
|
| 12 |
+
literature_hints,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from ..models import AssayReading, MolForgeAction, RewardComponent
|
| 17 |
+
except ImportError:
|
| 18 |
+
from models import AssayReading, MolForgeAction, RewardComponent
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MolForgeActionMixin:
|
| 22 |
+
"""Methods that mutate environment state through actions."""
|
| 23 |
+
|
| 24 |
+
def _execute_action(
|
| 25 |
+
self,
|
| 26 |
+
action: MolForgeAction,
|
| 27 |
+
reward_components: List[RewardComponent],
|
| 28 |
+
previous_properties: Mapping[str, float],
|
| 29 |
+
previous_score: float,
|
| 30 |
+
) -> tuple[float, bool]:
|
| 31 |
+
reward = 0.0
|
| 32 |
+
done = False
|
| 33 |
+
|
| 34 |
+
if action.action_type == "edit":
|
| 35 |
+
reward += self._apply_edit(action, reward_components, previous_score)
|
| 36 |
+
elif action.action_type == "run_assay":
|
| 37 |
+
reward += self._run_assay(action, reward_components)
|
| 38 |
+
elif action.action_type == "submit":
|
| 39 |
+
reward, done = self._submit(reward_components)
|
| 40 |
+
elif action.action_type == "restart":
|
| 41 |
+
reward += self._restart(reward_components)
|
| 42 |
+
elif action.action_type == "defer":
|
| 43 |
+
reward -= 0.05
|
| 44 |
+
reward_components.append(
|
| 45 |
+
RewardComponent(
|
| 46 |
+
name="defer",
|
| 47 |
+
value=-0.05,
|
| 48 |
+
explanation="Deferring preserves state but lightly penalizes lost project time.",
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
self._last_summary = "The team deferred action to gather its thoughts."
|
| 52 |
+
|
| 53 |
+
return reward, done
|
| 54 |
+
|
| 55 |
+
def _apply_edit(
|
| 56 |
+
self,
|
| 57 |
+
action: MolForgeAction,
|
| 58 |
+
reward_components: List[RewardComponent],
|
| 59 |
+
previous_score: float,
|
| 60 |
+
) -> float:
|
| 61 |
+
previous_signature = self._molecule_signature()
|
| 62 |
+
previous_fragment = self._molecule[action.slot] # type: ignore[index]
|
| 63 |
+
safe_defaults = {
|
| 64 |
+
"warhead": "nitrile",
|
| 65 |
+
"hinge": "pyridine",
|
| 66 |
+
"solvent_tail": "morpholine",
|
| 67 |
+
"back_pocket": "methoxy",
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
if action.edit_type == "remove":
|
| 71 |
+
self._molecule[action.slot] = safe_defaults[action.slot] # type: ignore[index]
|
| 72 |
+
else:
|
| 73 |
+
self._molecule[action.slot] = action.fragment # type: ignore[index]
|
| 74 |
+
|
| 75 |
+
new_signature = self._molecule_signature()
|
| 76 |
+
new_properties = self._true_properties()
|
| 77 |
+
new_score = compute_objective_score(new_properties, self._scenario)
|
| 78 |
+
delta = round(new_score - previous_score, 4)
|
| 79 |
+
if self._reward_mode == "dense":
|
| 80 |
+
reward = delta * 2.0
|
| 81 |
+
explanation = (
|
| 82 |
+
f"Updated {action.slot} from {previous_fragment} to {self._molecule[action.slot]}, "
|
| 83 |
+
f"changing the internal objective score by {delta:+.3f}."
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
reward = 0.04 if delta > 0 else (-0.04 if delta < 0 else 0.0)
|
| 87 |
+
explanation = (
|
| 88 |
+
f"Updated {action.slot} from {previous_fragment} to {self._molecule[action.slot]}. "
|
| 89 |
+
"Edit feedback is intentionally coarse; assays and terminal graders provide the main signal."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
reward_components.append(
|
| 93 |
+
RewardComponent(
|
| 94 |
+
name="edit_delta",
|
| 95 |
+
value=round(reward, 4),
|
| 96 |
+
explanation=explanation,
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if new_signature in self._visited_states:
|
| 101 |
+
reward -= 0.35
|
| 102 |
+
reward_components.append(
|
| 103 |
+
RewardComponent(
|
| 104 |
+
name="loop_penalty",
|
| 105 |
+
value=-0.35,
|
| 106 |
+
explanation="This edit revisited a previously explored molecular state.",
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
reward += 0.06
|
| 111 |
+
self._visited_states.add(new_signature)
|
| 112 |
+
|
| 113 |
+
reward -= 0.12
|
| 114 |
+
reward_components.append(
|
| 115 |
+
RewardComponent(
|
| 116 |
+
name="turn_cost",
|
| 117 |
+
value=-0.12,
|
| 118 |
+
explanation="Every chemistry edit consumes simulated project time.",
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
self._last_summary = (
|
| 122 |
+
f"Lead Chemist edited {action.slot}; molecule changed from "
|
| 123 |
+
f"{previous_signature} to {new_signature}."
|
| 124 |
+
)
|
| 125 |
+
return reward
|
| 126 |
+
|
| 127 |
+
def _run_assay(
|
| 128 |
+
self,
|
| 129 |
+
action: MolForgeAction,
|
| 130 |
+
reward_components: List[RewardComponent],
|
| 131 |
+
) -> float:
|
| 132 |
+
tool_name = action.tool_name or ""
|
| 133 |
+
cost = DEFAULT_TOOL_COSTS[tool_name]
|
| 134 |
+
self._state.remaining_budget -= cost
|
| 135 |
+
self._state.budget_used += cost
|
| 136 |
+
self._state.oracle_call_count += 1
|
| 137 |
+
|
| 138 |
+
key = f"{self._molecule_signature()}::{tool_name}"
|
| 139 |
+
runs = self._assay_runs.get(key, 0) + 1
|
| 140 |
+
self._assay_runs[key] = runs
|
| 141 |
+
|
| 142 |
+
reward = 0.02
|
| 143 |
+
if runs == 1:
|
| 144 |
+
reward += 0.10
|
| 145 |
+
explanation = "First assay on this molecule/tool pair increased observability."
|
| 146 |
+
else:
|
| 147 |
+
reward -= 0.08
|
| 148 |
+
explanation = "Repeated assay spent budget on the same molecule/tool pair."
|
| 149 |
+
|
| 150 |
+
readings = self._build_assay_readings(tool_name, runs)
|
| 151 |
+
self._merge_assays(readings)
|
| 152 |
+
if tool_name == "search_literature":
|
| 153 |
+
reward += 0.04
|
| 154 |
+
if self._reward_mode == "curriculum" and runs == 1:
|
| 155 |
+
required_props = {"potency", "toxicity"}
|
| 156 |
+
if "synth_min" in self._scenario.hard_constraints:
|
| 157 |
+
required_props.add("synth")
|
| 158 |
+
covered_props = {
|
| 159 |
+
reading.property_name
|
| 160 |
+
for reading in readings
|
| 161 |
+
if reading.property_name in required_props
|
| 162 |
+
}
|
| 163 |
+
if covered_props:
|
| 164 |
+
bonus = 0.08 * len(covered_props)
|
| 165 |
+
reward += bonus
|
| 166 |
+
reward_components.append(
|
| 167 |
+
RewardComponent(
|
| 168 |
+
name="curriculum_evidence_gate",
|
| 169 |
+
value=round(bonus, 4),
|
| 170 |
+
explanation=(
|
| 171 |
+
"Curriculum reward for collecting first-pass evidence "
|
| 172 |
+
f"for: {', '.join(sorted(covered_props))}."
|
| 173 |
+
),
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
reward_components.append(
|
| 178 |
+
RewardComponent(
|
| 179 |
+
name="assay_information_gain",
|
| 180 |
+
value=round(reward, 4),
|
| 181 |
+
explanation=explanation,
|
| 182 |
+
)
|
| 183 |
+
)
|
| 184 |
+
reward_components.append(
|
| 185 |
+
RewardComponent(
|
| 186 |
+
name="budget_spend",
|
| 187 |
+
value=round(-cost / max(self._scenario.oracle_budget, 1), 4),
|
| 188 |
+
explanation=f"Spent {cost} assay budget on {tool_name}.",
|
| 189 |
+
)
|
| 190 |
+
)
|
| 191 |
+
reward -= cost / max(self._scenario.oracle_budget, 1)
|
| 192 |
+
|
| 193 |
+
self._oracle_log.append(
|
| 194 |
+
{
|
| 195 |
+
"step": self._state.step_count,
|
| 196 |
+
"tool_name": tool_name,
|
| 197 |
+
"runs": runs,
|
| 198 |
+
"molecule": self._molecule_signature(),
|
| 199 |
+
"cost": cost,
|
| 200 |
+
"results": [reading.model_dump() for reading in readings],
|
| 201 |
+
}
|
| 202 |
+
)
|
| 203 |
+
self._last_summary = (
|
| 204 |
+
f"Assay Planner executed {tool_name}; {len(readings)} structured assay result(s) are now visible."
|
| 205 |
+
)
|
| 206 |
+
return reward
|
| 207 |
+
|
| 208 |
+
def _submit(self, reward_components: List[RewardComponent]) -> tuple[float, bool]:
|
| 209 |
+
properties = self._true_properties()
|
| 210 |
+
final_score = compute_objective_score(properties, self._scenario)
|
| 211 |
+
constraint_results = evaluate_constraints(properties, self._scenario)
|
| 212 |
+
constraint_margins = evaluate_constraint_margins(properties, self._scenario)
|
| 213 |
+
margin_score = sum(constraint_margins.values()) / max(len(constraint_margins), 1)
|
| 214 |
+
violation_penalty = round((1.0 - margin_score) * 2.0, 4)
|
| 215 |
+
hard_constraints_met = all(result[0] for result in constraint_results.values())
|
| 216 |
+
budget_efficiency = self._state.remaining_budget / max(self._scenario.oracle_budget, 1)
|
| 217 |
+
beats_baseline = final_score >= self._scenario.baseline_to_beat
|
| 218 |
+
current_signature = self._molecule_signature()
|
| 219 |
+
evidence_requirements = ["potency", "toxicity"]
|
| 220 |
+
if "synth_min" in self._scenario.hard_constraints:
|
| 221 |
+
evidence_requirements.append("synth")
|
| 222 |
+
missing_evidence = [
|
| 223 |
+
prop for prop in evidence_requirements if self._current_property_estimate(prop, current_signature) is None
|
| 224 |
+
]
|
| 225 |
+
evidence_met = not missing_evidence
|
| 226 |
+
post_shift_evidence_met = True
|
| 227 |
+
if self._scenario.target_shift_step and self._target_shift_active():
|
| 228 |
+
post_shift_evidence_met = any(
|
| 229 |
+
entry["step"] >= self._scenario.target_shift_step
|
| 230 |
+
and entry["molecule"] == current_signature
|
| 231 |
+
and any(result["property_name"] == "potency" for result in entry["results"])
|
| 232 |
+
for entry in self._oracle_log
|
| 233 |
+
)
|
| 234 |
+
valid_submission = hard_constraints_met and beats_baseline and evidence_met and post_shift_evidence_met
|
| 235 |
+
|
| 236 |
+
reward = final_score * 2.0 if valid_submission else final_score * 0.25
|
| 237 |
+
if valid_submission:
|
| 238 |
+
reward += 3.5
|
| 239 |
+
elif not hard_constraints_met:
|
| 240 |
+
reward -= violation_penalty
|
| 241 |
+
if not beats_baseline:
|
| 242 |
+
reward -= 0.6
|
| 243 |
+
if not evidence_met:
|
| 244 |
+
reward -= 1.2
|
| 245 |
+
if not post_shift_evidence_met:
|
| 246 |
+
reward -= 0.8
|
| 247 |
+
|
| 248 |
+
if valid_submission:
|
| 249 |
+
reward += max(0.0, budget_efficiency) * 0.7
|
| 250 |
+
if self._reward_mode == "curriculum" and evidence_met and post_shift_evidence_met:
|
| 251 |
+
submit_bonus = 0.35
|
| 252 |
+
if hard_constraints_met:
|
| 253 |
+
submit_bonus += 0.15
|
| 254 |
+
reward += submit_bonus
|
| 255 |
+
|
| 256 |
+
self._state.submitted = True
|
| 257 |
+
self._report_card = self._build_report_card(submitted=True)
|
| 258 |
+
self._last_summary = (
|
| 259 |
+
f"The team submitted a candidate that "
|
| 260 |
+
f"{'passed' if hard_constraints_met else 'failed'} hard constraints."
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
reward_components.extend(
|
| 264 |
+
[
|
| 265 |
+
RewardComponent(
|
| 266 |
+
name="submission_quality",
|
| 267 |
+
value=round((final_score * 2.0 if valid_submission else final_score * 0.25), 4),
|
| 268 |
+
explanation=(
|
| 269 |
+
"Full scientific quality reward because the submission met constraints, baseline, and evidence gates."
|
| 270 |
+
if valid_submission
|
| 271 |
+
else "Only a small quality trace is awarded because the submit action missed a gate."
|
| 272 |
+
),
|
| 273 |
+
),
|
| 274 |
+
RewardComponent(
|
| 275 |
+
name="hard_constraints",
|
| 276 |
+
value=(
|
| 277 |
+
3.5
|
| 278 |
+
if valid_submission
|
| 279 |
+
else (-violation_penalty if not hard_constraints_met else 0.0)
|
| 280 |
+
),
|
| 281 |
+
explanation=(
|
| 282 |
+
"Large sparse bonus for beating baseline with required current evidence."
|
| 283 |
+
if valid_submission
|
| 284 |
+
else "Submission missed constraints, baseline, or evidence requirements; constraint penalty scales with violation severity."
|
| 285 |
+
),
|
| 286 |
+
),
|
| 287 |
+
RewardComponent(
|
| 288 |
+
name="constraint_margin",
|
| 289 |
+
value=round(margin_score, 4),
|
| 290 |
+
explanation=(
|
| 291 |
+
"Proportional hard-constraint score: worse potency, toxicity, or synthesis violations produce lower values."
|
| 292 |
+
),
|
| 293 |
+
),
|
| 294 |
+
RewardComponent(
|
| 295 |
+
name="baseline_gate",
|
| 296 |
+
value=0.0 if beats_baseline else -0.6,
|
| 297 |
+
explanation=(
|
| 298 |
+
"Submitted molecule beat the scenario baseline."
|
| 299 |
+
if beats_baseline
|
| 300 |
+
else "Submitted molecule did not beat the scenario baseline."
|
| 301 |
+
),
|
| 302 |
+
),
|
| 303 |
+
RewardComponent(
|
| 304 |
+
name="submission_evidence",
|
| 305 |
+
value=0.0 if evidence_met else -1.2,
|
| 306 |
+
explanation=(
|
| 307 |
+
"Current-molecule potency/toxicity/synthesis evidence was available."
|
| 308 |
+
if evidence_met
|
| 309 |
+
else f"Submission lacked current evidence for: {', '.join(missing_evidence)}."
|
| 310 |
+
),
|
| 311 |
+
),
|
| 312 |
+
RewardComponent(
|
| 313 |
+
name="post_shift_evidence",
|
| 314 |
+
value=0.0 if post_shift_evidence_met else -0.8,
|
| 315 |
+
explanation=(
|
| 316 |
+
"Post-shift potency evidence was available for the submitted molecule."
|
| 317 |
+
if post_shift_evidence_met
|
| 318 |
+
else "Hard scenario submission lacked post-shift potency evidence for the current molecule."
|
| 319 |
+
),
|
| 320 |
+
),
|
| 321 |
+
RewardComponent(
|
| 322 |
+
name="budget_efficiency",
|
| 323 |
+
value=round(max(0.0, budget_efficiency) * 0.7, 4) if valid_submission else 0.0,
|
| 324 |
+
explanation=(
|
| 325 |
+
"Unused budget is rewarded to discourage wasteful oracle usage."
|
| 326 |
+
if valid_submission
|
| 327 |
+
else "Budget efficiency is not awarded to a gated or premature submission."
|
| 328 |
+
),
|
| 329 |
+
),
|
| 330 |
+
]
|
| 331 |
+
)
|
| 332 |
+
if self._reward_mode == "curriculum" and evidence_met and post_shift_evidence_met:
|
| 333 |
+
reward_components.append(
|
| 334 |
+
RewardComponent(
|
| 335 |
+
name="curriculum_evidence_supported_submit",
|
| 336 |
+
value=round(submit_bonus, 4),
|
| 337 |
+
explanation=(
|
| 338 |
+
"Curriculum reward for making a formal submit decision after the required "
|
| 339 |
+
"current evidence package was available."
|
| 340 |
+
),
|
| 341 |
+
)
|
| 342 |
+
)
|
| 343 |
+
return reward, True
|
| 344 |
+
|
| 345 |
+
def _restart(self, reward_components: List[RewardComponent]) -> float:
|
| 346 |
+
self._molecule = dict(self._scenario.restart_scaffold)
|
| 347 |
+
self._trap_penalty_active = False
|
| 348 |
+
self._known_assays = []
|
| 349 |
+
self._assay_runs = {}
|
| 350 |
+
self._restart_used = True
|
| 351 |
+
self._visited_states.add(self._molecule_signature())
|
| 352 |
+
self._state.remaining_budget -= 350
|
| 353 |
+
self._state.budget_used += 350
|
| 354 |
+
reward_components.append(
|
| 355 |
+
RewardComponent(
|
| 356 |
+
name="restart_penalty",
|
| 357 |
+
value=-0.4,
|
| 358 |
+
explanation="Restarting discards sunk work but switches to a clean scaffold family.",
|
| 359 |
+
)
|
| 360 |
+
)
|
| 361 |
+
self._last_summary = (
|
| 362 |
+
"The team abandoned the original scaffold series and restarted from a cleaner alternative."
|
| 363 |
+
)
|
| 364 |
+
return -0.4
|
| 365 |
+
|
| 366 |
+
def _build_assay_readings(self, tool_name: str, runs: int) -> List[AssayReading]:
|
| 367 |
+
properties = self._true_properties()
|
| 368 |
+
signature = self._molecule_signature()
|
| 369 |
+
|
| 370 |
+
if tool_name == "evaluate_properties":
|
| 371 |
+
property_names = ["potency", "novelty"]
|
| 372 |
+
elif tool_name == "dock_target":
|
| 373 |
+
property_names = ["potency"]
|
| 374 |
+
elif tool_name == "assay_toxicity":
|
| 375 |
+
property_names = ["toxicity"]
|
| 376 |
+
elif tool_name == "estimate_synthesizability":
|
| 377 |
+
property_names = ["synth"]
|
| 378 |
+
elif tool_name == "evaluate_novelty":
|
| 379 |
+
property_names = ["novelty"]
|
| 380 |
+
elif tool_name == "search_literature":
|
| 381 |
+
hint_score = min(0.95, 0.45 + 0.08 * runs)
|
| 382 |
+
return [
|
| 383 |
+
AssayReading(
|
| 384 |
+
tool_name=tool_name,
|
| 385 |
+
property_name="literature_signal",
|
| 386 |
+
estimate=round(hint_score, 4),
|
| 387 |
+
confidence_low=max(0.0, round(hint_score - 0.08, 4)),
|
| 388 |
+
confidence_high=min(1.0, round(hint_score + 0.08, 4)),
|
| 389 |
+
runs=runs,
|
| 390 |
+
molecule_signature=signature,
|
| 391 |
+
summary=literature_hints(self._molecule)[0],
|
| 392 |
+
)
|
| 393 |
+
]
|
| 394 |
+
else:
|
| 395 |
+
property_names = ["potency", "toxicity", "synth"]
|
| 396 |
+
|
| 397 |
+
readings = []
|
| 398 |
+
for property_name in property_names:
|
| 399 |
+
true_value = properties[property_name]
|
| 400 |
+
estimate = self._assay_estimate(signature, tool_name, property_name, runs, true_value)
|
| 401 |
+
width = max(0.03, 0.18 / runs)
|
| 402 |
+
readings.append(
|
| 403 |
+
AssayReading(
|
| 404 |
+
tool_name=tool_name,
|
| 405 |
+
property_name=property_name,
|
| 406 |
+
estimate=estimate,
|
| 407 |
+
confidence_low=max(0.0, round(estimate - width, 4)),
|
| 408 |
+
confidence_high=min(1.0, round(estimate + width, 4)),
|
| 409 |
+
runs=runs,
|
| 410 |
+
molecule_signature=signature,
|
| 411 |
+
summary=f"{tool_name} estimated {property_name} with run count {runs}.",
|
| 412 |
+
)
|
| 413 |
+
)
|
| 414 |
+
return readings
|
server/app.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI app for MolForge."""
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
from openenv.core.env_server.http_server import create_app
|
| 5 |
+
except Exception as exc: # pragma: no cover
|
| 6 |
+
raise ImportError(
|
| 7 |
+
"openenv-core is required to run MolForge. Install dependencies from pyproject.toml."
|
| 8 |
+
) from exc
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from ..models import MolForgeAction, MolForgeObservation
|
| 12 |
+
from .molforge_environment import MolForgeEnvironment
|
| 13 |
+
except ImportError:
|
| 14 |
+
from models import MolForgeAction, MolForgeObservation
|
| 15 |
+
from server.molforge_environment import MolForgeEnvironment
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
app = create_app(
|
| 19 |
+
MolForgeEnvironment,
|
| 20 |
+
MolForgeAction,
|
| 21 |
+
MolForgeObservation,
|
| 22 |
+
env_name="molforge",
|
| 23 |
+
max_concurrent_envs=2,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 28 |
+
"""Run the environment locally without Docker."""
|
| 29 |
+
|
| 30 |
+
import uvicorn
|
| 31 |
+
|
| 32 |
+
uvicorn.run(app, host=host, port=port)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
main()
|
server/governance.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Governance, validation, and coordination logic for MolForge."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Mapping, Optional
|
| 6 |
+
|
| 7 |
+
from .shared import (
|
| 8 |
+
DEFAULT_TOOL_COSTS,
|
| 9 |
+
EDITABLE_SLOTS,
|
| 10 |
+
FRAGMENT_LIBRARY,
|
| 11 |
+
ROLE_MESSAGE_TYPES,
|
| 12 |
+
ROLE_PERMISSIONS,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from ..models import GovernanceStatus, MolForgeAction, RewardComponent
|
| 17 |
+
except ImportError:
|
| 18 |
+
from models import GovernanceStatus, MolForgeAction, RewardComponent
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MolForgeGovernanceMixin:
|
| 22 |
+
"""Validation and multi-agent review methods."""
|
| 23 |
+
|
| 24 |
+
def _validate_action(self, action: MolForgeAction) -> Optional[tuple[str, str]]:
|
| 25 |
+
if action.action_type not in self._scenario.enabled_actions:
|
| 26 |
+
return "ACTION_DISABLED", f"{action.action_type} is disabled for this scenario."
|
| 27 |
+
|
| 28 |
+
if action.acting_role not in self._scenario.enabled_roles:
|
| 29 |
+
return "ROLE_DISABLED", f"{action.acting_role} is not enabled for this scenario."
|
| 30 |
+
|
| 31 |
+
allowed_actions = ROLE_PERMISSIONS.get(action.acting_role, [])
|
| 32 |
+
if action.action_type not in allowed_actions:
|
| 33 |
+
return (
|
| 34 |
+
"ROLE_PERMISSION_DENIED",
|
| 35 |
+
f"{action.acting_role} is not permitted to execute {action.action_type}.",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if len(action.messages) > self._scenario.max_messages_per_turn:
|
| 39 |
+
return (
|
| 40 |
+
"MESSAGE_LIMIT_EXCEEDED",
|
| 41 |
+
f"At most {self._scenario.max_messages_per_turn} messages may be sent per turn.",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
seen_senders = set()
|
| 45 |
+
for message in action.messages:
|
| 46 |
+
if message.sender not in self._scenario.enabled_roles:
|
| 47 |
+
return "MESSAGE_ROLE_INVALID", f"{message.sender} is not enabled in this scenario."
|
| 48 |
+
if message.sender in seen_senders:
|
| 49 |
+
return (
|
| 50 |
+
"DUPLICATE_ROLE_MESSAGE",
|
| 51 |
+
f"Each specialist may emit at most one message per turn; duplicate from {message.sender}.",
|
| 52 |
+
)
|
| 53 |
+
seen_senders.add(message.sender)
|
| 54 |
+
if message.message_type not in ROLE_MESSAGE_TYPES.get(message.sender, []):
|
| 55 |
+
return (
|
| 56 |
+
"MESSAGE_PERMISSION_DENIED",
|
| 57 |
+
f"{message.sender} cannot emit message type {message.message_type}.",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if action.action_type == "edit":
|
| 61 |
+
if action.slot is None or action.edit_type is None:
|
| 62 |
+
return "MISSING_EDIT_FIELDS", "Edit actions require both slot and edit_type."
|
| 63 |
+
if action.slot not in EDITABLE_SLOTS:
|
| 64 |
+
return "INVALID_SLOT", f"{action.slot} is not editable in MolForge."
|
| 65 |
+
if action.edit_type in {"add_fragment", "substitute"} and not action.fragment:
|
| 66 |
+
return "MISSING_FRAGMENT", "Edit actions require a fragment for add/substitute."
|
| 67 |
+
if action.fragment:
|
| 68 |
+
if action.fragment not in FRAGMENT_LIBRARY[action.slot]:
|
| 69 |
+
return "UNKNOWN_FRAGMENT", f"{action.fragment} is not valid for slot {action.slot}."
|
| 70 |
+
if self._molecule[action.slot] == action.fragment:
|
| 71 |
+
return "NO_STATE_CHANGE", "Edit selected the fragment already present in that slot."
|
| 72 |
+
|
| 73 |
+
if action.action_type == "run_assay":
|
| 74 |
+
if action.tool_name is None:
|
| 75 |
+
return "MISSING_TOOL_NAME", "run_assay actions require a tool_name."
|
| 76 |
+
if action.tool_name not in self._scenario.enabled_tools:
|
| 77 |
+
return "TOOL_DISABLED", f"{action.tool_name} is not enabled for this scenario."
|
| 78 |
+
cost = DEFAULT_TOOL_COSTS[action.tool_name]
|
| 79 |
+
if self._state.remaining_budget < cost:
|
| 80 |
+
return "BUDGET_EXCEEDED", f"{action.tool_name} costs {cost}, exceeding remaining budget."
|
| 81 |
+
|
| 82 |
+
if action.action_type == "restart":
|
| 83 |
+
if self._restart_used:
|
| 84 |
+
return "RESTART_ALREADY_USED", "restart_from_new_scaffold may be used at most once per episode."
|
| 85 |
+
if self._state.remaining_budget < 350:
|
| 86 |
+
return "BUDGET_EXCEEDED", "Not enough budget remains to restart from a new scaffold."
|
| 87 |
+
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
def _assess_governance(
|
| 91 |
+
self,
|
| 92 |
+
action: MolForgeAction,
|
| 93 |
+
previous_properties: Mapping[str, float],
|
| 94 |
+
) -> tuple[GovernanceStatus, List[RewardComponent], bool]:
|
| 95 |
+
reward_components: List[RewardComponent] = []
|
| 96 |
+
approvals: List[str] = []
|
| 97 |
+
objections: List[str] = []
|
| 98 |
+
vetoes: List[str] = []
|
| 99 |
+
required_roles = (
|
| 100 |
+
[]
|
| 101 |
+
if action.action_type == "defer"
|
| 102 |
+
else [role for role in self._scenario.required_review_roles if role != action.acting_role]
|
| 103 |
+
)
|
| 104 |
+
policy_veto = False
|
| 105 |
+
|
| 106 |
+
current_signature = self._molecule_signature()
|
| 107 |
+
simulated_properties = self._simulate_action_properties(action)
|
| 108 |
+
sender_map = {message.sender: message for message in action.messages}
|
| 109 |
+
|
| 110 |
+
actor_message = sender_map.get(action.acting_role)
|
| 111 |
+
if action.action_type != "defer":
|
| 112 |
+
if actor_message and actor_message.message_type == "proposal":
|
| 113 |
+
self._record_message(actor_message)
|
| 114 |
+
reward_components.append(
|
| 115 |
+
RewardComponent(
|
| 116 |
+
name="proposal_logged",
|
| 117 |
+
value=0.05,
|
| 118 |
+
explanation=f"{action.acting_role} logged a structured proposal before execution.",
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
self._role_metrics[action.acting_role]["correct_messages"] += 1
|
| 122 |
+
else:
|
| 123 |
+
reward_components.append(
|
| 124 |
+
RewardComponent(
|
| 125 |
+
name="missing_proposal",
|
| 126 |
+
value=-0.06,
|
| 127 |
+
explanation="The acting specialist did not provide an explicit proposal message.",
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
for role in required_roles:
|
| 132 |
+
expected = self._expected_feedback(role, action, previous_properties, simulated_properties)
|
| 133 |
+
actual = sender_map.get(role)
|
| 134 |
+
if actual is None:
|
| 135 |
+
reward_components.append(
|
| 136 |
+
RewardComponent(
|
| 137 |
+
name=f"missing_review_{role}",
|
| 138 |
+
value=-0.08,
|
| 139 |
+
explanation=f"{role} did not provide the required review for this turn.",
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
if expected["hard_veto"]:
|
| 143 |
+
policy_veto = True
|
| 144 |
+
vetoes.append(role)
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
if role != action.acting_role:
|
| 148 |
+
self._record_message(actual)
|
| 149 |
+
if self._matches_feedback(actual.message_type, expected["type"]):
|
| 150 |
+
reward_components.append(
|
| 151 |
+
RewardComponent(
|
| 152 |
+
name=f"coordination_{role}",
|
| 153 |
+
value=0.12,
|
| 154 |
+
explanation=expected["reason"],
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
self._role_metrics[role]["correct_messages"] += 1
|
| 158 |
+
if expected["type"] in {"approval", "submission_recommendation"}:
|
| 159 |
+
approvals.append(role)
|
| 160 |
+
else:
|
| 161 |
+
objections.append(role)
|
| 162 |
+
elif expected["type"] == "neutral":
|
| 163 |
+
reward_components.append(
|
| 164 |
+
RewardComponent(
|
| 165 |
+
name=f"unnecessary_message_{role}",
|
| 166 |
+
value=-0.02,
|
| 167 |
+
explanation=f"{role} contributed a message even though no strong intervention was needed.",
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
self._role_metrics[role]["incorrect_messages"] += 1
|
| 171 |
+
else:
|
| 172 |
+
reward_components.append(
|
| 173 |
+
RewardComponent(
|
| 174 |
+
name=f"misaligned_review_{role}",
|
| 175 |
+
value=-0.1,
|
| 176 |
+
explanation=(
|
| 177 |
+
f"{role} sent {actual.message_type}, but the hidden environment evaluation "
|
| 178 |
+
f"expected {expected['type']}."
|
| 179 |
+
),
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
self._role_metrics[role]["incorrect_messages"] += 1
|
| 183 |
+
if expected["hard_veto"]:
|
| 184 |
+
policy_veto = True
|
| 185 |
+
vetoes.append(role)
|
| 186 |
+
|
| 187 |
+
if expected["hard_veto"] and actual and self._matches_feedback(actual.message_type, expected["type"]):
|
| 188 |
+
policy_veto = True
|
| 189 |
+
vetoes.append(role)
|
| 190 |
+
|
| 191 |
+
extra_roles = {
|
| 192 |
+
sender
|
| 193 |
+
for sender in sender_map
|
| 194 |
+
if sender not in required_roles and sender != action.acting_role
|
| 195 |
+
}
|
| 196 |
+
for role in sorted(extra_roles):
|
| 197 |
+
self._record_message(sender_map[role])
|
| 198 |
+
reward_components.append(
|
| 199 |
+
RewardComponent(
|
| 200 |
+
name=f"optional_review_{role}",
|
| 201 |
+
value=0.02,
|
| 202 |
+
explanation=f"{role} added optional context for the current decision.",
|
| 203 |
+
)
|
| 204 |
+
)
|
| 205 |
+
self._role_metrics[role]["correct_messages"] += 1
|
| 206 |
+
|
| 207 |
+
if policy_veto:
|
| 208 |
+
reward_components.append(
|
| 209 |
+
RewardComponent(
|
| 210 |
+
name="policy_veto",
|
| 211 |
+
value=-0.35,
|
| 212 |
+
explanation="A specialist raised a valid hard veto, so the action was blocked.",
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
explanation = (
|
| 217 |
+
"Team review approved the decision."
|
| 218 |
+
if not policy_veto
|
| 219 |
+
else f"Action blocked after hard veto from {', '.join(vetoes)}."
|
| 220 |
+
)
|
| 221 |
+
status = "policy_veto" if policy_veto else "ready"
|
| 222 |
+
return (
|
| 223 |
+
GovernanceStatus(
|
| 224 |
+
status=status,
|
| 225 |
+
explanation=explanation,
|
| 226 |
+
required_roles=required_roles,
|
| 227 |
+
approvals=approvals,
|
| 228 |
+
objections=objections,
|
| 229 |
+
vetoes=vetoes,
|
| 230 |
+
executable=not policy_veto,
|
| 231 |
+
),
|
| 232 |
+
reward_components,
|
| 233 |
+
policy_veto,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def _expected_feedback(
|
| 237 |
+
self,
|
| 238 |
+
role: str,
|
| 239 |
+
action: MolForgeAction,
|
| 240 |
+
previous_properties: Mapping[str, float],
|
| 241 |
+
simulated_properties: Mapping[str, float],
|
| 242 |
+
) -> Dict[str, Any]:
|
| 243 |
+
tox_threshold = self._scenario.hard_constraints.get("toxicity_max", 0.45)
|
| 244 |
+
synth_threshold = self._scenario.hard_constraints.get("synth_min", 0.56)
|
| 245 |
+
current_signature = self._molecule_signature()
|
| 246 |
+
known_toxicity = self._current_property_estimate("toxicity", current_signature)
|
| 247 |
+
known_synth = self._current_property_estimate("synth", current_signature)
|
| 248 |
+
|
| 249 |
+
neutral = {"type": "neutral", "hard_veto": False, "reason": f"{role} had no strong intervention."}
|
| 250 |
+
|
| 251 |
+
if role == "toxicologist":
|
| 252 |
+
if action.action_type == "run_assay":
|
| 253 |
+
if action.tool_name == "assay_toxicity":
|
| 254 |
+
return {
|
| 255 |
+
"type": "approval",
|
| 256 |
+
"hard_veto": False,
|
| 257 |
+
"reason": "Toxicologist correctly encouraged a direct safety assay.",
|
| 258 |
+
}
|
| 259 |
+
return {
|
| 260 |
+
"type": "approval",
|
| 261 |
+
"hard_veto": False,
|
| 262 |
+
"reason": "Toxicologist correctly allowed an assay that improves decision evidence without increasing chemical risk.",
|
| 263 |
+
}
|
| 264 |
+
if action.action_type == "submit":
|
| 265 |
+
if known_toxicity is not None and known_toxicity > tox_threshold:
|
| 266 |
+
return {
|
| 267 |
+
"type": "objection",
|
| 268 |
+
"hard_veto": simulated_properties["toxicity"] > tox_threshold,
|
| 269 |
+
"reason": "Toxicologist correctly objected to a submission with visible toxicity evidence above threshold.",
|
| 270 |
+
}
|
| 271 |
+
if simulated_properties["toxicity"] > tox_threshold:
|
| 272 |
+
return {
|
| 273 |
+
"type": "objection",
|
| 274 |
+
"hard_veto": True,
|
| 275 |
+
"reason": "Toxicologist correctly blocked an unsafe submission.",
|
| 276 |
+
}
|
| 277 |
+
if known_toxicity is None:
|
| 278 |
+
return {
|
| 279 |
+
"type": "assay_request",
|
| 280 |
+
"hard_veto": True,
|
| 281 |
+
"reason": "Toxicologist correctly demanded explicit toxicity evidence before submission.",
|
| 282 |
+
}
|
| 283 |
+
return {
|
| 284 |
+
"type": "approval",
|
| 285 |
+
"hard_veto": False,
|
| 286 |
+
"reason": "Toxicologist correctly approved a submission with acceptable visible safety evidence.",
|
| 287 |
+
}
|
| 288 |
+
if action.action_type in {"edit", "restart"}:
|
| 289 |
+
toxicity_delta = simulated_properties["toxicity"] - previous_properties["toxicity"]
|
| 290 |
+
if toxicity_delta > 0.08:
|
| 291 |
+
return {
|
| 292 |
+
"type": "objection",
|
| 293 |
+
"hard_veto": True,
|
| 294 |
+
"reason": "Toxicologist correctly raised a hard objection to a major safety regression.",
|
| 295 |
+
}
|
| 296 |
+
if (
|
| 297 |
+
simulated_properties["toxicity"] > tox_threshold + 0.02
|
| 298 |
+
and toxicity_delta >= -0.02
|
| 299 |
+
):
|
| 300 |
+
return {
|
| 301 |
+
"type": "objection",
|
| 302 |
+
"hard_veto": True,
|
| 303 |
+
"reason": "Toxicologist correctly blocked a move that left an unsafe scaffold unimproved.",
|
| 304 |
+
}
|
| 305 |
+
if simulated_properties["toxicity"] > tox_threshold + 0.02:
|
| 306 |
+
return {
|
| 307 |
+
"type": "approval",
|
| 308 |
+
"hard_veto": False,
|
| 309 |
+
"reason": "Toxicologist correctly allowed a risk-reducing move while residual safety work remains.",
|
| 310 |
+
}
|
| 311 |
+
return {
|
| 312 |
+
"type": "approval",
|
| 313 |
+
"hard_veto": False,
|
| 314 |
+
"reason": "Toxicologist correctly approved a safety-compatible move.",
|
| 315 |
+
}
|
| 316 |
+
return neutral
|
| 317 |
+
|
| 318 |
+
if role == "assay_planner":
|
| 319 |
+
if action.action_type == "run_assay":
|
| 320 |
+
info_gain = self._estimate_information_gain(action.tool_name or "")
|
| 321 |
+
prior_runs = self._assay_runs.get(f"{current_signature}::{action.tool_name}", 0)
|
| 322 |
+
if (action.tool_name == "run_md_simulation" and self._state.remaining_budget < 4500) or (
|
| 323 |
+
prior_runs > 0 and info_gain < 0.05
|
| 324 |
+
):
|
| 325 |
+
return {
|
| 326 |
+
"type": "rejection",
|
| 327 |
+
"hard_veto": True,
|
| 328 |
+
"reason": "Assay Planner correctly blocked a wasteful or over-expensive assay.",
|
| 329 |
+
}
|
| 330 |
+
if info_gain < 0.04 and action.tool_name != "search_literature":
|
| 331 |
+
return {
|
| 332 |
+
"type": "rejection",
|
| 333 |
+
"hard_veto": False,
|
| 334 |
+
"reason": "Assay Planner correctly questioned a low-value assay.",
|
| 335 |
+
}
|
| 336 |
+
return {
|
| 337 |
+
"type": "approval",
|
| 338 |
+
"hard_veto": False,
|
| 339 |
+
"reason": "Assay Planner correctly approved an information-efficient assay.",
|
| 340 |
+
}
|
| 341 |
+
if action.action_type == "submit":
|
| 342 |
+
required_props = ["potency", "toxicity"]
|
| 343 |
+
if "synth_min" in self._scenario.hard_constraints:
|
| 344 |
+
required_props.append("synth")
|
| 345 |
+
missing = [
|
| 346 |
+
prop for prop in required_props if self._current_property_estimate(prop, current_signature) is None
|
| 347 |
+
]
|
| 348 |
+
if missing:
|
| 349 |
+
return {
|
| 350 |
+
"type": "assay_request",
|
| 351 |
+
"hard_veto": True,
|
| 352 |
+
"reason": "Assay Planner correctly asked for missing evidence before submission.",
|
| 353 |
+
}
|
| 354 |
+
return {
|
| 355 |
+
"type": "approval",
|
| 356 |
+
"hard_veto": False,
|
| 357 |
+
"reason": "Assay Planner correctly approved a well-supported submission.",
|
| 358 |
+
}
|
| 359 |
+
if action.action_type == "restart":
|
| 360 |
+
potency_threshold = self._scenario.hard_constraints.get("potency_min", 0.72)
|
| 361 |
+
if self._scenario.trap_penalty and previous_properties["potency"] < potency_threshold:
|
| 362 |
+
return {
|
| 363 |
+
"type": "approval",
|
| 364 |
+
"hard_veto": False,
|
| 365 |
+
"reason": "Assay Planner correctly endorsed escaping a low-value scaffold family.",
|
| 366 |
+
}
|
| 367 |
+
return {
|
| 368 |
+
"type": "rejection",
|
| 369 |
+
"hard_veto": False,
|
| 370 |
+
"reason": "Assay Planner correctly questioned an unnecessary restart.",
|
| 371 |
+
}
|
| 372 |
+
if action.action_type == "edit":
|
| 373 |
+
return {
|
| 374 |
+
"type": "approval",
|
| 375 |
+
"hard_veto": False,
|
| 376 |
+
"reason": "Assay Planner correctly approved a low-cost edit before spending assay budget.",
|
| 377 |
+
}
|
| 378 |
+
return neutral
|
| 379 |
+
|
| 380 |
+
if role == "process_chemist":
|
| 381 |
+
if action.action_type == "run_assay":
|
| 382 |
+
if action.tool_name == "estimate_synthesizability":
|
| 383 |
+
return {
|
| 384 |
+
"type": "approval",
|
| 385 |
+
"hard_veto": False,
|
| 386 |
+
"reason": "Process Chemist correctly requested explicit synthesizeability evidence.",
|
| 387 |
+
}
|
| 388 |
+
return {
|
| 389 |
+
"type": "approval",
|
| 390 |
+
"hard_veto": False,
|
| 391 |
+
"reason": "Process Chemist correctly allowed an assay that does not worsen route feasibility.",
|
| 392 |
+
}
|
| 393 |
+
if action.action_type == "submit":
|
| 394 |
+
if known_synth is not None and known_synth < synth_threshold:
|
| 395 |
+
return {
|
| 396 |
+
"type": "objection",
|
| 397 |
+
"hard_veto": simulated_properties["synth"] < synth_threshold,
|
| 398 |
+
"reason": "Process Chemist correctly objected to a submission with visible route evidence below threshold.",
|
| 399 |
+
}
|
| 400 |
+
if simulated_properties["synth"] < synth_threshold:
|
| 401 |
+
return {
|
| 402 |
+
"type": "objection",
|
| 403 |
+
"hard_veto": "synth_min" in self._scenario.hard_constraints,
|
| 404 |
+
"reason": "Process Chemist correctly blocked a submission that looks infeasible to make.",
|
| 405 |
+
}
|
| 406 |
+
if known_synth is None:
|
| 407 |
+
return {
|
| 408 |
+
"type": "assay_request",
|
| 409 |
+
"hard_veto": False,
|
| 410 |
+
"reason": "Process Chemist correctly asked for synthesizeability evidence before submission.",
|
| 411 |
+
}
|
| 412 |
+
return {
|
| 413 |
+
"type": "approval",
|
| 414 |
+
"hard_veto": False,
|
| 415 |
+
"reason": "Process Chemist correctly approved a feasible-looking submission.",
|
| 416 |
+
}
|
| 417 |
+
if action.action_type in {"edit", "restart"}:
|
| 418 |
+
if simulated_properties["synth"] < synth_threshold - 0.03:
|
| 419 |
+
return {
|
| 420 |
+
"type": "objection",
|
| 421 |
+
"hard_veto": False,
|
| 422 |
+
"reason": "Process Chemist correctly flagged a severe feasibility regression.",
|
| 423 |
+
}
|
| 424 |
+
if previous_properties["synth"] - simulated_properties["synth"] > 0.08:
|
| 425 |
+
return {
|
| 426 |
+
"type": "objection",
|
| 427 |
+
"hard_veto": False,
|
| 428 |
+
"reason": "Process Chemist correctly objected to a less tractable route.",
|
| 429 |
+
}
|
| 430 |
+
return {
|
| 431 |
+
"type": "approval",
|
| 432 |
+
"hard_veto": False,
|
| 433 |
+
"reason": "Process Chemist correctly approved a tractable chemistry move.",
|
| 434 |
+
}
|
| 435 |
+
return neutral
|
| 436 |
+
|
| 437 |
+
return neutral
|
| 438 |
+
|
| 439 |
+
@staticmethod
|
| 440 |
+
def _matches_feedback(actual_type: str, expected_type: str) -> bool:
|
| 441 |
+
if expected_type == "neutral":
|
| 442 |
+
return False
|
| 443 |
+
if expected_type == "approval":
|
| 444 |
+
return actual_type in {"approval", "submission_recommendation"}
|
| 445 |
+
if expected_type == "objection":
|
| 446 |
+
return actual_type in {"objection", "risk_flag", "rejection"}
|
| 447 |
+
if expected_type == "rejection":
|
| 448 |
+
return actual_type in {"rejection", "objection"}
|
| 449 |
+
if expected_type == "assay_request":
|
| 450 |
+
return actual_type == "assay_request"
|
| 451 |
+
return actual_type == expected_type
|
| 452 |
+
|
| 453 |
+
def _evaluate_reasoning_consistency(
|
| 454 |
+
self,
|
| 455 |
+
action: MolForgeAction,
|
| 456 |
+
previous_properties: Mapping[str, float],
|
| 457 |
+
current_properties: Mapping[str, float],
|
| 458 |
+
reward_components: List[RewardComponent],
|
| 459 |
+
) -> float:
|
| 460 |
+
del previous_properties, current_properties
|
| 461 |
+
|
| 462 |
+
rationale = action.rationale.lower().strip()
|
| 463 |
+
evidence = [item.lower().strip() for item in action.evidence if item.strip()]
|
| 464 |
+
expected_effects = {key: value for key, value in action.expected_effects.items() if value}
|
| 465 |
+
score = 0.0
|
| 466 |
+
explanations = []
|
| 467 |
+
|
| 468 |
+
if rationale:
|
| 469 |
+
score += 0.02
|
| 470 |
+
explanations.append("short rationale present")
|
| 471 |
+
else:
|
| 472 |
+
score -= 0.03
|
| 473 |
+
explanations.append("missing rationale")
|
| 474 |
+
|
| 475 |
+
if evidence:
|
| 476 |
+
grounded = sum(1 for item in evidence if self._evidence_item_is_visible(item))
|
| 477 |
+
score += min(grounded, 3) * 0.015
|
| 478 |
+
if grounded < len(evidence):
|
| 479 |
+
score -= min(len(evidence) - grounded, 2) * 0.02
|
| 480 |
+
explanations.append(f"{grounded}/{len(evidence)} evidence item(s) matched visible state")
|
| 481 |
+
else:
|
| 482 |
+
score -= 0.03
|
| 483 |
+
explanations.append("missing visible evidence")
|
| 484 |
+
|
| 485 |
+
if expected_effects:
|
| 486 |
+
plausible = sum(
|
| 487 |
+
1
|
| 488 |
+
for metric, direction in expected_effects.items()
|
| 489 |
+
if self._expected_effect_is_plausible(action, metric, direction)
|
| 490 |
+
)
|
| 491 |
+
checked = len(expected_effects)
|
| 492 |
+
score += min(plausible, 3) * 0.01
|
| 493 |
+
if plausible < checked:
|
| 494 |
+
score -= min(checked - plausible, 2) * 0.015
|
| 495 |
+
explanations.append(f"{plausible}/{checked} expected effect(s) were directionally plausible")
|
| 496 |
+
else:
|
| 497 |
+
score -= 0.02
|
| 498 |
+
explanations.append("missing expected effects")
|
| 499 |
+
|
| 500 |
+
score = max(-0.04, min(0.04, score))
|
| 501 |
+
|
| 502 |
+
if score != 0.0:
|
| 503 |
+
reward_components.append(
|
| 504 |
+
RewardComponent(
|
| 505 |
+
name="reasoning_grounding",
|
| 506 |
+
value=round(score, 4),
|
| 507 |
+
explanation="; ".join(explanations),
|
| 508 |
+
)
|
| 509 |
+
)
|
| 510 |
+
return score
|
| 511 |
+
|
| 512 |
+
def _evidence_item_is_visible(self, item: str) -> bool:
|
| 513 |
+
if not item:
|
| 514 |
+
return False
|
| 515 |
+
visible_terms = {
|
| 516 |
+
self._scenario.scenario_id.lower(),
|
| 517 |
+
self._scenario.difficulty.lower(),
|
| 518 |
+
self._molecule_signature().lower(),
|
| 519 |
+
str(self._state.remaining_budget),
|
| 520 |
+
str(self._state.max_budget),
|
| 521 |
+
str(self._state.step_count),
|
| 522 |
+
str(self._scenario.max_steps),
|
| 523 |
+
}
|
| 524 |
+
visible_terms.update(fragment.lower() for fragment in self._molecule.values())
|
| 525 |
+
visible_terms.update(tool.lower() for tool in self._scenario.enabled_tools)
|
| 526 |
+
visible_terms.update(constraint.lower() for constraint in self._scenario.hard_constraints)
|
| 527 |
+
visible_terms.update(reading.property_name.lower() for reading in self._known_assays)
|
| 528 |
+
visible_terms.update(reading.tool_name.lower() for reading in self._known_assays)
|
| 529 |
+
return any(term and term in item for term in visible_terms)
|
| 530 |
+
|
| 531 |
+
def _expected_effect_is_plausible(
|
| 532 |
+
self,
|
| 533 |
+
action: MolForgeAction,
|
| 534 |
+
metric: str,
|
| 535 |
+
direction: str,
|
| 536 |
+
) -> bool:
|
| 537 |
+
if metric not in {"potency", "toxicity", "synth", "novelty", "budget"}:
|
| 538 |
+
return False
|
| 539 |
+
if direction not in {"up", "down", "neutral", "unknown", "not_applicable"}:
|
| 540 |
+
return False
|
| 541 |
+
if direction in {"unknown", "not_applicable"}:
|
| 542 |
+
return True
|
| 543 |
+
if metric == "budget":
|
| 544 |
+
if action.action_type in {"run_assay", "restart"}:
|
| 545 |
+
return direction == "down"
|
| 546 |
+
return direction == "neutral"
|
| 547 |
+
if action.action_type in {"run_assay", "submit", "defer"}:
|
| 548 |
+
return direction in {"neutral", "unknown", "not_applicable"}
|
| 549 |
+
if action.action_type == "restart":
|
| 550 |
+
if metric in {"toxicity", "budget"}:
|
| 551 |
+
return direction == "down"
|
| 552 |
+
if metric == "synth":
|
| 553 |
+
return direction in {"up", "unknown"}
|
| 554 |
+
return direction in {"up", "unknown", "neutral"}
|
| 555 |
+
if action.action_type != "edit" or not action.slot or not action.fragment:
|
| 556 |
+
return direction in {"neutral", "unknown"}
|
| 557 |
+
|
| 558 |
+
fragment = action.fragment
|
| 559 |
+
plausibility = {
|
| 560 |
+
("hinge", "azaindole", "potency", "up"),
|
| 561 |
+
("back_pocket", "cyano", "potency", "up"),
|
| 562 |
+
("back_pocket", "cyano", "toxicity", "down"),
|
| 563 |
+
("back_pocket", "chloro", "potency", "up"),
|
| 564 |
+
("back_pocket", "chloro", "toxicity", "up"),
|
| 565 |
+
("back_pocket", "trifluoromethyl", "potency", "up"),
|
| 566 |
+
("back_pocket", "trifluoromethyl", "toxicity", "up"),
|
| 567 |
+
("solvent_tail", "morpholine", "toxicity", "down"),
|
| 568 |
+
("solvent_tail", "morpholine", "synth", "up"),
|
| 569 |
+
("solvent_tail", "dimethylamino", "toxicity", "up"),
|
| 570 |
+
("warhead", "reversible_cyanoacrylamide", "toxicity", "down"),
|
| 571 |
+
("warhead", "reversible_cyanoacrylamide", "novelty", "up"),
|
| 572 |
+
("warhead", "nitrile", "toxicity", "down"),
|
| 573 |
+
}
|
| 574 |
+
if (action.slot, fragment, metric, direction) in plausibility:
|
| 575 |
+
return True
|
| 576 |
+
return direction in {"neutral", "unknown"}
|
server/molforge_environment.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MolForge environment implementation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
from dataclasses import replace
|
| 8 |
+
from typing import Any, Dict, List
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
|
| 11 |
+
from openenv.core.env_server.interfaces import Environment
|
| 12 |
+
|
| 13 |
+
from .actions import MolForgeActionMixin
|
| 14 |
+
from .governance import MolForgeGovernanceMixin
|
| 15 |
+
from .shared import (
|
| 16 |
+
FRAGMENT_LIBRARY,
|
| 17 |
+
SCENARIOS,
|
| 18 |
+
SLOT_ORDER,
|
| 19 |
+
compute_objective_score,
|
| 20 |
+
get_scenario,
|
| 21 |
+
)
|
| 22 |
+
from .shared import MolForgeSharedMixin
|
| 23 |
+
from .views import MolForgeViewMixin
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from ..models import GovernanceStatus, MolForgeAction, MolForgeObservation, MolForgeState, RewardComponent
|
| 27 |
+
except ImportError:
|
| 28 |
+
from models import GovernanceStatus, MolForgeAction, MolForgeObservation, MolForgeState, RewardComponent
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MolForgeEnvironment(
|
| 32 |
+
MolForgeActionMixin,
|
| 33 |
+
MolForgeGovernanceMixin,
|
| 34 |
+
MolForgeViewMixin,
|
| 35 |
+
MolForgeSharedMixin,
|
| 36 |
+
Environment,
|
| 37 |
+
):
|
| 38 |
+
"""Deterministic medicinal-chemistry design environment for OpenEnv."""
|
| 39 |
+
|
| 40 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 41 |
+
|
| 42 |
+
def __init__(self):
|
| 43 |
+
self._debug_state_enabled = os.getenv("MOLFORGE_DEBUG_STATE", "").lower() in {"1", "true", "yes"}
|
| 44 |
+
self._training_randomization_enabled = os.getenv("MOLFORGE_TRAINING_RANDOMIZATION", "").lower() in {
|
| 45 |
+
"1",
|
| 46 |
+
"true",
|
| 47 |
+
"yes",
|
| 48 |
+
}
|
| 49 |
+
self._reward_mode = os.getenv("MOLFORGE_REWARD_MODE", "assay_gated").lower()
|
| 50 |
+
self._rng = random.Random(os.getenv("MOLFORGE_RANDOM_SEED", "molforge"))
|
| 51 |
+
self._reset_index = -1
|
| 52 |
+
self._state = MolForgeState(episode_id=str(uuid4()), step_count=0)
|
| 53 |
+
self._scenario = SCENARIOS[0]
|
| 54 |
+
self._molecule: Dict[str, str] = {}
|
| 55 |
+
self._assay_runs: Dict[str, int] = {}
|
| 56 |
+
self._known_assays: List = []
|
| 57 |
+
self._message_log: List = []
|
| 58 |
+
self._history: List[Dict[str, Any]] = []
|
| 59 |
+
self._oracle_log: List[Dict[str, Any]] = []
|
| 60 |
+
self._visited_states: set[str] = set()
|
| 61 |
+
self._last_summary = ""
|
| 62 |
+
self._report_card = ""
|
| 63 |
+
self._reward_total = 0.0
|
| 64 |
+
self._restart_used = False
|
| 65 |
+
self._trap_penalty_active = False
|
| 66 |
+
self._role_metrics = self._empty_role_metrics()
|
| 67 |
+
self._state_path: List[str] = ["[start]"]
|
| 68 |
+
self._last_governance = GovernanceStatus(
|
| 69 |
+
status="ready",
|
| 70 |
+
explanation="Awaiting the first coordinated decision.",
|
| 71 |
+
required_roles=[],
|
| 72 |
+
approvals=[],
|
| 73 |
+
objections=[],
|
| 74 |
+
vetoes=[],
|
| 75 |
+
executable=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def reset(self) -> MolForgeObservation:
|
| 79 |
+
"""Start a new scenario in a deterministic rotation."""
|
| 80 |
+
|
| 81 |
+
self._reset_index += 1
|
| 82 |
+
self._scenario = self._select_reset_scenario()
|
| 83 |
+
self._molecule = dict(self._scenario.starting_scaffold)
|
| 84 |
+
self._assay_runs = {}
|
| 85 |
+
self._known_assays = []
|
| 86 |
+
self._message_log = []
|
| 87 |
+
self._history = []
|
| 88 |
+
self._oracle_log = []
|
| 89 |
+
self._visited_states = {self._molecule_signature()}
|
| 90 |
+
self._last_summary = "Episode initialized with a fresh multi-agent review board."
|
| 91 |
+
self._report_card = ""
|
| 92 |
+
self._reward_total = 0.0
|
| 93 |
+
self._restart_used = False
|
| 94 |
+
self._trap_penalty_active = self._scenario.trap_penalty
|
| 95 |
+
self._role_metrics = self._empty_role_metrics()
|
| 96 |
+
self._state_path = ["[start]"]
|
| 97 |
+
self._last_governance = GovernanceStatus(
|
| 98 |
+
status="ready",
|
| 99 |
+
explanation="Lead Chemist should propose the first coordinated action.",
|
| 100 |
+
required_roles=list(self._scenario.required_review_roles),
|
| 101 |
+
approvals=[],
|
| 102 |
+
objections=[],
|
| 103 |
+
vetoes=[],
|
| 104 |
+
executable=True,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self._state = MolForgeState(
|
| 108 |
+
episode_id=str(uuid4()),
|
| 109 |
+
step_count=0,
|
| 110 |
+
scenario_id=self._scenario.scenario_id,
|
| 111 |
+
difficulty=self._scenario.difficulty,
|
| 112 |
+
state_label="[start]",
|
| 113 |
+
state_path=list(self._state_path),
|
| 114 |
+
coordination_mode=self._scenario.coordination_mode, # type: ignore[arg-type]
|
| 115 |
+
enabled_roles=list(self._scenario.enabled_roles),
|
| 116 |
+
target_name=self._scenario.target_name,
|
| 117 |
+
current_molecule=self._molecule_signature(),
|
| 118 |
+
remaining_budget=self._scenario.oracle_budget,
|
| 119 |
+
budget_used=0,
|
| 120 |
+
max_budget=self._scenario.oracle_budget,
|
| 121 |
+
visited_states=1,
|
| 122 |
+
known_assay_count=0,
|
| 123 |
+
invalid_action_count=0,
|
| 124 |
+
objection_count=0,
|
| 125 |
+
oracle_call_count=0,
|
| 126 |
+
message_count=0,
|
| 127 |
+
decision_count=0,
|
| 128 |
+
submitted=False,
|
| 129 |
+
reward_total=0.0,
|
| 130 |
+
metadata={},
|
| 131 |
+
)
|
| 132 |
+
self._sync_state_metadata()
|
| 133 |
+
return self._build_observation(reward=0.0, done=False, reward_components=[])
|
| 134 |
+
|
| 135 |
+
def _select_reset_scenario(self):
|
| 136 |
+
"""Select a deterministic judge scenario or a randomized training variant."""
|
| 137 |
+
|
| 138 |
+
scenario = get_scenario(self._reset_index)
|
| 139 |
+
if not self._training_randomization_enabled:
|
| 140 |
+
return scenario
|
| 141 |
+
|
| 142 |
+
scenario = self._rng.choice(SCENARIOS)
|
| 143 |
+
budget_scale = self._rng.uniform(0.85, 1.15)
|
| 144 |
+
max_steps_delta = self._rng.choice([-1, 0, 0, 1])
|
| 145 |
+
starting_scaffold = dict(scenario.starting_scaffold)
|
| 146 |
+
if self._rng.random() < 0.35:
|
| 147 |
+
slot = self._rng.choice(SLOT_ORDER)
|
| 148 |
+
choices = [
|
| 149 |
+
fragment
|
| 150 |
+
for fragment in FRAGMENT_LIBRARY[slot]
|
| 151 |
+
if fragment != starting_scaffold[slot]
|
| 152 |
+
]
|
| 153 |
+
starting_scaffold[slot] = self._rng.choice(choices)
|
| 154 |
+
return replace(
|
| 155 |
+
scenario,
|
| 156 |
+
oracle_budget=max(1, int(round(scenario.oracle_budget * budget_scale))),
|
| 157 |
+
max_steps=max(4, scenario.max_steps + max_steps_delta),
|
| 158 |
+
starting_scaffold=starting_scaffold,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def step(self, action: MolForgeAction) -> MolForgeObservation: # type: ignore[override]
|
| 162 |
+
"""Execute one coordinated environment action."""
|
| 163 |
+
|
| 164 |
+
reward_components: List[RewardComponent] = []
|
| 165 |
+
done = False
|
| 166 |
+
error_code = ""
|
| 167 |
+
self._state.step_count += 1
|
| 168 |
+
self._state.decision_count += 1
|
| 169 |
+
|
| 170 |
+
previous_properties = self._true_properties()
|
| 171 |
+
previous_score = compute_objective_score(previous_properties, self._scenario)
|
| 172 |
+
|
| 173 |
+
validation_error = self._validate_action(action)
|
| 174 |
+
if validation_error:
|
| 175 |
+
error_code, message = validation_error
|
| 176 |
+
self._state.invalid_action_count += 1
|
| 177 |
+
self._last_governance = GovernanceStatus(
|
| 178 |
+
status="needs_revision",
|
| 179 |
+
explanation=message,
|
| 180 |
+
required_roles=list(self._scenario.required_review_roles),
|
| 181 |
+
approvals=[],
|
| 182 |
+
objections=[],
|
| 183 |
+
vetoes=[],
|
| 184 |
+
executable=False,
|
| 185 |
+
)
|
| 186 |
+
reward_components.append(
|
| 187 |
+
RewardComponent(
|
| 188 |
+
name="invalid_action",
|
| 189 |
+
value=-1.0,
|
| 190 |
+
explanation=message,
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
reward = -1.0
|
| 194 |
+
self._last_summary = message
|
| 195 |
+
self._append_state_label("[invalid]")
|
| 196 |
+
else:
|
| 197 |
+
governance, governance_components, policy_veto = self._assess_governance(
|
| 198 |
+
action, previous_properties
|
| 199 |
+
)
|
| 200 |
+
self._last_governance = governance
|
| 201 |
+
reward_components.extend(governance_components)
|
| 202 |
+
reward = sum(component.value for component in governance_components)
|
| 203 |
+
|
| 204 |
+
if policy_veto:
|
| 205 |
+
self._last_summary = governance.explanation
|
| 206 |
+
self._append_state_label("[policy_veto]")
|
| 207 |
+
else:
|
| 208 |
+
self._last_governance.status = "executed"
|
| 209 |
+
action_reward, done = self._execute_action(
|
| 210 |
+
action, reward_components, previous_properties, previous_score
|
| 211 |
+
)
|
| 212 |
+
reward += action_reward
|
| 213 |
+
if not done:
|
| 214 |
+
reward += self._evaluate_reasoning_consistency(
|
| 215 |
+
action,
|
| 216 |
+
previous_properties,
|
| 217 |
+
self._true_properties(),
|
| 218 |
+
reward_components,
|
| 219 |
+
)
|
| 220 |
+
if done and self._state.submitted:
|
| 221 |
+
self._append_state_label("[submitted]")
|
| 222 |
+
elif not done:
|
| 223 |
+
self._append_state_label(f"[decision_{self._state.step_count:02d}]")
|
| 224 |
+
|
| 225 |
+
if not done and self._state.step_count >= self._scenario.max_steps:
|
| 226 |
+
done = True
|
| 227 |
+
reward_components.append(
|
| 228 |
+
RewardComponent(
|
| 229 |
+
name="step_limit",
|
| 230 |
+
value=-0.3,
|
| 231 |
+
explanation="Episode ended because the maximum decision horizon was reached.",
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
reward -= 0.3
|
| 235 |
+
self._report_card = self._build_report_card(submitted=False)
|
| 236 |
+
self._last_summary = "Max-step termination triggered."
|
| 237 |
+
self._append_state_label("[terminated:max_steps]")
|
| 238 |
+
|
| 239 |
+
if not done and self._state.remaining_budget <= 0:
|
| 240 |
+
done = True
|
| 241 |
+
reward_components.append(
|
| 242 |
+
RewardComponent(
|
| 243 |
+
name="budget_exhausted",
|
| 244 |
+
value=-0.5,
|
| 245 |
+
explanation="Episode terminated because the oracle budget reached zero.",
|
| 246 |
+
)
|
| 247 |
+
)
|
| 248 |
+
reward -= 0.5
|
| 249 |
+
self._report_card = self._build_report_card(submitted=False)
|
| 250 |
+
self._last_summary = "Budget exhausted before a valid submission."
|
| 251 |
+
self._append_state_label("[terminated:budget]")
|
| 252 |
+
|
| 253 |
+
if done and not self._report_card:
|
| 254 |
+
self._report_card = self._build_report_card(submitted=self._state.submitted)
|
| 255 |
+
|
| 256 |
+
if done and not self._state.submitted and self._reward_mode == "curriculum":
|
| 257 |
+
reward += self._curriculum_terminal_progress_reward(reward_components)
|
| 258 |
+
|
| 259 |
+
reward = round(reward, 4)
|
| 260 |
+
self._reward_total = round(self._reward_total + reward, 4)
|
| 261 |
+
self._state.reward_total = self._reward_total
|
| 262 |
+
self._state.current_molecule = self._molecule_signature()
|
| 263 |
+
self._state.state_label = self._state_path[-1]
|
| 264 |
+
self._state.state_path = list(self._state_path)
|
| 265 |
+
self._state.visited_states = len(self._visited_states)
|
| 266 |
+
self._state.known_assay_count = len(self._known_assays)
|
| 267 |
+
self._state.last_error_code = error_code
|
| 268 |
+
|
| 269 |
+
self._history.append(
|
| 270 |
+
{
|
| 271 |
+
"step": self._state.step_count,
|
| 272 |
+
"action": action.model_dump(exclude_none=True),
|
| 273 |
+
"reward": reward,
|
| 274 |
+
"done": done,
|
| 275 |
+
"molecule": self._molecule_signature(),
|
| 276 |
+
"state_label": self._state.state_label,
|
| 277 |
+
"summary": self._last_summary,
|
| 278 |
+
"governance": self._last_governance.model_dump(),
|
| 279 |
+
}
|
| 280 |
+
)
|
| 281 |
+
if done:
|
| 282 |
+
self._report_card = self._build_report_card(submitted=self._state.submitted)
|
| 283 |
+
self._sync_state_metadata()
|
| 284 |
+
|
| 285 |
+
return self._build_observation(
|
| 286 |
+
reward=reward,
|
| 287 |
+
done=done,
|
| 288 |
+
reward_components=reward_components,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def _curriculum_terminal_progress_reward(self, reward_components: List[RewardComponent]) -> float:
|
| 292 |
+
"""Give bounded partial credit for near-miss episodes during RL warmup.
|
| 293 |
+
|
| 294 |
+
This intentionally does not change the public submission grader. It only
|
| 295 |
+
makes the training reward less sparse when a model builds evidence or a
|
| 296 |
+
chemically plausible candidate but fails to formally submit.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
grader_scores = self._grade_all()
|
| 300 |
+
progress = (
|
| 301 |
+
0.25 * grader_scores["candidate_score"]
|
| 302 |
+
+ 0.25 * grader_scores["constraint_margin_score"]
|
| 303 |
+
+ 0.25 * grader_scores["evidence_score"]
|
| 304 |
+
+ 0.15 * grader_scores["coordination_score"]
|
| 305 |
+
+ 0.10 * grader_scores["budget_score"]
|
| 306 |
+
)
|
| 307 |
+
progress = min(0.75, max(0.0, progress))
|
| 308 |
+
reward_components.append(
|
| 309 |
+
RewardComponent(
|
| 310 |
+
name="curriculum_terminal_progress",
|
| 311 |
+
value=round(progress, 4),
|
| 312 |
+
explanation=(
|
| 313 |
+
"Bounded warmup reward for non-submitted episodes based on candidate quality, "
|
| 314 |
+
"constraint margin, evidence coverage, coordination, and budget discipline. "
|
| 315 |
+
"Official submission_score remains 0.0 without a submit action."
|
| 316 |
+
),
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
missed_nomination_penalty = 0.0
|
| 320 |
+
if (
|
| 321 |
+
grader_scores["evidence_score"] >= 0.99
|
| 322 |
+
and grader_scores["constraint_margin_score"] >= 0.9
|
| 323 |
+
and grader_scores["candidate_score"] >= self._scenario.baseline_to_beat
|
| 324 |
+
):
|
| 325 |
+
missed_nomination_penalty = -0.25
|
| 326 |
+
reward_components.append(
|
| 327 |
+
RewardComponent(
|
| 328 |
+
name="curriculum_missed_nomination",
|
| 329 |
+
value=missed_nomination_penalty,
|
| 330 |
+
explanation=(
|
| 331 |
+
"The candidate had a strong evidence package near the decision deadline, "
|
| 332 |
+
"but the team failed to make a formal submit decision."
|
| 333 |
+
),
|
| 334 |
+
)
|
| 335 |
+
)
|
| 336 |
+
return progress + missed_nomination_penalty
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def state(self) -> MolForgeState:
|
| 340 |
+
"""Return the current environment state."""
|
| 341 |
+
|
| 342 |
+
return self._state
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.3
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.30.0
|
| 4 |
+
pydantic>=2.8.0
|
| 5 |
+
rdkit>=2023.9.5,<2024.3.1; python_version < "3.13"
|
| 6 |
+
rdkit>=2026.3.1; python_version >= "3.13"
|
| 7 |
+
# Optional TDC oracle support:
|
| 8 |
+
# pytdc>=1.1.0; python_version < "3.13"
|
server/shared.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared imports, constants, and utility mixins for MolForge."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from typing import Any, Dict, List, Mapping, Optional
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from ..models import (
|
| 11 |
+
AgentMessage,
|
| 12 |
+
AssayReading,
|
| 13 |
+
MolForgeAction,
|
| 14 |
+
)
|
| 15 |
+
from ..scenarios import (
|
| 16 |
+
DEFAULT_TOOL_COSTS,
|
| 17 |
+
EDITABLE_SLOTS,
|
| 18 |
+
FRAGMENT_LIBRARY,
|
| 19 |
+
SLOT_ORDER,
|
| 20 |
+
SCENARIOS,
|
| 21 |
+
ScenarioConfig,
|
| 22 |
+
compute_objective_score,
|
| 23 |
+
enumerate_candidate_edits,
|
| 24 |
+
evaluate_constraint_margins,
|
| 25 |
+
evaluate_constraints,
|
| 26 |
+
evaluate_molecule,
|
| 27 |
+
format_molecule,
|
| 28 |
+
get_scenario,
|
| 29 |
+
literature_hints,
|
| 30 |
+
molecule_to_smiles,
|
| 31 |
+
oracle_backend_status,
|
| 32 |
+
)
|
| 33 |
+
except ImportError:
|
| 34 |
+
from models import (
|
| 35 |
+
AgentMessage,
|
| 36 |
+
AssayReading,
|
| 37 |
+
MolForgeAction,
|
| 38 |
+
)
|
| 39 |
+
from scenarios import (
|
| 40 |
+
DEFAULT_TOOL_COSTS,
|
| 41 |
+
EDITABLE_SLOTS,
|
| 42 |
+
FRAGMENT_LIBRARY,
|
| 43 |
+
SLOT_ORDER,
|
| 44 |
+
SCENARIOS,
|
| 45 |
+
ScenarioConfig,
|
| 46 |
+
compute_objective_score,
|
| 47 |
+
enumerate_candidate_edits,
|
| 48 |
+
evaluate_constraint_margins,
|
| 49 |
+
evaluate_constraints,
|
| 50 |
+
evaluate_molecule,
|
| 51 |
+
format_molecule,
|
| 52 |
+
get_scenario,
|
| 53 |
+
literature_hints,
|
| 54 |
+
molecule_to_smiles,
|
| 55 |
+
oracle_backend_status,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
ROLE_PERMISSIONS: Dict[str, List[str]] = {
|
| 60 |
+
"lead_chemist": ["edit", "submit", "restart", "defer"],
|
| 61 |
+
"toxicologist": [],
|
| 62 |
+
"assay_planner": ["run_assay"],
|
| 63 |
+
"process_chemist": [],
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
ROLE_MESSAGE_TYPES: Dict[str, List[str]] = {
|
| 67 |
+
"lead_chemist": ["proposal", "revision_request", "submission_recommendation"],
|
| 68 |
+
"toxicologist": ["approval", "objection", "risk_flag", "assay_request", "rejection"],
|
| 69 |
+
"assay_planner": ["proposal", "approval", "rejection", "assay_request", "submission_recommendation"],
|
| 70 |
+
"process_chemist": ["approval", "objection", "risk_flag", "assay_request"],
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MolForgeSharedMixin:
|
| 75 |
+
"""Utility methods shared across the environment mixins."""
|
| 76 |
+
|
| 77 |
+
def _merge_assays(self, readings: List[AssayReading]) -> None:
|
| 78 |
+
keyed = {
|
| 79 |
+
(reading.tool_name, reading.property_name, reading.molecule_signature): reading
|
| 80 |
+
for reading in self._known_assays
|
| 81 |
+
}
|
| 82 |
+
for reading in readings:
|
| 83 |
+
keyed[(reading.tool_name, reading.property_name, reading.molecule_signature)] = reading
|
| 84 |
+
self._known_assays = list(keyed.values())
|
| 85 |
+
|
| 86 |
+
def _current_property_estimate(
|
| 87 |
+
self,
|
| 88 |
+
property_name: str,
|
| 89 |
+
molecule_signature: Optional[str] = None,
|
| 90 |
+
) -> Optional[float]:
|
| 91 |
+
signature = molecule_signature or self._molecule_signature()
|
| 92 |
+
for reading in reversed(self._known_assays):
|
| 93 |
+
if reading.molecule_signature == signature and reading.property_name == property_name:
|
| 94 |
+
return reading.estimate
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
def _estimate_information_gain(self, tool_name: str) -> float:
|
| 98 |
+
current_signature = self._molecule_signature()
|
| 99 |
+
prior_runs = self._assay_runs.get(f"{current_signature}::{tool_name}", 0)
|
| 100 |
+
base = {
|
| 101 |
+
"evaluate_properties": 0.7,
|
| 102 |
+
"dock_target": 0.62,
|
| 103 |
+
"assay_toxicity": 0.78 if self._scenario.difficulty != "easy" else 0.52,
|
| 104 |
+
"estimate_synthesizability": 0.66 if "synth_min" in self._scenario.hard_constraints else 0.42,
|
| 105 |
+
"evaluate_novelty": 0.38,
|
| 106 |
+
"search_literature": 0.32,
|
| 107 |
+
"run_md_simulation": 0.84,
|
| 108 |
+
}.get(tool_name, 0.25)
|
| 109 |
+
decay = 0.4**prior_runs
|
| 110 |
+
return round(base * decay, 4)
|
| 111 |
+
|
| 112 |
+
def _simulate_action_properties(self, action: MolForgeAction) -> Dict[str, float]:
|
| 113 |
+
if action.action_type == "edit" and action.slot:
|
| 114 |
+
molecule = dict(self._molecule)
|
| 115 |
+
if action.edit_type == "remove":
|
| 116 |
+
defaults = {
|
| 117 |
+
"warhead": "nitrile",
|
| 118 |
+
"hinge": "pyridine",
|
| 119 |
+
"solvent_tail": "morpholine",
|
| 120 |
+
"back_pocket": "methoxy",
|
| 121 |
+
}
|
| 122 |
+
molecule[action.slot] = defaults[action.slot]
|
| 123 |
+
elif action.fragment:
|
| 124 |
+
molecule[action.slot] = action.fragment
|
| 125 |
+
return self._evaluate_for_molecule(molecule, self._trap_penalty_active)
|
| 126 |
+
|
| 127 |
+
if action.action_type == "restart":
|
| 128 |
+
return self._evaluate_for_molecule(dict(self._scenario.restart_scaffold), False)
|
| 129 |
+
|
| 130 |
+
return self._true_properties()
|
| 131 |
+
|
| 132 |
+
def _record_message(self, message: AgentMessage) -> None:
|
| 133 |
+
if not message.message_id:
|
| 134 |
+
message.message_id = f"msg_{self._state.step_count:03d}_{len(self._message_log):03d}"
|
| 135 |
+
self._message_log.append(deepcopy(message))
|
| 136 |
+
self._state.message_count += 1
|
| 137 |
+
self._role_metrics[message.sender]["messages_sent"] += 1
|
| 138 |
+
if message.message_type in {"objection", "risk_flag", "rejection"}:
|
| 139 |
+
self._state.objection_count += 1
|
| 140 |
+
|
| 141 |
+
def _sync_state_metadata(self) -> None:
|
| 142 |
+
self._state.metadata = {
|
| 143 |
+
"state_label": self._state.state_label,
|
| 144 |
+
"state_path": list(self._state_path),
|
| 145 |
+
"trace": deepcopy(self._history),
|
| 146 |
+
"message_log": [message.model_dump() for message in self._message_log],
|
| 147 |
+
"oracle_log": deepcopy(self._oracle_log),
|
| 148 |
+
"role_metrics": deepcopy(self._role_metrics),
|
| 149 |
+
"terminal_grader_scores": self._grade_all() if self._state.submitted else {},
|
| 150 |
+
}
|
| 151 |
+
if self._debug_state_enabled:
|
| 152 |
+
self._state.metadata["debug_hidden_properties"] = self._true_properties()
|
| 153 |
+
|
| 154 |
+
def _true_properties(self) -> Dict[str, float]:
|
| 155 |
+
return self._evaluate_for_molecule(self._molecule, self._trap_penalty_active)
|
| 156 |
+
|
| 157 |
+
def _evaluate_for_molecule(
|
| 158 |
+
self,
|
| 159 |
+
molecule: Mapping[str, str],
|
| 160 |
+
trap_penalty_active: bool,
|
| 161 |
+
) -> Dict[str, float]:
|
| 162 |
+
return evaluate_molecule(
|
| 163 |
+
molecule,
|
| 164 |
+
self._scenario.__class__(
|
| 165 |
+
**{**self._scenario.__dict__, "trap_penalty": trap_penalty_active}
|
| 166 |
+
),
|
| 167 |
+
target_shift_active=self._target_shift_active(),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def _target_shift_active(self) -> bool:
|
| 171 |
+
return bool(
|
| 172 |
+
self._scenario.target_shift_step
|
| 173 |
+
and self._state.step_count >= self._scenario.target_shift_step
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def _molecule_signature(self) -> str:
|
| 177 |
+
return format_molecule(self._molecule)
|
| 178 |
+
|
| 179 |
+
def _append_state_label(self, label: str) -> None:
|
| 180 |
+
if not self._state_path or self._state_path[-1] != label:
|
| 181 |
+
self._state_path.append(label)
|
| 182 |
+
|
| 183 |
+
def _safety_alerts(self) -> List[str]:
|
| 184 |
+
alerts = []
|
| 185 |
+
if self._molecule["solvent_tail"] == "dimethylamino":
|
| 186 |
+
alerts.append("Dimethylamino tail is a recurring liability for cardiac safety.")
|
| 187 |
+
if self._molecule["back_pocket"] == "trifluoromethyl":
|
| 188 |
+
alerts.append("Trifluoromethyl group may overshoot lipophilic safety windows.")
|
| 189 |
+
if self._molecule["hinge"] == "fluorophenyl" and self._molecule["back_pocket"] == "chloro":
|
| 190 |
+
alerts.append("Hydrophobic hinge/back-pocket combination looks safety-negative.")
|
| 191 |
+
return alerts
|
| 192 |
+
|
| 193 |
+
def _route_warnings(self) -> List[str]:
|
| 194 |
+
warnings = []
|
| 195 |
+
if self._molecule["hinge"] == "quinazoline":
|
| 196 |
+
warnings.append("Quinazoline hinge increases route complexity.")
|
| 197 |
+
if self._molecule["warhead"] == "vinyl_sulfonamide":
|
| 198 |
+
warnings.append("Vinyl sulfonamide warhead is reactive and harder to handle.")
|
| 199 |
+
if self._molecule["back_pocket"] == "trifluoromethyl":
|
| 200 |
+
warnings.append("CF3 substitution raises cost and scale-up complexity.")
|
| 201 |
+
return warnings
|
| 202 |
+
|
| 203 |
+
@staticmethod
|
| 204 |
+
def _empty_role_metrics() -> Dict[str, Dict[str, int]]:
|
| 205 |
+
return {
|
| 206 |
+
role: {"messages_sent": 0, "correct_messages": 0, "incorrect_messages": 0}
|
| 207 |
+
for role in ["lead_chemist", "toxicologist", "assay_planner", "process_chemist"]
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
def _open_unit_interval(value: float, epsilon: float = 1e-4) -> float:
|
| 212 |
+
return round(min(max(value, epsilon), 1.0 - epsilon), 4)
|
| 213 |
+
|
| 214 |
+
@staticmethod
|
| 215 |
+
def _assay_estimate(
|
| 216 |
+
signature: str,
|
| 217 |
+
tool_name: str,
|
| 218 |
+
property_name: str,
|
| 219 |
+
runs: int,
|
| 220 |
+
true_value: float,
|
| 221 |
+
) -> float:
|
| 222 |
+
digest = hashlib.sha256(
|
| 223 |
+
f"{signature}|{tool_name}|{property_name}|{runs}".encode("utf-8")
|
| 224 |
+
).hexdigest()
|
| 225 |
+
centered = (int(digest[:8], 16) / 0xFFFFFFFF) - 0.5
|
| 226 |
+
noise = centered * (0.16 / runs)
|
| 227 |
+
return round(min(max(true_value + noise, 0.0), 1.0), 4)
|
server/views.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Observation building and scoring mixin for MolForge."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from typing import Any, Dict, List, Mapping
|
| 7 |
+
|
| 8 |
+
from .shared import (
|
| 9 |
+
DEFAULT_TOOL_COSTS,
|
| 10 |
+
EDITABLE_SLOTS,
|
| 11 |
+
ROLE_MESSAGE_TYPES,
|
| 12 |
+
ROLE_PERMISSIONS,
|
| 13 |
+
SCENARIOS,
|
| 14 |
+
SLOT_ORDER,
|
| 15 |
+
compute_objective_score,
|
| 16 |
+
enumerate_candidate_edits,
|
| 17 |
+
evaluate_constraint_margins,
|
| 18 |
+
evaluate_constraints,
|
| 19 |
+
literature_hints,
|
| 20 |
+
molecule_to_smiles,
|
| 21 |
+
oracle_backend_status,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from ..models import ConstraintCheck, MolForgeObservation, MoleculeSlot, RoleObservation
|
| 26 |
+
except ImportError:
|
| 27 |
+
from models import ConstraintCheck, MolForgeObservation, MoleculeSlot, RoleObservation
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MolForgeViewMixin:
|
| 31 |
+
"""Observation, report-card, and grader methods."""
|
| 32 |
+
|
| 33 |
+
def _build_observation(
|
| 34 |
+
self,
|
| 35 |
+
*,
|
| 36 |
+
reward: float,
|
| 37 |
+
done: bool,
|
| 38 |
+
reward_components: List,
|
| 39 |
+
) -> MolForgeObservation:
|
| 40 |
+
current_signature = self._molecule_signature()
|
| 41 |
+
current_assays = [
|
| 42 |
+
reading for reading in self._known_assays if reading.molecule_signature == current_signature
|
| 43 |
+
]
|
| 44 |
+
visible_metrics = {
|
| 45 |
+
"budget_fraction_remaining": round(
|
| 46 |
+
self._state.remaining_budget / max(self._scenario.oracle_budget, 1), 4
|
| 47 |
+
),
|
| 48 |
+
"current_molecule_assay_count": float(len(current_assays)),
|
| 49 |
+
}
|
| 50 |
+
for property_name in ["potency", "toxicity", "synth", "novelty"]:
|
| 51 |
+
estimate = self._current_property_estimate(property_name, current_signature)
|
| 52 |
+
if estimate is not None:
|
| 53 |
+
visible_metrics[property_name] = estimate
|
| 54 |
+
|
| 55 |
+
constraint_status = self._build_visible_constraints(current_signature)
|
| 56 |
+
metadata: Dict[str, Any] = {
|
| 57 |
+
"task_index": self._reset_index % len(SCENARIOS),
|
| 58 |
+
"oracle_budget_costs": deepcopy(DEFAULT_TOOL_COSTS),
|
| 59 |
+
"history_length": len(self._history),
|
| 60 |
+
"trace_tail": [entry["summary"] for entry in self._history[-3:]],
|
| 61 |
+
"current_smiles": molecule_to_smiles(self._molecule),
|
| 62 |
+
"oracle_backend": oracle_backend_status(),
|
| 63 |
+
"candidate_edits": [
|
| 64 |
+
{"slot": slot, "fragment": fragment}
|
| 65 |
+
for slot, fragment in list(enumerate_candidate_edits(self._molecule))[:8]
|
| 66 |
+
],
|
| 67 |
+
"literature_hints": literature_hints(self._molecule),
|
| 68 |
+
"target_shift_active": self._target_shift_active(),
|
| 69 |
+
"public_role_metrics": {
|
| 70 |
+
role: {
|
| 71 |
+
"messages_sent": metrics["messages_sent"],
|
| 72 |
+
"correct_messages": metrics["correct_messages"],
|
| 73 |
+
}
|
| 74 |
+
for role, metrics in self._role_metrics.items()
|
| 75 |
+
},
|
| 76 |
+
}
|
| 77 |
+
if done:
|
| 78 |
+
metadata["terminal_grader_scores"] = self._grade_all()
|
| 79 |
+
|
| 80 |
+
return MolForgeObservation(
|
| 81 |
+
scenario_id=self._scenario.scenario_id,
|
| 82 |
+
difficulty=self._scenario.difficulty,
|
| 83 |
+
state_label=self._state.state_label,
|
| 84 |
+
state_path=list(self._state_path),
|
| 85 |
+
coordination_mode=self._scenario.coordination_mode, # type: ignore[arg-type]
|
| 86 |
+
enabled_roles=list(self._scenario.enabled_roles),
|
| 87 |
+
task_brief=self._scenario.task_brief,
|
| 88 |
+
target_name=self._scenario.target_name,
|
| 89 |
+
current_molecule=current_signature,
|
| 90 |
+
molecule_slots=[
|
| 91 |
+
MoleculeSlot(slot=slot, fragment=self._molecule[slot], editable=True)
|
| 92 |
+
for slot in SLOT_ORDER
|
| 93 |
+
],
|
| 94 |
+
editable_slots=list(EDITABLE_SLOTS),
|
| 95 |
+
step_index=self._state.step_count,
|
| 96 |
+
max_steps=self._scenario.max_steps,
|
| 97 |
+
remaining_budget=self._state.remaining_budget,
|
| 98 |
+
budget_used=self._state.budget_used,
|
| 99 |
+
max_budget=self._scenario.oracle_budget,
|
| 100 |
+
known_assays=deepcopy(self._known_assays),
|
| 101 |
+
role_observations=self._build_role_observations(current_signature),
|
| 102 |
+
message_log=[message.model_dump() for message in self._message_log[-8:]],
|
| 103 |
+
governance=deepcopy(self._last_governance),
|
| 104 |
+
last_transition_summary=self._last_summary,
|
| 105 |
+
visible_metrics=visible_metrics,
|
| 106 |
+
constraint_status=constraint_status,
|
| 107 |
+
reward_breakdown=reward_components,
|
| 108 |
+
allowed_actions=[
|
| 109 |
+
"Lead Chemist: edit, submit, restart, defer",
|
| 110 |
+
"Assay Planner: run_assay",
|
| 111 |
+
"Messages: proposal, approval, objection, risk_flag, assay_request, rejection",
|
| 112 |
+
],
|
| 113 |
+
report_card=self._report_card,
|
| 114 |
+
metadata=metadata,
|
| 115 |
+
done=done,
|
| 116 |
+
reward=reward,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def _build_visible_constraints(self, molecule_signature: str) -> List[ConstraintCheck]:
|
| 120 |
+
checks: List[ConstraintCheck] = []
|
| 121 |
+
for name, threshold in self._scenario.hard_constraints.items():
|
| 122 |
+
property_name = "toxicity" if name == "toxicity_max" else name.split("_")[0]
|
| 123 |
+
estimate = self._current_property_estimate(property_name, molecule_signature)
|
| 124 |
+
relation = "<=" if name.endswith("_max") else ">="
|
| 125 |
+
if estimate is None:
|
| 126 |
+
checks.append(
|
| 127 |
+
ConstraintCheck(
|
| 128 |
+
name=name,
|
| 129 |
+
target=f"{relation} {threshold:.2f}",
|
| 130 |
+
satisfied=None,
|
| 131 |
+
actual=None,
|
| 132 |
+
evidence_status="unknown",
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
continue
|
| 136 |
+
satisfied = estimate <= threshold if name.endswith("_max") else estimate >= threshold
|
| 137 |
+
checks.append(
|
| 138 |
+
ConstraintCheck(
|
| 139 |
+
name=name,
|
| 140 |
+
target=f"{relation} {threshold:.2f}",
|
| 141 |
+
satisfied=satisfied,
|
| 142 |
+
actual=round(estimate, 4),
|
| 143 |
+
evidence_status="known",
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
return checks
|
| 147 |
+
|
| 148 |
+
def _build_role_observations(self, molecule_signature: str) -> List[RoleObservation]:
|
| 149 |
+
current_assays = [
|
| 150 |
+
reading.model_dump()
|
| 151 |
+
for reading in self._known_assays
|
| 152 |
+
if reading.molecule_signature == molecule_signature
|
| 153 |
+
]
|
| 154 |
+
evidence_gaps = [
|
| 155 |
+
prop
|
| 156 |
+
for prop in ["potency", "toxicity", "synth"]
|
| 157 |
+
if self._current_property_estimate(prop, molecule_signature) is None
|
| 158 |
+
]
|
| 159 |
+
edit_history = [
|
| 160 |
+
entry["action"]
|
| 161 |
+
for entry in self._history
|
| 162 |
+
if entry["action"].get("action_type") == "edit"
|
| 163 |
+
][-4:]
|
| 164 |
+
|
| 165 |
+
return [
|
| 166 |
+
RoleObservation(
|
| 167 |
+
role="lead_chemist",
|
| 168 |
+
local_objective="Propose high-value scaffold edits and decide when the team should submit.",
|
| 169 |
+
permissions=ROLE_PERMISSIONS["lead_chemist"],
|
| 170 |
+
observation={
|
| 171 |
+
"molecule_slots": deepcopy(self._molecule),
|
| 172 |
+
"edit_history": edit_history,
|
| 173 |
+
"visible_assays": current_assays,
|
| 174 |
+
"candidate_edits": [
|
| 175 |
+
{"slot": slot, "fragment": fragment}
|
| 176 |
+
for slot, fragment in list(enumerate_candidate_edits(self._molecule))[:8]
|
| 177 |
+
],
|
| 178 |
+
"open_questions": evidence_gaps,
|
| 179 |
+
},
|
| 180 |
+
),
|
| 181 |
+
RoleObservation(
|
| 182 |
+
role="toxicologist",
|
| 183 |
+
local_objective="Protect against safety regressions and unsafe submissions.",
|
| 184 |
+
permissions=ROLE_MESSAGE_TYPES["toxicologist"],
|
| 185 |
+
observation={
|
| 186 |
+
"toxicity_readouts": [
|
| 187 |
+
reading
|
| 188 |
+
for reading in current_assays
|
| 189 |
+
if reading["property_name"] == "toxicity"
|
| 190 |
+
],
|
| 191 |
+
"hard_threshold": self._scenario.hard_constraints.get("toxicity_max"),
|
| 192 |
+
"safety_alerts": self._safety_alerts(),
|
| 193 |
+
"risk_history": [
|
| 194 |
+
message.model_dump()
|
| 195 |
+
for message in self._message_log
|
| 196 |
+
if message.sender == "toxicologist"
|
| 197 |
+
][-4:],
|
| 198 |
+
},
|
| 199 |
+
),
|
| 200 |
+
RoleObservation(
|
| 201 |
+
role="assay_planner",
|
| 202 |
+
local_objective="Allocate assay budget where the expected information gain is highest.",
|
| 203 |
+
permissions=ROLE_PERMISSIONS["assay_planner"] + ROLE_MESSAGE_TYPES["assay_planner"],
|
| 204 |
+
observation={
|
| 205 |
+
"budget_ledger": {
|
| 206 |
+
"remaining_budget": self._state.remaining_budget,
|
| 207 |
+
"budget_used": self._state.budget_used,
|
| 208 |
+
"max_budget": self._state.max_budget,
|
| 209 |
+
},
|
| 210 |
+
"tool_costs": deepcopy(DEFAULT_TOOL_COSTS),
|
| 211 |
+
"tool_usage_history": deepcopy(self._assay_runs),
|
| 212 |
+
"evidence_gaps": evidence_gaps,
|
| 213 |
+
"estimated_information_value": {
|
| 214 |
+
tool_name: round(self._estimate_information_gain(tool_name), 4)
|
| 215 |
+
for tool_name in self._scenario.enabled_tools
|
| 216 |
+
},
|
| 217 |
+
},
|
| 218 |
+
),
|
| 219 |
+
RoleObservation(
|
| 220 |
+
role="process_chemist",
|
| 221 |
+
local_objective="Guard tractability and synthetic feasibility before the team commits.",
|
| 222 |
+
permissions=ROLE_MESSAGE_TYPES["process_chemist"],
|
| 223 |
+
observation={
|
| 224 |
+
"synth_readouts": [
|
| 225 |
+
reading for reading in current_assays if reading["property_name"] == "synth"
|
| 226 |
+
],
|
| 227 |
+
"route_warnings": self._route_warnings(),
|
| 228 |
+
"feasibility_flags": {
|
| 229 |
+
"heavy_hinge": self._molecule["hinge"] == "quinazoline",
|
| 230 |
+
"reactive_warhead": self._molecule["warhead"] == "vinyl_sulfonamide",
|
| 231 |
+
"lipophilic_tail": self._molecule["back_pocket"] == "trifluoromethyl",
|
| 232 |
+
},
|
| 233 |
+
},
|
| 234 |
+
),
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
def _grade_all(self) -> Dict[str, float]:
|
| 238 |
+
properties = self._true_properties()
|
| 239 |
+
constraints = evaluate_constraints(properties, self._scenario)
|
| 240 |
+
constraint_margins = evaluate_constraint_margins(properties, self._scenario)
|
| 241 |
+
constraint_margin_score = sum(constraint_margins.values()) / max(len(constraint_margins), 1)
|
| 242 |
+
constraint_fraction = sum(1.0 for passed, _ in constraints.values() if passed) / max(len(constraints), 1)
|
| 243 |
+
submitted = self._state.submitted
|
| 244 |
+
coordination_score = self._coordination_score()
|
| 245 |
+
evidence_score = self._evidence_score()
|
| 246 |
+
budget_score = self._open_unit_interval(
|
| 247 |
+
self._state.remaining_budget / max(self._scenario.oracle_budget, 1),
|
| 248 |
+
)
|
| 249 |
+
progress_score = self._grade_progress(
|
| 250 |
+
candidate_score=compute_objective_score(properties, self._scenario),
|
| 251 |
+
constraint_margin_score=constraint_margin_score,
|
| 252 |
+
constraint_fraction=constraint_fraction,
|
| 253 |
+
evidence_score=evidence_score,
|
| 254 |
+
coordination_score=coordination_score,
|
| 255 |
+
budget_score=budget_score,
|
| 256 |
+
)
|
| 257 |
+
submission_score = self._grade_submission(properties) if submitted else 0.0
|
| 258 |
+
final_score = self._grade_final(
|
| 259 |
+
submission_score=submission_score,
|
| 260 |
+
progress_score=progress_score,
|
| 261 |
+
submitted=submitted,
|
| 262 |
+
constraint_fraction=constraint_fraction,
|
| 263 |
+
evidence_score=evidence_score,
|
| 264 |
+
)
|
| 265 |
+
return {
|
| 266 |
+
"final_score": final_score,
|
| 267 |
+
"potency_score": self._open_unit_interval(properties["potency"]),
|
| 268 |
+
"safety_score": self._open_unit_interval(1.0 - properties["toxicity"]),
|
| 269 |
+
"synth_score": self._open_unit_interval(properties["synth"]),
|
| 270 |
+
"novelty_score": self._open_unit_interval(properties["novelty"]),
|
| 271 |
+
"candidate_score": self._open_unit_interval(compute_objective_score(properties, self._scenario)),
|
| 272 |
+
"constraint_score": self._open_unit_interval(
|
| 273 |
+
sum(1.0 for passed, _ in constraints.values() if passed) / max(len(constraints), 1),
|
| 274 |
+
),
|
| 275 |
+
"constraint_margin_score": self._open_unit_interval(constraint_margin_score),
|
| 276 |
+
"budget_score": budget_score,
|
| 277 |
+
"submitted_score": 1.0 if submitted else 0.0,
|
| 278 |
+
"submission_score": submission_score,
|
| 279 |
+
"progress_score": progress_score,
|
| 280 |
+
"coordination_score": self._open_unit_interval(coordination_score),
|
| 281 |
+
"evidence_score": self._open_unit_interval(evidence_score),
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
def _grade_progress(
|
| 285 |
+
self,
|
| 286 |
+
*,
|
| 287 |
+
candidate_score: float,
|
| 288 |
+
constraint_margin_score: float,
|
| 289 |
+
constraint_fraction: float,
|
| 290 |
+
evidence_score: float,
|
| 291 |
+
coordination_score: float,
|
| 292 |
+
budget_score: float,
|
| 293 |
+
) -> float:
|
| 294 |
+
"""Score scientific progress even when no formal submission happened."""
|
| 295 |
+
|
| 296 |
+
progress = (
|
| 297 |
+
0.45 * candidate_score
|
| 298 |
+
+ 0.35 * constraint_margin_score
|
| 299 |
+
+ 0.10 * evidence_score
|
| 300 |
+
+ 0.05 * coordination_score
|
| 301 |
+
+ 0.05 * budget_score
|
| 302 |
+
)
|
| 303 |
+
repeated_assays = sum(max(0, runs - 1) for runs in self._assay_runs.values())
|
| 304 |
+
policy_vetoes = sum(
|
| 305 |
+
1
|
| 306 |
+
for entry in self._history
|
| 307 |
+
if entry.get("governance", {}).get("status") == "policy_veto"
|
| 308 |
+
)
|
| 309 |
+
progress -= min(0.20, 0.04 * repeated_assays)
|
| 310 |
+
progress -= min(0.20, 0.05 * policy_vetoes)
|
| 311 |
+
|
| 312 |
+
if constraint_fraction < 1.0:
|
| 313 |
+
progress = min(progress, 0.25 + 0.25 * constraint_fraction)
|
| 314 |
+
if not self._state.submitted and evidence_score < 0.99:
|
| 315 |
+
progress = min(progress, 0.45)
|
| 316 |
+
if self._scenario.trap_penalty and not self._restart_used:
|
| 317 |
+
progress = min(progress, 0.30)
|
| 318 |
+
if self._state.submitted:
|
| 319 |
+
progress += 0.05
|
| 320 |
+
return self._open_unit_interval(progress)
|
| 321 |
+
|
| 322 |
+
def _grade_final(
|
| 323 |
+
self,
|
| 324 |
+
*,
|
| 325 |
+
submission_score: float,
|
| 326 |
+
progress_score: float,
|
| 327 |
+
submitted: bool,
|
| 328 |
+
constraint_fraction: float,
|
| 329 |
+
evidence_score: float,
|
| 330 |
+
) -> float:
|
| 331 |
+
"""Single conservative scalar for RL/evaluation headline reporting."""
|
| 332 |
+
|
| 333 |
+
if submitted:
|
| 334 |
+
return self._open_unit_interval(submission_score)
|
| 335 |
+
|
| 336 |
+
score = 0.35 * progress_score
|
| 337 |
+
if constraint_fraction < 1.0:
|
| 338 |
+
score = min(score, 0.05 + 0.10 * constraint_fraction)
|
| 339 |
+
if evidence_score < 0.99:
|
| 340 |
+
score = min(score, 0.15)
|
| 341 |
+
if self._scenario.trap_penalty and not self._restart_used:
|
| 342 |
+
score = min(score, 0.08)
|
| 343 |
+
return self._open_unit_interval(score)
|
| 344 |
+
|
| 345 |
+
def _coordination_score(self) -> float:
|
| 346 |
+
expected_messages = 0
|
| 347 |
+
for entry in self._history:
|
| 348 |
+
action = entry.get("action", {})
|
| 349 |
+
if action.get("action_type") == "defer":
|
| 350 |
+
continue
|
| 351 |
+
expected_messages += 1 + len(entry.get("governance", {}).get("required_roles", []))
|
| 352 |
+
if expected_messages == 0:
|
| 353 |
+
return self._open_unit_interval(0.0)
|
| 354 |
+
total_correct = sum(metrics["correct_messages"] for metrics in self._role_metrics.values())
|
| 355 |
+
return self._open_unit_interval(min(total_correct, expected_messages) / expected_messages)
|
| 356 |
+
|
| 357 |
+
def _grade_submission(self, properties: Mapping[str, float]) -> float:
|
| 358 |
+
base = compute_objective_score(properties, self._scenario)
|
| 359 |
+
constraint_margins = evaluate_constraint_margins(properties, self._scenario)
|
| 360 |
+
constraint_margin_score = sum(constraint_margins.values()) / max(len(constraint_margins), 1)
|
| 361 |
+
constraints = evaluate_constraints(properties, self._scenario)
|
| 362 |
+
constraint_fraction = sum(1.0 for passed, _ in constraints.values() if passed) / max(len(constraints), 1)
|
| 363 |
+
submission_score = (
|
| 364 |
+
0.60 * base
|
| 365 |
+
+ 0.20 * constraint_margin_score
|
| 366 |
+
+ 0.10 * self._coordination_score()
|
| 367 |
+
+ 0.10 * self._evidence_score()
|
| 368 |
+
)
|
| 369 |
+
evidence_score = self._evidence_score()
|
| 370 |
+
if evidence_score >= 0.99 and constraint_fraction >= 1.0 and base >= self._scenario.baseline_to_beat:
|
| 371 |
+
budget_efficiency = self._state.remaining_budget / max(self._scenario.oracle_budget, 1)
|
| 372 |
+
submission_score += 0.05 * max(0.0, budget_efficiency)
|
| 373 |
+
if evidence_score < 1.0:
|
| 374 |
+
submission_score = min(submission_score, 0.25 + 0.25 * evidence_score)
|
| 375 |
+
if constraint_fraction < 1.0:
|
| 376 |
+
submission_score = min(submission_score, 0.20 + 0.50 * constraint_margin_score)
|
| 377 |
+
if base < self._scenario.baseline_to_beat:
|
| 378 |
+
submission_score = min(submission_score, 0.45)
|
| 379 |
+
return self._open_unit_interval(submission_score)
|
| 380 |
+
|
| 381 |
+
def _evidence_score(self) -> float:
|
| 382 |
+
current_signature = self._molecule_signature()
|
| 383 |
+
required = ["potency", "toxicity"]
|
| 384 |
+
if "synth_min" in self._scenario.hard_constraints:
|
| 385 |
+
required.append("synth")
|
| 386 |
+
available = sum(
|
| 387 |
+
1
|
| 388 |
+
for prop in required
|
| 389 |
+
if self._current_property_estimate(prop, current_signature) is not None
|
| 390 |
+
)
|
| 391 |
+
score = available / max(len(required), 1)
|
| 392 |
+
if self._scenario.target_shift_step and self._target_shift_active():
|
| 393 |
+
has_post_shift_potency = any(
|
| 394 |
+
entry["step"] >= self._scenario.target_shift_step
|
| 395 |
+
and entry["molecule"] == current_signature
|
| 396 |
+
and any(result["property_name"] == "potency" for result in entry["results"])
|
| 397 |
+
for entry in self._oracle_log
|
| 398 |
+
)
|
| 399 |
+
score = min(score, 1.0 if has_post_shift_potency else 0.5)
|
| 400 |
+
return score
|
| 401 |
+
|
| 402 |
+
def _build_report_card(self, *, submitted: bool) -> str:
|
| 403 |
+
properties = self._true_properties()
|
| 404 |
+
grader_scores = self._grade_all()
|
| 405 |
+
constraints = evaluate_constraints(properties, self._scenario)
|
| 406 |
+
lines = [
|
| 407 |
+
f"Scenario: {self._scenario.scenario_id} ({self._scenario.difficulty})",
|
| 408 |
+
f"Final molecule: {self._molecule_signature()}",
|
| 409 |
+
f"Potency: {properties['potency']:.3f}",
|
| 410 |
+
f"Toxicity: {properties['toxicity']:.3f}",
|
| 411 |
+
f"Synthesizability: {properties['synth']:.3f}",
|
| 412 |
+
f"Novelty: {properties['novelty']:.3f}",
|
| 413 |
+
f"Final score: {grader_scores['final_score']:.3f}",
|
| 414 |
+
f"Candidate scientific score: {grader_scores['candidate_score']:.3f}",
|
| 415 |
+
f"Constraint margin score: {grader_scores['constraint_margin_score']:.3f}",
|
| 416 |
+
f"Submission grader: {grader_scores['submission_score']:.3f}",
|
| 417 |
+
f"Progress score: {grader_scores['progress_score']:.3f}",
|
| 418 |
+
f"Coordination score: {grader_scores['coordination_score']:.3f}",
|
| 419 |
+
f"Evidence score: {grader_scores['evidence_score']:.3f}",
|
| 420 |
+
"Constraints:",
|
| 421 |
+
]
|
| 422 |
+
for name, (passed, threshold) in constraints.items():
|
| 423 |
+
metric_name = "toxicity" if name == "toxicity_max" else name.split("_")[0]
|
| 424 |
+
lines.append(
|
| 425 |
+
f"- {name}: {'pass' if passed else 'fail'} (actual={properties[metric_name]:.3f}, threshold={threshold:.3f})"
|
| 426 |
+
)
|
| 427 |
+
lines.append(
|
| 428 |
+
f"Messages sent: {self._state.message_count}, objections raised: {self._state.objection_count}, oracle calls: {self._state.oracle_call_count}"
|
| 429 |
+
)
|
| 430 |
+
if self._scenario.target_shift_step and self._target_shift_active():
|
| 431 |
+
lines.append("Target mutation triggered during this episode.")
|
| 432 |
+
if self._restart_used:
|
| 433 |
+
lines.append("Agent used restart_from_new_scaffold to escape the original trap series.")
|
| 434 |
+
if not submitted:
|
| 435 |
+
lines.append("Episode terminated without a formal submit action.")
|
| 436 |
+
return "\n".join(lines)
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|