Spaces:
Sleeping
Sleeping
Update CERNenv Space
Browse files- .dockerignore +9 -0
- .gitignore +21 -0
- .python-version +1 -0
- Dockerfile +31 -0
- README.md +58 -4
- [External] Apr ‘26 OpenEnv Hackathon Themes & Judging Criteria.txt +190 -0
- [External] Meta OpenEnv Hackathon Participant Help Guide.txt +291 -0
- client.py +37 -0
- models.py +600 -0
- openenv.yaml +6 -0
- pyproject.toml +61 -0
- scripts/__init__.py +0 -0
- scripts/_build_spaces.py +135 -0
- scripts/baseline_agents.py +305 -0
- scripts/push_to_hub.py +247 -0
- scripts/run_agent.py +129 -0
- server/Dockerfile +50 -0
- server/__init__.py +1 -0
- server/app.py +52 -0
- server/environment.py +363 -0
- server/requirements.txt +6 -0
- server/rewards/__init__.py +19 -0
- server/rewards/reward_function.py +283 -0
- server/rules/__init__.py +5 -0
- server/rules/engine.py +203 -0
- server/simulator/__init__.py +31 -0
- server/simulator/latent_state.py +171 -0
- server/simulator/noise.py +161 -0
- server/simulator/output_generator.py +586 -0
- server/simulator/transition.py +197 -0
- server/tasks/__init__.py +9 -0
- server/tasks/scenarios.py +422 -0
- space/__init__.py +0 -0
- space/env/Dockerfile +24 -0
- space/env/README.md +51 -0
- space/env/requirements.txt +6 -0
- space/training/Dockerfile +31 -0
- space/training/README.md +64 -0
- space/training/__init__.py +0 -0
- space/training/app.py +412 -0
- space/training/requirements.txt +18 -0
- training/__init__.py +1 -0
- training/colab_train_unsloth.ipynb +260 -0
- training/evaluate.py +152 -0
- training/llm_agent.py +227 -0
- training/plots.py +93 -0
- training/rollouts.py +160 -0
- training/training_script.py +211 -0
- training/training_unsloth.py +130 -0
.dockerignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.venv
|
| 3 |
+
__pycache__
|
| 4 |
+
*.pyc
|
| 5 |
+
.pytest_cache
|
| 6 |
+
training/runs
|
| 7 |
+
training/grpo-output
|
| 8 |
+
training/rollouts
|
| 9 |
+
notebooks/.ipynb_checkpoints
|
.gitignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.venv/
|
| 4 |
+
.env
|
| 5 |
+
.pytest_cache/
|
| 6 |
+
.coverage
|
| 7 |
+
htmlcov/
|
| 8 |
+
.DS_Store
|
| 9 |
+
training/runs/
|
| 10 |
+
training/grpo-output/
|
| 11 |
+
training/rollouts/
|
| 12 |
+
training/plots/
|
| 13 |
+
*.png
|
| 14 |
+
!docs/*.png
|
| 15 |
+
!assets/*.png
|
| 16 |
+
.ipynb_checkpoints/
|
| 17 |
+
.uv/
|
| 18 |
+
uv.lock
|
| 19 |
+
dist/
|
| 20 |
+
build/
|
| 21 |
+
*.egg-info/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CERNenv trainer Space (Docker, A100)
|
| 2 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 5 |
+
PYTHONUNBUFFERED=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
HF_HOME=/home/user/.cache/huggingface \
|
| 8 |
+
TRANSFORMERS_CACHE=/home/user/.cache/huggingface/transformers \
|
| 9 |
+
PYTHONPATH=/home/user/app
|
| 10 |
+
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
python3.11 python3.11-venv python3.11-dev python3-pip \
|
| 13 |
+
git curl ca-certificates build-essential \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 15 |
+
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python \
|
| 16 |
+
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python3
|
| 17 |
+
|
| 18 |
+
RUN useradd -ms /bin/bash user
|
| 19 |
+
USER user
|
| 20 |
+
ENV PATH="/home/user/.local/bin:${PATH}"
|
| 21 |
+
WORKDIR /home/user/app
|
| 22 |
+
|
| 23 |
+
COPY --chown=user:user space/training/requirements.txt /tmp/requirements.txt
|
| 24 |
+
RUN python -m pip install --upgrade pip && \
|
| 25 |
+
python -m pip install --user -r /tmp/requirements.txt
|
| 26 |
+
|
| 27 |
+
COPY --chown=user:user . /home/user/app
|
| 28 |
+
|
| 29 |
+
EXPOSE 7860
|
| 30 |
+
|
| 31 |
+
CMD ["python", "-m", "uvicorn", "space.training.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,64 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: indigo
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
|
|
|
| 7 |
pinned: false
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: CERNenv Trainer
|
| 3 |
+
emoji: ⚛️
|
| 4 |
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: docker
|
| 7 |
+
suggested_hardware: a100-large
|
| 8 |
+
suggested_storage: medium
|
| 9 |
pinned: false
|
| 10 |
+
license: bsd-3-clause
|
| 11 |
+
short_description: GRPO trainer for CERNenv (Unsloth + LoRA, A100)
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# CERNenv Trainer (Hugging Face Space, A100)
|
| 15 |
+
|
| 16 |
+
Fine-tunes a small instruction-tuned LLM (Large Language Model) to act as
|
| 17 |
+
an LHC (Large Hadron Collider) physicist inside the **CERNenv** OpenEnv
|
| 18 |
+
environment using **GRPO** (Group-Relative Policy Optimization),
|
| 19 |
+
**Unsloth**, and **LoRA** (Low-Rank Adaptation).
|
| 20 |
+
|
| 21 |
+
## Hardware
|
| 22 |
+
- Recommended: **A100 large (80 GB)**
|
| 23 |
+
- Minimum: T4 / L4 (will use a smaller model + fewer episodes)
|
| 24 |
+
|
| 25 |
+
## Required Space secrets
|
| 26 |
+
| Secret | Purpose |
|
| 27 |
+
| --- | --- |
|
| 28 |
+
| `HF_TOKEN` | Hugging Face token with `write` access for model push |
|
| 29 |
+
| `HF_USERNAME` | Hub username, used as the default model-repo owner |
|
| 30 |
+
|
| 31 |
+
## Optional environment variables
|
| 32 |
+
| Variable | Default | Notes |
|
| 33 |
+
| --- | --- | --- |
|
| 34 |
+
| `MODEL_NAME` | `unsloth/Qwen2.5-3B-Instruct` | Any chat model Unsloth supports |
|
| 35 |
+
| `TOTAL_EPISODES` | `400` | Prompts × generations rollouts |
|
| 36 |
+
| `DIFFICULTY` | `easy` | `easy` / `medium` / `hard` |
|
| 37 |
+
| `MAX_STEPS` | `18` | Steps per episode |
|
| 38 |
+
| `NUM_GENERATIONS` | `4` | GRPO group size |
|
| 39 |
+
| `OUTPUT_DIR` | `runs/unsloth-grpo` | LoRA adapter output |
|
| 40 |
+
| `PUSH_REPO` | `${HF_USERNAME}/cernenv-grpo-qwen2.5-3b` | Hub repo for adapters |
|
| 41 |
+
| `AUTOSTART` | `0` | Set to `1` to start training on Space boot |
|
| 42 |
+
|
| 43 |
+
## How to use
|
| 44 |
+
|
| 45 |
+
This Space exposes a tiny FastAPI control panel:
|
| 46 |
+
- `GET /` — status + current run info
|
| 47 |
+
- `POST /train` — start / restart a training run
|
| 48 |
+
- `GET /logs` — live tail of `training.log`
|
| 49 |
+
- `GET /metrics` — reward + success-rate snapshots
|
| 50 |
+
|
| 51 |
+
Click **"Start training"** in the UI, or set `AUTOSTART=1` in the Space variables to kick off immediately on boot.
|
| 52 |
+
|
| 53 |
+
When training finishes, the LoRA adapters are pushed to `PUSH_REPO`.
|
| 54 |
+
|
| 55 |
+
## Local equivalent
|
| 56 |
+
|
| 57 |
+
The same training run is reproducible locally with:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
PYTHONPATH=. python -m training.training_unsloth \
|
| 61 |
+
--model_name unsloth/Qwen2.5-3B-Instruct \
|
| 62 |
+
--difficulty easy --total_episodes 400 --max_steps 18 \
|
| 63 |
+
--output_dir runs/unsloth-grpo
|
| 64 |
+
```
|
[External] Apr ‘26 OpenEnv Hackathon Themes & Judging Criteria.txt
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Theme #1 - Multi-Agent Interactions
|
| 2 |
+
Environments for this theme involve cooperation, competition, negotiation, and coalition formation. Learning from these environments will enable agents to model the beliefs and incentives of others in partially observable settings. This drives theory-of-mind reasoning and emergent strategic behavior.
|
| 3 |
+
Expected Outcome: an environment that can be used to train multi-agent task handling in a LLM
|
| 4 |
+
Example environments: Market simulations, compute-allocation negotiations, collaborative puzzle worlds, mixed cooperative/competitive strategy games.
|
| 5 |
+
Theme #2 - (Super) Long-Horizon Planning & Instruction Following
|
| 6 |
+
You will build environments that require deep, multi-step reasoning with sparse or delayed rewards. After using these environments, the goal is to enable agents to decompose goals, track state over extended trajectories, and recover from early mistakes. The aim is to push beyond shallow next-token reasoning toward structured planning and durable internal representations.
|
| 7 |
+
Expected Outcome: an environment that can capture and improve LLM behaviour on challenging long horizon tasks that need long running sessions beyond context memory limits.
|
| 8 |
+
Example environments: (Think of OpenClaw workflows with Multi-turn tasks). Research-planning simulators, large-scale codebase refactoring tasks, strategic resource management worlds, long-horizon logistics optimization, extremely complicated long-horizon instruction following (e.g., 300 instructions scattered around).
|
| 9 |
+
Theme #3 - World Modeling
|
| 10 |
+
#3.1 Professional Tasks
|
| 11 |
+
Here you will develop environments that require real interaction with tools, APIs, or dynamic systems where the model is expected to do real hard work instead of exploiting short-cuts to arrive at the desired outcome. Learning from these environments will enable agents to maintain consistent internal state, update beliefs based on outcomes, and orchestrate multi-step workflows. The goal is to strengthen causal reasoning and persistent world models.
|
| 12 |
+
Expected Outcome: an environment capturing nuances of a defined partially observable world and improve LLM interaction with it
|
| 13 |
+
Example environments: Dynamic browser/API ecosystems, enterprise applications, scientific workflow loops (papers → code → experiments), economic simulations with feedback, tool-discovery benchmarks.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#3.2 Personalized Tasks
|
| 17 |
+
Here we will develop an environment that offers real personalized task handling, imagine replying to personal messages or handling dinner conflicts due to work conflicts, replying to tough emails. Think any personal assistant tasks
|
| 18 |
+
|
| 19 |
+
Expected Outcome: An environment that gives the model a realistic simulation of handling personal tasks, conflicts and managing them as delegations
|
| 20 |
+
|
| 21 |
+
Example environments: Executive Assistant Meeting Planner, Dinner and drive planning, email and message replying, shopping, etc
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
Theme #4 - Self-Improvement
|
| 25 |
+
The focus here is to create environments where agents can learn to generate new challenges, escalate difficulty, and improve through self-play or adaptive curricula. Rather than optimizing fixed tasks, the goal is for agents to learn to drive their own capability growth. The objective is recursive skill amplification.
|
| 26 |
+
Expected Outcome: an environment for improving self-play of a LLM over a defined set of tasks
|
| 27 |
+
Example environments: Self-play negotiation arenas, auto-generated math/proof tasks, evolving coding competitions, adaptive RL curricula.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
Theme #5: Wild Card - Impress Us!
|
| 31 |
+
We do not want to limit your focus if your idea doesn’t fit the boxes above, we want and WILL reward out of box tasks, please be creative but remember to add submissions that meaningfully add value to LLM training on a certain task.
|
| 32 |
+
Guidelines for Problem Statement
|
| 33 |
+
* It is NOT mandatory to choose the same problem statement as Round 1. Only choose the same problem statement if it aligns with the above provided Hackathon themes.
|
| 34 |
+
* You can start working on your problem statement once you have finalized it. Post-training can be done onsite on 25th & 26th when you receive compute credits for HuggingFace.
|
| 35 |
+
* Before the onsite, we suggest you work on building the environment, agent behaviours, reward model and evaluate if your work aligns with the judging criteria given below.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Judging Criteria
|
| 41 |
+
Minimum requirements:
|
| 42 |
+
* Usage of OpenEnv (latest release)
|
| 43 |
+
* Show a minimal training script for your environment using Unsloth or HF TRL in Colab
|
| 44 |
+
* Write a mini-blog on HuggingFace or mini-video on YouTube talking about your submission, <2 minutes
|
| 45 |
+
* Your OpenEnv compliant environment should be hosted on Hugging Face Spaces.
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
Judging Overview
|
| 49 |
+
* Evaluation: Teams will be scored based on the following criteria:
|
| 50 |
+
1. Environment Innovation (40%): Is the environment novel, creative, or challenging? Does it meaningfully test the agent’s behavior?
|
| 51 |
+
2. Storytelling (30%): Does the team clearly explain the problem, environment, and agent behavior? Is the demo engaging and easy to follow?
|
| 52 |
+
3. Showing Improvement in Rewards (20%): Does the demo provide observable evidence of training progress (reward curves, metrics, or before/after behavior)?
|
| 53 |
+
4. Reward and Training Script/Pipeline Setup (10%): Is the reward logic coherent, and does the pipeline produce meaningful improvement in the agent’s inference (how it acts in the environment)?
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
OpenEnv Hackathon - What Judges Look For
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
This guide tells you what makes a strong submission for the OpenEnv Hackathon (India 2026).
|
| 60 |
+
Read it before you start building, and again before you submit.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
For the list of themes and example problems, refer to the top sections.
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
NOTE: Please remember only one submission per team. If you have multiple ideas, pick the best one and go for it. Please make sure that the URL link of your environment is submitted as judges will pull the environment from the URL to evaluate it. Changes or commits after the submission deadline will not be considered.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
TL;DR
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
Build an environment that an LLM could actually be trained on to get measurably better at
|
| 73 |
+
something interesting. Then show that training. Then tell the story.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
A messy but ambitious environment with real training evidence beats a polished but boring one.
|
| 77 |
+
Pick a problem that excites you (that energy comes through in the pitch).
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
Judging Criteria
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
Criterion: Environment Innovation
|
| 84 |
+
Weight: 40%
|
| 85 |
+
What it means:
|
| 86 |
+
Is the environment novel, creative, or genuinely challenging?
|
| 87 |
+
Does it meaningfully test agent behavior in a way that hasn't been done before?
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
Criterion: Storytelling & Presentation
|
| 91 |
+
Weight: 30%
|
| 92 |
+
What it means:
|
| 93 |
+
Can you clearly explain the problem, the environment, and what the agent learned?
|
| 94 |
+
Is the demo engaging and easy to follow for a non-technical audience?
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
Criterion: Showing Improvement in Rewards
|
| 98 |
+
Weight: 20%
|
| 99 |
+
What it means:
|
| 100 |
+
Is there observable evidence of training progress? Reward curves, before/after behavior,
|
| 101 |
+
comparison against a baseline -- anything that proves the agent learned something.
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
Criterion: Reward & Training Pipeline
|
| 105 |
+
Weight: 10%
|
| 106 |
+
What it means:
|
| 107 |
+
Is the reward logic coherent? Does the pipeline produce meaningful improvement in the trained
|
| 108 |
+
agent's behavior?
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
Minimum Submission Requirements
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
NOTE: These are non-negotiable. Submissions missing any of these are at a serious disadvantage.
|
| 115 |
+
* Use OpenEnv (latest release). Build on top of the framework; don’t reinvent the wheel.
|
| 116 |
+
* A working training script using Unsloth or Hugging Face TRL, ideally as a Colab notebook so judges can re-run it.
|
| 117 |
+
* Evidence that you actually trained; at minimum, loss and reward plots from a real run.
|
| 118 |
+
* A short writeup: a mini-blog on Hugging Face or a < 2 minute video on YouTube explaining what your environment does and what you trained, or a short slide deck of presentation. Please make sure that all materials are linked from your README file so that judges can access them easily.
|
| 119 |
+
* Push your environment to a Hugging Face Space so it’s discoverable and runnable.
|
| 120 |
+
* A README that motivates the problem, explains how the env works, and shows results.
|
| 121 |
+
* README should have a link to the environment in the Hugging Face Space. It should also have all additional references to other materials (e.g. videos, blog posts, slides, presentations, etc.) that you want to include.
|
| 122 |
+
* Please do not include big video files in your Env submission on HF Hub as we would like to have a small size for each env (Please use url as reference link to additional materials).
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
What Makes a Submission Stand Out
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
Pick an ambitious, original problem
|
| 129 |
+
The themes (problems) are deliberately open. Use them as launching pads, not boxes. Judges have seen a lot of chess, snake, tic-tac-toe, and grid-world clones. To score well on innovation,
|
| 130 |
+
you need a genuinely fresh angle. Some questions to ask yourself:
|
| 131 |
+
* Does this environment exist to teach an LLM something it currently can’t do well?
|
| 132 |
+
* Is the domain underexplored in RL/LLM training?
|
| 133 |
+
* Could a researcher write a paper about training on this?
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
Design a reward signal that actually teaches
|
| 137 |
+
A great environment has a reward function that:
|
| 138 |
+
* Provides a rich, informative signal (not just 0/1 at the end)
|
| 139 |
+
* Captures something hard to measure in a clever way
|
| 140 |
+
* Uses OpenEnv’s Rubric system thoughtfully (composable rubrics > monolithic scoring)
|
| 141 |
+
* Is hard to game; an agent that exploits the reward without solving the task should not get high scores
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
Show real training, end to end
|
| 145 |
+
The bar isn’t “training script exists.” The bar is “training script runs against the environment, the
|
| 146 |
+
agent learns, and you can show it.” Concretely:
|
| 147 |
+
* Your training loop should connect to your environment (not a static dataset)
|
| 148 |
+
* Train long enough that the curves mean something
|
| 149 |
+
* Compare a trained agent vs. a random/untrained baseline; quantitative and/or qualitative
|
| 150 |
+
* Include the plots and numbers in your README and writeup
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
Make your plots readable
|
| 154 |
+
Reviewers spend seconds, not minutes, on each plot. Help them out:
|
| 155 |
+
* Label both axes (e.g. “training step” / “episode” on x, “reward” / “loss” on y) and include units where they apply
|
| 156 |
+
* Save plots as .png or .jpg and commit them to the repo (don’t leave them only in a Colab cell or a deleted Wandb run) (if you ran via Wandb, please include the link to that specific run of your plots)
|
| 157 |
+
* Embed the key plots in your README with a one-line caption explaining what each one shows If you have multiple runs (baseline vs. trained, ablations, etc.), put them on the same axes so the comparison is obvious
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
Tell a story, not an API doc
|
| 161 |
+
Your README, blog, and pitch should answer:
|
| 162 |
+
1. Problem) what capability gap or interesting domain are you targeting?
|
| 163 |
+
2. Environment) what does the agent see, do, and get rewarded for?
|
| 164 |
+
3. Results) what changed after training? Show it.
|
| 165 |
+
4. Why does it matter) who would care, and why?
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
A reviewer should be able to read your README in 3~5 minutes and want to try your
|
| 169 |
+
environment.
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
NOTE: If you have a video, HF post, or anything else interesting, please make sure that it’s linked
|
| 173 |
+
from your README as a link.
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
Engineer it cleanly (table stakes)
|
| 177 |
+
Engineering quality matters less than ambition, but sloppy work hurts. Make sure you:
|
| 178 |
+
* Use OpenEnv’s Environment / MCPEnvironment base classes properly
|
| 179 |
+
* Respect the client / server separation (clients should never import server internals)
|
| 180 |
+
* Follow the standard Gym-style API (reset, step, state)
|
| 181 |
+
* Have a valid openenv.yaml manifest
|
| 182 |
+
* Don’t use reserved tool names (reset, step, state, close) for MCP tools
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
Final Note
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
Judges are looking for environments that push the frontier of what we can train LLMs to do. Be
|
| 189 |
+
ambitious. Pick a problem you find genuinely interesting; that almost always produces better
|
| 190 |
+
work than chasing what you think judges want. Good luck.
|
[External] Meta OpenEnv Hackathon Participant Help Guide.txt
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Hackathon Self-Serve Guide: Build an RL Environment, Train an LLM, Ship a Demo
|
| 2 |
+
0) What you are building
|
| 3 |
+
The core idea is not just to fine-tune a text model, but to build a specialized LLM system that can act inside an environment, get feedback, and improve through reinforcement learning. The practical stack discussed here is:
|
| 4 |
+
Environment → verifier/reward functions → TRL trainer → Unsloth for efficiency → deployment on OpenEnv / Spaces.
|
| 5 |
+
A strong project usually looks like one of these,
|
| 6 |
+
Please refer to [External] Apr ‘26 OpenEnv Hackathon Themes for theme guidelines on selecting & forming problem statements.
|
| 7 |
+
1) Start with the right project idea
|
| 8 |
+
Pick a task that has all three of these properties:
|
| 9 |
+
1. The model can act step by step
|
| 10 |
+
2. You can verify success programmatically
|
| 11 |
+
3. The task is hard enough to be interesting, but not so hard that the model never succeeds
|
| 12 |
+
This last point matters a lot. RL only works if the probability of getting a good answer is greater than zero. If your task is so hard that the model never gets any reward, you will burn compute and learn nothing.
|
| 13 |
+
Please refer to [External] Apr ‘26 OpenEnv Hackathon Themes for theme guidelines on selecting & forming problem statements.
|
| 14 |
+
A useful rule: prefer tasks with crisp verification over tasks that only “look good” to a human. RL gets easier when the reward is objective.
|
| 15 |
+
2) Understand the minimum RL loop before you build
|
| 16 |
+
At a high level, your loop is:
|
| 17 |
+
1. Give the model a prompt
|
| 18 |
+
2. Let it generate an action, strategy, answer, or code
|
| 19 |
+
3. Execute that output in an environment or verifier
|
| 20 |
+
4. Convert the result into a reward
|
| 21 |
+
5. Update the model so higher-reward behavior becomes more likely
|
| 22 |
+
That is the practical mental model for RL here. The system samples many outputs, scores them, and shifts probability mass away from bad outputs and toward better ones.
|
| 23 |
+
One especially useful framing is that RL is like a more efficient version of repeated in-context improvement. Instead of repeatedly stuffing previous examples into the context, you let backpropagation store what worked into the weights.
|
| 24 |
+
3) Decide whether you need SFT first
|
| 25 |
+
Use this simple rule:
|
| 26 |
+
* If you have a lot of good data, use SFT
|
| 27 |
+
* If you do not have data but can verify outputs, use RL
|
| 28 |
+
* In many practical cases, do a little SFT first, then RL
|
| 29 |
+
Why this matters:
|
| 30 |
+
* SFT is generally more sample-efficient
|
| 31 |
+
* RL is useful when you can test outcomes but cannot cheaply author ideal traces
|
| 32 |
+
* RL often needs some warm start, formatting priming, or easy tasks first so that good rollouts happen at all
|
| 33 |
+
For hackathon teams, the best path is usually:
|
| 34 |
+
1. Start from a capable base/instruct model
|
| 35 |
+
2. Add light formatting or task scaffolding if needed
|
| 36 |
+
3. Use RL for improvement, not as magic from scratch
|
| 37 |
+
4) Design the environment before you design the trainer
|
| 38 |
+
Treat the environment as a first-class artifact. It should define:
|
| 39 |
+
* reset(): start a fresh episode
|
| 40 |
+
* step(action): apply an action and return the next result
|
| 41 |
+
* state() / observation: what the agent sees
|
| 42 |
+
* reward: what counts as progress or success
|
| 43 |
+
OpenEnv standardizes this so the same training code can work across many environments, instead of every team inventing a different API. That is one of the main reasons to use it in a hackathon.
|
| 44 |
+
Think about your environment in this order:
|
| 45 |
+
1. What does the agent observe?
|
| 46 |
+
2. What actions can it take?
|
| 47 |
+
3. What ends an episode?
|
| 48 |
+
4. How do you compute reward?
|
| 49 |
+
5. How do you stop abuse, infinite loops, or cheating?
|
| 50 |
+
5) Build the environment using OpenEnv
|
| 51 |
+
The intended workflow is to bootstrap an environment skeleton and then fill in the behavior. OpenEnv’s CLI creates the scaffolding for you. The environment is implemented as a Python package and exposed via a FastAPI app.
|
| 52 |
+
Your implementation typically defines:
|
| 53 |
+
* action dataclass
|
| 54 |
+
* observation dataclass
|
| 55 |
+
* state representation
|
| 56 |
+
* environment methods like reset and step
|
| 57 |
+
* FastAPI wrapper / client-server interface
|
| 58 |
+
That gives you a clean separation:
|
| 59 |
+
* the environment handles world dynamics and scoring,
|
| 60 |
+
* the trainer handles optimization,
|
| 61 |
+
* and the model just learns to act inside the interface.
|
| 62 |
+
6) Keep the task simple at first
|
| 63 |
+
Do not begin with your hardest benchmark. Start with the easiest version of your environment that still proves the concept. This is where curriculum learning helps.
|
| 64 |
+
A good progression:
|
| 65 |
+
1. easy tasks with short horizons,
|
| 66 |
+
2. medium tasks with a little more branching,
|
| 67 |
+
3. harder tasks only after the model starts getting non-zero reward.
|
| 68 |
+
The principle is simple: make success possible early. If the model never sees successful trajectories, learning stalls.
|
| 69 |
+
7) Design rewards carefully
|
| 70 |
+
Your reward function is your task specification. If it is weak, incomplete, or easy to exploit, the model will optimize the wrong thing very efficiently.
|
| 71 |
+
A strong reward design usually includes multiple components, for example:
|
| 72 |
+
* execution success,
|
| 73 |
+
* correctness,
|
| 74 |
+
* format compliance,
|
| 75 |
+
* timeouts,
|
| 76 |
+
* resource usage,
|
| 77 |
+
* safety constraints,
|
| 78 |
+
* and anti-cheating checks.
|
| 79 |
+
One explicit recommendation was to use multiple independent reward functions, not just one. If you only have a single reward signal, it is easier for the model to hack it. Multiple independent checks reduce that risk.
|
| 80 |
+
For example, for a coding environment:
|
| 81 |
+
* reward passing tests,
|
| 82 |
+
* penalize timeouts,
|
| 83 |
+
* reward format compliance,
|
| 84 |
+
* reject use of forbidden globals,
|
| 85 |
+
* and separately verify the function contract.
|
| 86 |
+
8) Protect yourself against reward hacking
|
| 87 |
+
Reward hacking is one of the biggest practical failure modes. The model may learn shortcuts that maximize your reward without solving the real task. Examples mentioned include:
|
| 88 |
+
* editing timers,
|
| 89 |
+
* caching results,
|
| 90 |
+
* abusing globals,
|
| 91 |
+
* mutating protected state,
|
| 92 |
+
* or exploiting environment bugs.
|
| 93 |
+
What to do:
|
| 94 |
+
1. Use multiple independent reward functions
|
| 95 |
+
2. Lock down execution where possible
|
| 96 |
+
3. Add time limits
|
| 97 |
+
4. Avoid unrestricted global state
|
| 98 |
+
5. Sample outputs frequently and inspect them
|
| 99 |
+
6. Terminate or roll back runs if behavior drifts badly
|
| 100 |
+
A particularly practical recommendation was to use a locked-down function or restricted execution approach so the model cannot rely on undeclared globals or hidden cached state.
|
| 101 |
+
Also, do not just let training run forever without checking generations. Periodic human inspection is still necessary.
|
| 102 |
+
9) Use process-aware feedback when you can
|
| 103 |
+
Naively assigning the same final reward to every token is inefficient. If possible, use richer supervision that distinguishes good intermediate steps from bad ones. That is the idea behind process supervision.
|
| 104 |
+
In practice, this can be approximated by:
|
| 105 |
+
* line-by-line checks,
|
| 106 |
+
* step-level verifiers,
|
| 107 |
+
* program trace analysis,
|
| 108 |
+
* or LLM-as-a-judge for intermediate reasoning.
|
| 109 |
+
But be careful: LLM-as-a-judge can itself be gamed. Use it as one signal, not the only signal.
|
| 110 |
+
For a hackathon, outcome-based verification plus a few lightweight process checks is usually the sweet spot.
|
| 111 |
+
10) Pick the right training stack
|
| 112 |
+
The intended stack here is:
|
| 113 |
+
* TRL for RL training algorithms
|
| 114 |
+
* Unsloth to make RL training and inference more efficient
|
| 115 |
+
* OpenEnv to standardize environment interaction
|
| 116 |
+
This combination works because:
|
| 117 |
+
* OpenEnv gives you a common environment interface
|
| 118 |
+
* TRL gives you RL trainers like GRPO
|
| 119 |
+
* Unsloth reduces memory use and improves efficiency on top of TRL
|
| 120 |
+
One of the practical examples used the same prompt repeated many times, routed through an environment, with TRL driving training and Unsloth helping with performance.
|
| 121 |
+
11) Prefer GRPO / RLVR style training for verifiable tasks
|
| 122 |
+
The RL setup discussed here leans toward RL with verifiable rewards:
|
| 123 |
+
* instead of a learned reward model,
|
| 124 |
+
* use a verifier, test harness, regex check, executor, or environment.
|
| 125 |
+
GRPO was described as a more efficient evolution relative to older PPO-style setups, especially by simplifying away parts like the value model.
|
| 126 |
+
For hackathon purposes, the key practical takeaway is:
|
| 127 |
+
* if the task is verifiable,
|
| 128 |
+
* build the verifier first,
|
| 129 |
+
* then plug that verifier into RL training.
|
| 130 |
+
12) Keep inference fast
|
| 131 |
+
One important point: in RL for LLMs, inference can dominate total runtime. Over time, rollout generation often becomes the bottleneck, not the optimizer step.
|
| 132 |
+
That means your project speed depends heavily on:
|
| 133 |
+
* fast sampling,
|
| 134 |
+
* tight environment loops,
|
| 135 |
+
* low-overhead execution,
|
| 136 |
+
* and efficient model runtime.
|
| 137 |
+
This is one reason Unsloth matters in the stack, and another reason to avoid overly heavy environments early in the hackathon.
|
| 138 |
+
13) Deploy your environment early
|
| 139 |
+
OpenEnv environments are designed to be deployed as Hugging Face Spaces, which provide:
|
| 140 |
+
* a running server,
|
| 141 |
+
* a Git repository,
|
| 142 |
+
* and a container registry.
|
| 143 |
+
That gives you several ways to work:
|
| 144 |
+
* interact with the remote Space directly,
|
| 145 |
+
* install the client code from the repo,
|
| 146 |
+
* pull and run the container locally,
|
| 147 |
+
* or run the FastAPI app locally via Python/Uvicorn.
|
| 148 |
+
Why this is good for a hackathon:
|
| 149 |
+
* one shared source of truth,
|
| 150 |
+
* easier collaboration,
|
| 151 |
+
* easier demos,
|
| 152 |
+
* easier switching between local and remote execution.
|
| 153 |
+
A good habit is to deploy an early version of the environment before training seriously. That catches API and packaging issues early.
|
| 154 |
+
14) Scale only after the environment is stable
|
| 155 |
+
There was a dedicated tutorial flow around:
|
| 156 |
+
1. environment,
|
| 157 |
+
2. deployment,
|
| 158 |
+
3. scaling,
|
| 159 |
+
4. training with TRL and Wordle.
|
| 160 |
+
Follow the same order.
|
| 161 |
+
Do not start with scale. First confirm:
|
| 162 |
+
* reset works,
|
| 163 |
+
* step works,
|
| 164 |
+
* rewards are sensible,
|
| 165 |
+
* timeouts work,
|
| 166 |
+
* logs are visible,
|
| 167 |
+
* and the environment can be run locally and remotely.
|
| 168 |
+
Only then:
|
| 169 |
+
* increase batch sizes,
|
| 170 |
+
* duplicate prompts or tasks,
|
| 171 |
+
* expand task diversity,
|
| 172 |
+
* and benchmark throughput.
|
| 173 |
+
15) Monitor the right things during training
|
| 174 |
+
Do not watch only one scalar. Monitor:
|
| 175 |
+
* overall reward,
|
| 176 |
+
* individual reward function columns,
|
| 177 |
+
* success indicators,
|
| 178 |
+
* timeout frequency,
|
| 179 |
+
* and generated strategies over time.
|
| 180 |
+
A very concrete suggestion was:
|
| 181 |
+
* watch whether the reward is going up,
|
| 182 |
+
* and separately watch critical columns like “function works.”
|
| 183 |
+
Also inspect actual generations during training. A rising reward is not enough if the model is learning to exploit bugs.
|
| 184 |
+
16) Save models correctly
|
| 185 |
+
If you use QLoRA / LoRA-style training, be careful when saving. One explicit warning was:
|
| 186 |
+
Do not upcast a 4-bit model to 16-bit and then merge the LoRA weights naively. That can badly damage model quality. Instead, use the proper merged-save path, or use the adapters directly.
|
| 187 |
+
For participants, that means:
|
| 188 |
+
* keep your training save path simple,
|
| 189 |
+
* test post-training inference immediately,
|
| 190 |
+
* and do not leave export until the end.
|
| 191 |
+
17) How to structure your team over the hackathon
|
| 192 |
+
A very effective team split is:
|
| 193 |
+
Person A: Environment
|
| 194 |
+
* builds reset/step/state
|
| 195 |
+
* adds timeouts and safety constraints
|
| 196 |
+
* makes local and remote execution work
|
| 197 |
+
Person B: Verifier / Rewards
|
| 198 |
+
* writes multiple reward functions
|
| 199 |
+
* adds anti-hacking checks
|
| 200 |
+
* makes failure cases visible
|
| 201 |
+
Person C: Training
|
| 202 |
+
* sets up TRL + Unsloth
|
| 203 |
+
* runs experiments
|
| 204 |
+
* tracks metrics and generations
|
| 205 |
+
Person D: Demo / Product
|
| 206 |
+
* prepares the Space demo
|
| 207 |
+
* creates a simple interface
|
| 208 |
+
* records examples and final benchmarks
|
| 209 |
+
This split matches the way the stack naturally decomposes in practice.
|
| 210 |
+
18) A practical 1-day execution plan
|
| 211 |
+
Phase 1: Pick a narrow task
|
| 212 |
+
Choose a small, verifiable environment. Avoid huge long-horizon tasks first.
|
| 213 |
+
Phase 2: Build the environment
|
| 214 |
+
Use OpenEnv init, implement reset/step/state, and get a local loop working.
|
| 215 |
+
Phase 3: Build rewards
|
| 216 |
+
Add at least 2–4 independent reward checks, plus timeout and anti-cheat logic.
|
| 217 |
+
Phase 4: Deploy
|
| 218 |
+
Push to a Space or run locally via container/Uvicorn so teammates can use the same environment.
|
| 219 |
+
Phase 5: Train small
|
| 220 |
+
Run a tiny TRL + Unsloth experiment first. Look at outputs, not just metrics.
|
| 221 |
+
Phase 6: Inspect for hacking
|
| 222 |
+
Sample generations. Check for globals, hacks, environment abuse, or suspicious shortcuts.
|
| 223 |
+
Phase 7: Add curriculum
|
| 224 |
+
If the model gets zero reward too often, simplify tasks or add easier start states.
|
| 225 |
+
Phase 8: Train bigger
|
| 226 |
+
Only after the loop is stable should you increase scale, batch size, or environment diversity.
|
| 227 |
+
Phase 9: Save and demo
|
| 228 |
+
Export the trained model correctly, test inference, and show before/after behavior.
|
| 229 |
+
19) What judges or reviewers will likely find compelling
|
| 230 |
+
The strongest hackathon projects usually show:
|
| 231 |
+
* a clear environment design,
|
| 232 |
+
* objective reward functions,
|
| 233 |
+
* evidence that the model improved,
|
| 234 |
+
* prevention against reward hacking,
|
| 235 |
+
* a reproducible deployment story,
|
| 236 |
+
* and a sharp demo.
|
| 237 |
+
A simple but strong demo format is:
|
| 238 |
+
1. baseline model attempt,
|
| 239 |
+
2. reward/verifier output,
|
| 240 |
+
3. trained model attempt,
|
| 241 |
+
4. measurable improvement,
|
| 242 |
+
5. short explanation of safeguards.
|
| 243 |
+
20) Suggested problem statement theme directions
|
| 244 |
+
Please Refer to [External] Apr ‘26 OpenEnv Hackathon Themes
|
| 245 |
+
21) Common mistakes to avoid
|
| 246 |
+
* Picking a task so hard that success probability is zero
|
| 247 |
+
* Using only one reward function
|
| 248 |
+
* Not checking for reward hacking
|
| 249 |
+
* Training before the environment is stable
|
| 250 |
+
* Relying only on average reward and not inspecting outputs
|
| 251 |
+
* Forgetting timeouts and sandbox limits
|
| 252 |
+
* Saving LoRA/QLoRA models incorrectly
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
22) Learning Resources
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
(Recommended) RL Environment Lecture Chapters:
|
| 259 |
+
RL Mega Lecture
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
Module 1: Why OpenEnv? (~7 min)
|
| 265 |
+
▸ Workshop 8:02–15:05 — https://www.youtube.com/watch?v=1jU05MlENOI&t=482s
|
| 266 |
+
▸ Sanyam: RL loop, fragmented env APIs, OpenEnv as universal interface, Gymnasium spec + Docker
|
| 267 |
+
▸ Alt: Mega Lecture 40:01–46:00 — https://www.youtube.com/watch?v=Jew4lhAiqnw&t=2401s
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
Module 2: Using Existing Envs (~7.5 min)
|
| 271 |
+
▸ Workshop 35:33–43:05 — https://www.youtube.com/watch?v=1jU05MlENOI&t=2133s
|
| 272 |
+
▸ Ben: Hub org, env collections, 3 Space interfaces (server/repo/registry), from_hub
|
| 273 |
+
▸ Alt: Mega Lecture 1:24:11–1:30:00 — https://www.youtube.com/watch?v=Jew4lhAiqnw&t=5051s
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
Module 3: Deploying Envs (~9 min)
|
| 277 |
+
▸ Mega Lecture 1:30:00–1:39:07 — https://www.youtube.com/watch?v=Jew4lhAiqnw&t=5400s
|
| 278 |
+
▸ Ben: live openenv init, scaffold, running locally, openenv push, Docker run from Space
|
| 279 |
+
▸ Alt: Workshop 43:05–48:30 — https://www.youtube.com/watch?v=1jU05MlENOI&t=2585s
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
Module 4: Building Your Own (~6.5 min)
|
| 283 |
+
▸ Workshop 43:45–50:20 — https://www.youtube.com/watch?v=1jU05MlENOI&t=2625s
|
| 284 |
+
▸ Ben: scaffold files, business logic (reset/step), models, client, publishing
|
| 285 |
+
▸ Alt: Mega Lecture 1:33:30–1:39:07 — https://www.youtube.com/watch?v=Jew4lhAiqnw&t=5610s
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
Module 5: Training + TRL (~14 min)
|
| 289 |
+
▸ Mega Lecture 1:53:20–2:07:12 — https://www.youtube.com/watch?v=Jew4lhAiqnw&t=6800s
|
| 290 |
+
▸ Lewis: Wordle GRPO walkthrough — rollout function, reward shaping, GRPOTrainer, live training
|
| 291 |
+
▸ Alt: Workshop 22:24–34:12 — https://www.youtube.com/watch?v=1jU05MlENOI&t=1344s
|
client.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WebSocket client for CERNenv.
|
| 2 |
+
|
| 3 |
+
Wraps OpenEnv's ``EnvClient`` so users can ``await client.reset()`` and
|
| 4 |
+
``await client.step(action)`` against a running CERNenv server.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any, Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core import EnvClient
|
| 12 |
+
from openenv.core.client_types import StepResult
|
| 13 |
+
|
| 14 |
+
from models import CollisionObservation, ExperimentAction
|
| 15 |
+
from server.environment import CernState
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CernEnv(EnvClient[ExperimentAction, CollisionObservation, CernState]):
|
| 19 |
+
"""Async WebSocket client for the CERN environment."""
|
| 20 |
+
|
| 21 |
+
def _step_payload(self, action: ExperimentAction) -> Dict[str, Any]:
|
| 22 |
+
return action.model_dump()
|
| 23 |
+
|
| 24 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[CollisionObservation]:
|
| 25 |
+
obs_data = payload.get("observation", payload)
|
| 26 |
+
observation = CollisionObservation(**obs_data)
|
| 27 |
+
return StepResult(
|
| 28 |
+
observation=observation,
|
| 29 |
+
reward=payload.get("reward", observation.reward),
|
| 30 |
+
done=payload.get("done", observation.done),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def _parse_state(self, payload: Dict[str, Any]) -> CernState:
|
| 34 |
+
return CernState(**payload)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
__all__ = ["CernEnv"]
|
models.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for CERNenv: an LHC (Large Hadron Collider) style particle
|
| 3 |
+
physics discovery POMDP (Partially Observable Markov Decision Process).
|
| 4 |
+
|
| 5 |
+
The agent is a Large Language Model (LLM) acting as a high-energy physicist.
|
| 6 |
+
Each step it picks one structured action (configure beams, allocate
|
| 7 |
+
luminosity, run a trigger, fit a spectrum, request systematics, submit a
|
| 8 |
+
discovery claim, etc.) and receives a noisy detector-style observation.
|
| 9 |
+
The latent particle and detector parameters are the hidden ground truth.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from enum import Enum
|
| 15 |
+
from typing import Any, Dict, List, Optional
|
| 16 |
+
|
| 17 |
+
from pydantic import BaseModel, Field
|
| 18 |
+
|
| 19 |
+
from openenv.core.env_server.types import Action, Observation
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ── Action vocabulary ───────────────────────────────────────────────────────
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ActionType(str, Enum):
|
| 26 |
+
# ── Beam & data acquisition (DAQ) ─────────────────────────────────
|
| 27 |
+
CONFIGURE_BEAM = "configure_beam"
|
| 28 |
+
ALLOCATE_LUMINOSITY = "allocate_luminosity"
|
| 29 |
+
SET_TRIGGER = "set_trigger"
|
| 30 |
+
COLLECT_COLLISIONS = "collect_collisions"
|
| 31 |
+
|
| 32 |
+
# ── Reconstruction & calibration ─────────────────────────────────
|
| 33 |
+
CALIBRATE_DETECTOR = "calibrate_detector"
|
| 34 |
+
RECONSTRUCT_TRACKS = "reconstruct_tracks"
|
| 35 |
+
SELECT_CHANNEL = "select_channel"
|
| 36 |
+
|
| 37 |
+
# ── Analysis ──────────────────────────────────────────────────────
|
| 38 |
+
BUILD_INVARIANT_MASS = "build_invariant_mass"
|
| 39 |
+
SUBTRACT_BACKGROUND = "subtract_background"
|
| 40 |
+
FIT_RESONANCE = "fit_resonance"
|
| 41 |
+
SCAN_BUMP = "scan_bump"
|
| 42 |
+
MEASURE_ANGULAR = "measure_angular"
|
| 43 |
+
ESTIMATE_SIGNIFICANCE = "estimate_significance"
|
| 44 |
+
|
| 45 |
+
# ── Systematics & meta ───────────────────────────────────────────
|
| 46 |
+
REQUEST_SYSTEMATICS = "request_systematics"
|
| 47 |
+
REQUEST_THEORY_REVIEW = "request_theory_review"
|
| 48 |
+
|
| 49 |
+
# ── Final ─────────────────────────────────────────────────────────
|
| 50 |
+
SUBMIT_DISCOVERY_CLAIM = "submit_discovery_claim"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
DAQ_ACTIONS = frozenset({
|
| 54 |
+
ActionType.CONFIGURE_BEAM,
|
| 55 |
+
ActionType.ALLOCATE_LUMINOSITY,
|
| 56 |
+
ActionType.SET_TRIGGER,
|
| 57 |
+
ActionType.COLLECT_COLLISIONS,
|
| 58 |
+
})
|
| 59 |
+
|
| 60 |
+
RECO_ACTIONS = frozenset({
|
| 61 |
+
ActionType.CALIBRATE_DETECTOR,
|
| 62 |
+
ActionType.RECONSTRUCT_TRACKS,
|
| 63 |
+
ActionType.SELECT_CHANNEL,
|
| 64 |
+
})
|
| 65 |
+
|
| 66 |
+
ANALYSIS_ACTIONS = frozenset({
|
| 67 |
+
ActionType.BUILD_INVARIANT_MASS,
|
| 68 |
+
ActionType.SUBTRACT_BACKGROUND,
|
| 69 |
+
ActionType.FIT_RESONANCE,
|
| 70 |
+
ActionType.SCAN_BUMP,
|
| 71 |
+
ActionType.MEASURE_ANGULAR,
|
| 72 |
+
ActionType.ESTIMATE_SIGNIFICANCE,
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
META_ACTIONS = frozenset({
|
| 76 |
+
ActionType.REQUEST_SYSTEMATICS,
|
| 77 |
+
ActionType.REQUEST_THEORY_REVIEW,
|
| 78 |
+
ActionType.SUBMIT_DISCOVERY_CLAIM,
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ── Detector channels & physics primitives ────────────────────────────────
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class DetectorChannel(str, Enum):
|
| 86 |
+
"""Final-state decay channel the agent reconstructs in.
|
| 87 |
+
|
| 88 |
+
Channels affect signal acceptance and background composition. Picking a
|
| 89 |
+
channel where the true particle does not decay yields low signal yield
|
| 90 |
+
no matter how much luminosity is collected — this is intentional.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
DIPHOTON = "diphoton" # γγ
|
| 94 |
+
DILEPTON_EE = "dilepton_ee" # e+ e-
|
| 95 |
+
DILEPTON_MUMU = "dilepton_mumu" # μ+ μ-
|
| 96 |
+
DIJET = "dijet" # jj
|
| 97 |
+
FOUR_LEPTON = "four_lepton" # 4ℓ
|
| 98 |
+
BB = "bb" # b b-bar
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class TriggerType(str, Enum):
|
| 102 |
+
"""Hardware-level event selection."""
|
| 103 |
+
|
| 104 |
+
LOW_PT = "low_pt" # broad acceptance, lots of background
|
| 105 |
+
HIGH_PT = "high_pt" # high-mass focus, lower QCD
|
| 106 |
+
DIPHOTON_HLT = "diphoton_hlt"
|
| 107 |
+
DILEPTON_HLT = "dilepton_hlt"
|
| 108 |
+
JET_HLT = "jet_hlt"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class BeamEnergy(str, Enum):
|
| 112 |
+
"""LHC-style center-of-mass energies (TeV)."""
|
| 113 |
+
|
| 114 |
+
E_7 = "7TeV"
|
| 115 |
+
E_8 = "8TeV"
|
| 116 |
+
E_13 = "13TeV"
|
| 117 |
+
E_14 = "14TeV"
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ── Tool / instrument registry (for prompts and tool-fit reward) ──────────
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class ToolCategory(str, Enum):
|
| 124 |
+
DAQ = "daq"
|
| 125 |
+
RECONSTRUCTION = "reconstruction"
|
| 126 |
+
CALIBRATION = "calibration"
|
| 127 |
+
ANALYSIS = "analysis"
|
| 128 |
+
STATISTICS = "statistics"
|
| 129 |
+
SYSTEMATICS = "systematics"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class ToolSpec(BaseModel):
|
| 133 |
+
name: str
|
| 134 |
+
category: ToolCategory
|
| 135 |
+
description: str = ""
|
| 136 |
+
typical_runtime_hours: float = 0.5
|
| 137 |
+
typical_cost_musd: float = 0.0 # in millions of USD (compute / beam time proxy)
|
| 138 |
+
requires_gpu: bool = False
|
| 139 |
+
channels: List[str] = Field(default_factory=list)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
TOOL_REGISTRY: Dict[str, ToolSpec] = {
|
| 143 |
+
"ATLAS_HLT": ToolSpec(
|
| 144 |
+
name="ATLAS_HLT",
|
| 145 |
+
category=ToolCategory.DAQ,
|
| 146 |
+
description="ATLAS High-Level Trigger system for online event selection",
|
| 147 |
+
typical_runtime_hours=0.0,
|
| 148 |
+
channels=["diphoton", "dilepton_ee", "dilepton_mumu", "four_lepton", "dijet", "bb"],
|
| 149 |
+
),
|
| 150 |
+
"CMS_HLT": ToolSpec(
|
| 151 |
+
name="CMS_HLT",
|
| 152 |
+
category=ToolCategory.DAQ,
|
| 153 |
+
description="CMS High-Level Trigger system",
|
| 154 |
+
typical_runtime_hours=0.0,
|
| 155 |
+
channels=["diphoton", "dilepton_ee", "dilepton_mumu", "four_lepton", "dijet", "bb"],
|
| 156 |
+
),
|
| 157 |
+
"GEANT4": ToolSpec(
|
| 158 |
+
name="GEANT4",
|
| 159 |
+
category=ToolCategory.RECONSTRUCTION,
|
| 160 |
+
description="Detector simulation toolkit for full event reconstruction",
|
| 161 |
+
typical_runtime_hours=1.0,
|
| 162 |
+
typical_cost_musd=0.05,
|
| 163 |
+
requires_gpu=False,
|
| 164 |
+
),
|
| 165 |
+
"Athena": ToolSpec(
|
| 166 |
+
name="Athena",
|
| 167 |
+
category=ToolCategory.RECONSTRUCTION,
|
| 168 |
+
description="ATLAS reconstruction framework",
|
| 169 |
+
typical_runtime_hours=0.8,
|
| 170 |
+
),
|
| 171 |
+
"CMSSW": ToolSpec(
|
| 172 |
+
name="CMSSW",
|
| 173 |
+
category=ToolCategory.RECONSTRUCTION,
|
| 174 |
+
description="CMS reconstruction software",
|
| 175 |
+
typical_runtime_hours=0.8,
|
| 176 |
+
),
|
| 177 |
+
"ECAL_calibration": ToolSpec(
|
| 178 |
+
name="ECAL_calibration",
|
| 179 |
+
category=ToolCategory.CALIBRATION,
|
| 180 |
+
description="Electromagnetic calorimeter energy-scale calibration",
|
| 181 |
+
typical_runtime_hours=0.3,
|
| 182 |
+
),
|
| 183 |
+
"Tracker_alignment": ToolSpec(
|
| 184 |
+
name="Tracker_alignment",
|
| 185 |
+
category=ToolCategory.CALIBRATION,
|
| 186 |
+
description="Inner tracker alignment for momentum precision",
|
| 187 |
+
typical_runtime_hours=0.4,
|
| 188 |
+
),
|
| 189 |
+
"ROOT_RooFit": ToolSpec(
|
| 190 |
+
name="ROOT_RooFit",
|
| 191 |
+
category=ToolCategory.ANALYSIS,
|
| 192 |
+
description="Maximum-likelihood spectrum fitting toolkit",
|
| 193 |
+
typical_runtime_hours=0.2,
|
| 194 |
+
),
|
| 195 |
+
"MadGraph": ToolSpec(
|
| 196 |
+
name="MadGraph",
|
| 197 |
+
category=ToolCategory.ANALYSIS,
|
| 198 |
+
description="Matrix-element generator for signal+background templates",
|
| 199 |
+
typical_runtime_hours=1.5,
|
| 200 |
+
typical_cost_musd=0.02,
|
| 201 |
+
),
|
| 202 |
+
"Pythia8": ToolSpec(
|
| 203 |
+
name="Pythia8",
|
| 204 |
+
category=ToolCategory.ANALYSIS,
|
| 205 |
+
description="Parton-shower and hadronisation generator",
|
| 206 |
+
typical_runtime_hours=0.5,
|
| 207 |
+
),
|
| 208 |
+
"BumpHunter": ToolSpec(
|
| 209 |
+
name="BumpHunter",
|
| 210 |
+
category=ToolCategory.STATISTICS,
|
| 211 |
+
description="Sliding-window local-significance bump-hunting algorithm",
|
| 212 |
+
typical_runtime_hours=0.1,
|
| 213 |
+
),
|
| 214 |
+
"CLs_fit": ToolSpec(
|
| 215 |
+
name="CLs_fit",
|
| 216 |
+
category=ToolCategory.STATISTICS,
|
| 217 |
+
description="Modified-frequentist CLs limits and significance",
|
| 218 |
+
typical_runtime_hours=0.1,
|
| 219 |
+
),
|
| 220 |
+
"Asimov_significance": ToolSpec(
|
| 221 |
+
name="Asimov_significance",
|
| 222 |
+
category=ToolCategory.STATISTICS,
|
| 223 |
+
description="Asymptotic significance from Asimov dataset",
|
| 224 |
+
typical_runtime_hours=0.05,
|
| 225 |
+
),
|
| 226 |
+
"JES_systematics": ToolSpec(
|
| 227 |
+
name="JES_systematics",
|
| 228 |
+
category=ToolCategory.SYSTEMATICS,
|
| 229 |
+
description="Jet energy-scale systematic study",
|
| 230 |
+
typical_runtime_hours=0.4,
|
| 231 |
+
),
|
| 232 |
+
"Luminosity_calibration": ToolSpec(
|
| 233 |
+
name="Luminosity_calibration",
|
| 234 |
+
category=ToolCategory.SYSTEMATICS,
|
| 235 |
+
description="Van der Meer scan luminosity calibration",
|
| 236 |
+
typical_runtime_hours=0.3,
|
| 237 |
+
),
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ── Action schema ──────────────────────────────────────────────────────────
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class ExperimentAction(Action):
|
| 245 |
+
"""One structured experimental step at the LHC."""
|
| 246 |
+
|
| 247 |
+
action_type: ActionType = Field(
|
| 248 |
+
...,
|
| 249 |
+
description=(
|
| 250 |
+
"Discrete LHC pipeline step. The environment enforces physics "
|
| 251 |
+
"prerequisites: you cannot fit a spectrum before collecting data, "
|
| 252 |
+
"or claim a discovery before estimating significance."
|
| 253 |
+
),
|
| 254 |
+
)
|
| 255 |
+
method: Optional[str] = Field(
|
| 256 |
+
None,
|
| 257 |
+
description=(
|
| 258 |
+
"Optional named instrument or framework (e.g. 'ROOT_RooFit', "
|
| 259 |
+
"'BumpHunter', 'Pythia8'). Affects cost, runtime, and tool-fit reward."
|
| 260 |
+
),
|
| 261 |
+
)
|
| 262 |
+
parameters: Dict[str, Any] = Field(
|
| 263 |
+
default_factory=dict,
|
| 264 |
+
description=(
|
| 265 |
+
"Action-specific settings such as beam energy, integrated luminosity "
|
| 266 |
+
"(fb^-1), trigger selection, decay channel, mass window, fit model."
|
| 267 |
+
),
|
| 268 |
+
)
|
| 269 |
+
justification: Optional[str] = Field(
|
| 270 |
+
None,
|
| 271 |
+
description="Short scientific rationale for picking this step now.",
|
| 272 |
+
)
|
| 273 |
+
confidence: float = Field(
|
| 274 |
+
0.5, ge=0.0, le=1.0,
|
| 275 |
+
description="Agent confidence in the chosen step.",
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# ── Outputs ────────────────────────────────────────────────────────────────
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class OutputType(str, Enum):
|
| 283 |
+
BEAM_CONFIG = "beam_config"
|
| 284 |
+
LUMINOSITY_LOG = "luminosity_log"
|
| 285 |
+
TRIGGER_REPORT = "trigger_report"
|
| 286 |
+
COLLISION_BATCH = "collision_batch"
|
| 287 |
+
CALIBRATION_REPORT = "calibration_report"
|
| 288 |
+
RECONSTRUCTION = "reconstruction"
|
| 289 |
+
CHANNEL_SELECTION = "channel_selection"
|
| 290 |
+
INVARIANT_MASS_HIST = "invariant_mass_hist"
|
| 291 |
+
BACKGROUND_SUBTRACTION = "background_subtraction"
|
| 292 |
+
FIT_RESULT = "fit_result"
|
| 293 |
+
BUMP_SCAN = "bump_scan"
|
| 294 |
+
ANGULAR_RESULT = "angular_result"
|
| 295 |
+
SIGNIFICANCE = "significance"
|
| 296 |
+
SYSTEMATICS_REPORT = "systematics_report"
|
| 297 |
+
THEORY_REVIEW = "theory_review"
|
| 298 |
+
DISCOVERY_CLAIM = "discovery_claim"
|
| 299 |
+
FAILURE_REPORT = "failure_report"
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class IntermediateOutput(BaseModel):
|
| 303 |
+
"""A single noisy detector or analysis artifact."""
|
| 304 |
+
|
| 305 |
+
output_type: OutputType
|
| 306 |
+
step_index: int
|
| 307 |
+
success: bool = True
|
| 308 |
+
quality_score: float = Field(1.0, ge=0.0, le=1.0)
|
| 309 |
+
summary: str = ""
|
| 310 |
+
data: Dict[str, Any] = Field(default_factory=dict)
|
| 311 |
+
uncertainty: float = Field(0.0, ge=0.0, le=1.0)
|
| 312 |
+
warnings: List[str] = Field(default_factory=list)
|
| 313 |
+
artifacts_available: List[str] = Field(default_factory=list)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# ── Observable state components ───────────────────────────────────────────
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class ResourceUsage(BaseModel):
|
| 320 |
+
"""Agent-visible resource counters."""
|
| 321 |
+
|
| 322 |
+
budget_used_musd: float = 0.0
|
| 323 |
+
budget_remaining_musd: float = 100.0
|
| 324 |
+
luminosity_used_fb: float = 0.0
|
| 325 |
+
luminosity_remaining_fb: float = 300.0
|
| 326 |
+
time_used_days: float = 0.0
|
| 327 |
+
time_remaining_days: float = 365.0
|
| 328 |
+
compute_hours_used: float = 0.0
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class PipelineStepRecord(BaseModel):
|
| 332 |
+
step_index: int
|
| 333 |
+
action_type: ActionType
|
| 334 |
+
method: Optional[str] = None
|
| 335 |
+
parameters: Dict[str, Any] = Field(default_factory=dict)
|
| 336 |
+
output_summary: str = ""
|
| 337 |
+
output_type: OutputType
|
| 338 |
+
success: bool = True
|
| 339 |
+
quality_score: float = 1.0
|
| 340 |
+
cost_musd: float = 0.0
|
| 341 |
+
luminosity_cost_fb: float = 0.0
|
| 342 |
+
time_cost_days: float = 0.0
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class PaperReference(BaseModel):
|
| 346 |
+
title: str
|
| 347 |
+
citation: Optional[str] = None
|
| 348 |
+
doi: Optional[str] = None
|
| 349 |
+
arxiv_id: Optional[str] = None
|
| 350 |
+
url: Optional[str] = None
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class ExpectedFinding(BaseModel):
|
| 354 |
+
finding: str
|
| 355 |
+
category: str = "claim"
|
| 356 |
+
keywords: List[str] = Field(default_factory=list)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class TaskSpec(BaseModel):
|
| 360 |
+
"""The physics question the agent is given for this episode."""
|
| 361 |
+
|
| 362 |
+
problem_statement: str = "Discover and characterise an unknown resonance."
|
| 363 |
+
target_collider: str = "LHC"
|
| 364 |
+
beam_energy_options: List[str] = Field(
|
| 365 |
+
default_factory=lambda: [e.value for e in BeamEnergy],
|
| 366 |
+
)
|
| 367 |
+
available_channels: List[str] = Field(
|
| 368 |
+
default_factory=lambda: [c.value for c in DetectorChannel],
|
| 369 |
+
)
|
| 370 |
+
available_triggers: List[str] = Field(
|
| 371 |
+
default_factory=lambda: [t.value for t in TriggerType],
|
| 372 |
+
)
|
| 373 |
+
available_tools: List[str] = Field(
|
| 374 |
+
default_factory=lambda: list(TOOL_REGISTRY.keys()),
|
| 375 |
+
)
|
| 376 |
+
mass_search_window_gev: List[float] = Field(default_factory=lambda: [50.0, 1000.0])
|
| 377 |
+
budget_limit_musd: float = 100.0
|
| 378 |
+
luminosity_budget_fb: float = 300.0
|
| 379 |
+
time_limit_days: float = 365.0
|
| 380 |
+
prior_observations: List[str] = Field(default_factory=list)
|
| 381 |
+
success_criteria: List[str] = Field(default_factory=list)
|
| 382 |
+
paper_references: List[PaperReference] = Field(default_factory=list)
|
| 383 |
+
expected_findings: List[ExpectedFinding] = Field(default_factory=list)
|
| 384 |
+
difficulty: str = "medium"
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class DiscoveryClaim(BaseModel):
|
| 388 |
+
"""Structured final claim graded against hidden truth."""
|
| 389 |
+
|
| 390 |
+
claim: str = ""
|
| 391 |
+
mass_estimate_gev: Optional[float] = None
|
| 392 |
+
mass_uncertainty_gev: Optional[float] = None
|
| 393 |
+
width_estimate_gev: Optional[float] = None
|
| 394 |
+
significance_sigma: Optional[float] = None
|
| 395 |
+
decay_channel: Optional[str] = None
|
| 396 |
+
spin_hypothesis: Optional[int] = None # 0, 1, 2
|
| 397 |
+
parity: Optional[str] = None # "+", "-"
|
| 398 |
+
cross_section_fb: Optional[float] = None
|
| 399 |
+
confidence: float = Field(0.5, ge=0.0, le=1.0)
|
| 400 |
+
evidence_steps: List[int] = Field(default_factory=list)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class CollisionObservation(Observation):
|
| 404 |
+
"""Full observable state returned to the agent each step.
|
| 405 |
+
|
| 406 |
+
Excludes the hidden particle truth and hidden detector systematics.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
task: TaskSpec = Field(default_factory=TaskSpec)
|
| 410 |
+
step_index: int = 0
|
| 411 |
+
pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
|
| 412 |
+
available_channels: List[str] = Field(default_factory=list)
|
| 413 |
+
available_triggers: List[str] = Field(default_factory=list)
|
| 414 |
+
available_tools: List[str] = Field(default_factory=list)
|
| 415 |
+
resource_usage: ResourceUsage = Field(default_factory=ResourceUsage)
|
| 416 |
+
latest_output: Optional[IntermediateOutput] = None
|
| 417 |
+
all_outputs: List[IntermediateOutput] = Field(default_factory=list)
|
| 418 |
+
candidate_masses_gev: List[float] = Field(default_factory=list)
|
| 419 |
+
candidate_significances: List[float] = Field(default_factory=list)
|
| 420 |
+
selected_channel: Optional[str] = None
|
| 421 |
+
selected_beam_energy: Optional[str] = None
|
| 422 |
+
cumulative_significance: float = 0.0
|
| 423 |
+
uncertainty_summary: Dict[str, float] = Field(default_factory=dict)
|
| 424 |
+
rule_violations: List[str] = Field(default_factory=list)
|
| 425 |
+
step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# ── Agent-facing prompt helpers ───────────────────────────────────────────
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
AGENT_ACTION_GUIDANCE: Dict[ActionType, str] = {
|
| 432 |
+
ActionType.CONFIGURE_BEAM: (
|
| 433 |
+
"Pick the LHC center-of-mass energy. Higher energy reaches heavier "
|
| 434 |
+
"resonances but costs more per fb^-1. Required before collecting data."
|
| 435 |
+
),
|
| 436 |
+
ActionType.ALLOCATE_LUMINOSITY: (
|
| 437 |
+
"Schedule a chunk of integrated luminosity (fb^-1). More luminosity "
|
| 438 |
+
"means more events but uses budget and time. Required before collecting."
|
| 439 |
+
),
|
| 440 |
+
ActionType.SET_TRIGGER: (
|
| 441 |
+
"Choose a hardware/HLT trigger. Match the trigger to the channel of "
|
| 442 |
+
"interest; mismatched triggers throw away signal."
|
| 443 |
+
),
|
| 444 |
+
ActionType.COLLECT_COLLISIONS: (
|
| 445 |
+
"Run the experiment. Returns a noisy raw event count plus background "
|
| 446 |
+
"estimate, conditioned on beam, luminosity, trigger, and channel."
|
| 447 |
+
),
|
| 448 |
+
ActionType.CALIBRATE_DETECTOR: (
|
| 449 |
+
"Apply ECAL/tracker calibration. Reduces systematic uncertainty; "
|
| 450 |
+
"neglecting it inflates fit uncertainty later."
|
| 451 |
+
),
|
| 452 |
+
ActionType.RECONSTRUCT_TRACKS: (
|
| 453 |
+
"Reconstruct charged-particle tracks and physics objects. Required "
|
| 454 |
+
"before any analysis-level step."
|
| 455 |
+
),
|
| 456 |
+
ActionType.SELECT_CHANNEL: (
|
| 457 |
+
"Pick the decay channel to study (γγ, ℓℓ, jj, 4ℓ, bb). Wrong channel "
|
| 458 |
+
"= small signal acceptance regardless of luminosity."
|
| 459 |
+
),
|
| 460 |
+
ActionType.BUILD_INVARIANT_MASS: (
|
| 461 |
+
"Construct the invariant-mass histogram in the chosen channel and "
|
| 462 |
+
"mass window."
|
| 463 |
+
),
|
| 464 |
+
ActionType.SUBTRACT_BACKGROUND: (
|
| 465 |
+
"Fit a smooth background model and subtract it to expose any peak."
|
| 466 |
+
),
|
| 467 |
+
ActionType.FIT_RESONANCE: (
|
| 468 |
+
"Fit a Breit-Wigner / Crystal Ball line shape. Returns mass, width, "
|
| 469 |
+
"and statistical uncertainty."
|
| 470 |
+
),
|
| 471 |
+
ActionType.SCAN_BUMP: (
|
| 472 |
+
"Run a sliding-window bump hunt over the mass window. Reports the "
|
| 473 |
+
"most-significant candidate region."
|
| 474 |
+
),
|
| 475 |
+
ActionType.MEASURE_ANGULAR: (
|
| 476 |
+
"Measure decay angular distribution to constrain spin/parity. "
|
| 477 |
+
"Useful only after a peak is identified."
|
| 478 |
+
),
|
| 479 |
+
ActionType.ESTIMATE_SIGNIFICANCE: (
|
| 480 |
+
"Compute the statistical significance of a candidate signal in σ. "
|
| 481 |
+
"Required before claiming a discovery."
|
| 482 |
+
),
|
| 483 |
+
ActionType.REQUEST_SYSTEMATICS: (
|
| 484 |
+
"Run a systematics study (JES, luminosity, calibration). Improves "
|
| 485 |
+
"uncertainty estimates and reduces overconfidence penalty."
|
| 486 |
+
),
|
| 487 |
+
ActionType.REQUEST_THEORY_REVIEW: (
|
| 488 |
+
"Ask a theorist sub-agent to review the evidence; small extra signal "
|
| 489 |
+
"but not a substitute for missing data."
|
| 490 |
+
),
|
| 491 |
+
ActionType.SUBMIT_DISCOVERY_CLAIM: (
|
| 492 |
+
"Submit a structured discovery claim. Graded on mass calibration, "
|
| 493 |
+
"significance, channel, spin hypothesis, and overconfidence."
|
| 494 |
+
),
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
AGENT_ENVIRONMENT_RULES: List[str] = [
|
| 499 |
+
"Each successful action returns summarized evidence; do not repeat steps.",
|
| 500 |
+
"Hard prerequisites are enforced: data collection requires beam+luminosity+trigger; "
|
| 501 |
+
"analysis requires reconstruction and a chosen channel.",
|
| 502 |
+
"A discovery claim requires a fitted resonance and an estimated significance.",
|
| 503 |
+
"Tools listed in available_tools are pre-filtered for this episode; prefer them.",
|
| 504 |
+
"Submitting an overconfident wrong claim is heavily penalised.",
|
| 505 |
+
]
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def build_agent_system_prompt() -> str:
|
| 509 |
+
lines = [
|
| 510 |
+
"You are an expert high-energy physicist running an analysis at the LHC.",
|
| 511 |
+
"",
|
| 512 |
+
"At each turn you observe the experiment state and pick one structured next step",
|
| 513 |
+
"to maximise the probability of correctly characterising a hidden resonance.",
|
| 514 |
+
"",
|
| 515 |
+
"Environment rules:",
|
| 516 |
+
]
|
| 517 |
+
lines.extend(f" - {rule}" for rule in AGENT_ENVIRONMENT_RULES)
|
| 518 |
+
lines.append("")
|
| 519 |
+
lines.append("Action guidance:")
|
| 520 |
+
lines.extend(
|
| 521 |
+
f" - {a.value}: {AGENT_ACTION_GUIDANCE[a]}" for a in ActionType
|
| 522 |
+
)
|
| 523 |
+
lines.extend([
|
| 524 |
+
"",
|
| 525 |
+
"Respond with ONLY a single valid JSON object, no extra prose:",
|
| 526 |
+
'{"action_type": "...", "method": null, "parameters": {}, "justification": "...", "confidence": 0.8}',
|
| 527 |
+
"",
|
| 528 |
+
"For submit_discovery_claim, structure parameters['claim'] as:",
|
| 529 |
+
'{"mass_estimate_gev": 125.0, "mass_uncertainty_gev": 0.5, "width_estimate_gev": 0.004,'
|
| 530 |
+
' "significance_sigma": 5.2, "decay_channel": "diphoton", "spin_hypothesis": 0,'
|
| 531 |
+
' "parity": "+", "cross_section_fb": 50.0, "confidence": 0.9}',
|
| 532 |
+
])
|
| 533 |
+
return "\n".join(lines)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def build_agent_observation_context(
|
| 537 |
+
obs: CollisionObservation,
|
| 538 |
+
*,
|
| 539 |
+
max_tools: int = 6,
|
| 540 |
+
max_channels: int = 4,
|
| 541 |
+
) -> str:
|
| 542 |
+
parts: List[str] = []
|
| 543 |
+
|
| 544 |
+
parts.append(
|
| 545 |
+
f"Mass search window: [{obs.task.mass_search_window_gev[0]:.0f}, "
|
| 546 |
+
f"{obs.task.mass_search_window_gev[1]:.0f}] GeV; "
|
| 547 |
+
f"difficulty={obs.task.difficulty}."
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
chans = list(dict.fromkeys(obs.available_channels or obs.task.available_channels))
|
| 551 |
+
if chans:
|
| 552 |
+
parts.append("Available channels: " + ", ".join(chans[:max_channels]))
|
| 553 |
+
|
| 554 |
+
tools = list(dict.fromkeys(obs.available_tools or obs.task.available_tools))
|
| 555 |
+
if tools:
|
| 556 |
+
parts.append("Available tools: " + ", ".join(tools[:max_tools]))
|
| 557 |
+
|
| 558 |
+
if obs.selected_channel:
|
| 559 |
+
parts.append(f"Selected channel: {obs.selected_channel}")
|
| 560 |
+
if obs.selected_beam_energy:
|
| 561 |
+
parts.append(f"Beam energy: {obs.selected_beam_energy}")
|
| 562 |
+
|
| 563 |
+
if obs.candidate_masses_gev:
|
| 564 |
+
masses = [f"{m:.1f}" for m in obs.candidate_masses_gev[:3]]
|
| 565 |
+
sigmas = [f"{s:.1f}" for s in obs.candidate_significances[:3]]
|
| 566 |
+
parts.append(
|
| 567 |
+
"Candidate peaks (GeV / σ): "
|
| 568 |
+
+ ", ".join(f"{m}/{s}" for m, s in zip(masses, sigmas))
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
return "\n".join(parts)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
__all__ = [
|
| 575 |
+
"ActionType",
|
| 576 |
+
"DAQ_ACTIONS",
|
| 577 |
+
"RECO_ACTIONS",
|
| 578 |
+
"ANALYSIS_ACTIONS",
|
| 579 |
+
"META_ACTIONS",
|
| 580 |
+
"DetectorChannel",
|
| 581 |
+
"TriggerType",
|
| 582 |
+
"BeamEnergy",
|
| 583 |
+
"ToolCategory",
|
| 584 |
+
"ToolSpec",
|
| 585 |
+
"TOOL_REGISTRY",
|
| 586 |
+
"ExperimentAction",
|
| 587 |
+
"OutputType",
|
| 588 |
+
"IntermediateOutput",
|
| 589 |
+
"ResourceUsage",
|
| 590 |
+
"PipelineStepRecord",
|
| 591 |
+
"PaperReference",
|
| 592 |
+
"ExpectedFinding",
|
| 593 |
+
"TaskSpec",
|
| 594 |
+
"DiscoveryClaim",
|
| 595 |
+
"CollisionObservation",
|
| 596 |
+
"AGENT_ACTION_GUIDANCE",
|
| 597 |
+
"AGENT_ENVIRONMENT_RULES",
|
| 598 |
+
"build_agent_system_prompt",
|
| 599 |
+
"build_agent_observation_context",
|
| 600 |
+
]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: cernenv
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
pyproject.toml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-cernenv"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "RL environment for autonomous particle physics agents at the LHC"
|
| 9 |
+
requires-python = ">=3.10,<3.13"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core[core]>=0.2.0",
|
| 12 |
+
"numpy>=1.24.0",
|
| 13 |
+
"scipy>=1.10.0",
|
| 14 |
+
"pydantic>=2.0.0",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[project.optional-dependencies]
|
| 18 |
+
dev = [
|
| 19 |
+
"pytest>=8.0.0",
|
| 20 |
+
"pytest-cov>=4.0.0",
|
| 21 |
+
]
|
| 22 |
+
train = [
|
| 23 |
+
"accelerate>=1.0.0",
|
| 24 |
+
"datasets>=2.18.0",
|
| 25 |
+
"matplotlib>=3.8.0",
|
| 26 |
+
"peft>=0.10.0",
|
| 27 |
+
"torch>=2.2.0",
|
| 28 |
+
"transformers>=4.44.0",
|
| 29 |
+
"trl>=0.9.0",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
[project.scripts]
|
| 33 |
+
cernenv-server = "server.app:main"
|
| 34 |
+
|
| 35 |
+
[tool.uv]
|
| 36 |
+
package = false
|
| 37 |
+
|
| 38 |
+
[tool.setuptools]
|
| 39 |
+
include-package-data = true
|
| 40 |
+
packages = [
|
| 41 |
+
"cernenv",
|
| 42 |
+
"cernenv.server",
|
| 43 |
+
"cernenv.server.simulator",
|
| 44 |
+
"cernenv.server.rules",
|
| 45 |
+
"cernenv.server.rewards",
|
| 46 |
+
"cernenv.server.tasks",
|
| 47 |
+
"cernenv.server.physics",
|
| 48 |
+
"cernenv.training",
|
| 49 |
+
"cernenv.tests",
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
[tool.setuptools.package-dir]
|
| 53 |
+
cernenv = "."
|
| 54 |
+
"cernenv.server" = "server"
|
| 55 |
+
"cernenv.server.simulator" = "server/simulator"
|
| 56 |
+
"cernenv.server.rules" = "server/rules"
|
| 57 |
+
"cernenv.server.rewards" = "server/rewards"
|
| 58 |
+
"cernenv.server.tasks" = "server/tasks"
|
| 59 |
+
"cernenv.server.physics" = "server/physics"
|
| 60 |
+
"cernenv.training" = "training"
|
| 61 |
+
"cernenv.tests" = "tests"
|
scripts/__init__.py
ADDED
|
File without changes
|
scripts/_build_spaces.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stage env- and trainer-Space directories from the repo root.
|
| 2 |
+
|
| 3 |
+
Each Space needs a *single* directory containing the full repo plus the
|
| 4 |
+
right Dockerfile + README front-matter at its root. This script copies
|
| 5 |
+
the repo into a staging directory, drops in the Space-specific
|
| 6 |
+
``Dockerfile`` / ``README.md``, and prints the staging path.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import shutil
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 18 |
+
|
| 19 |
+
EXCLUDES = {
|
| 20 |
+
".venv",
|
| 21 |
+
"__pycache__",
|
| 22 |
+
".git",
|
| 23 |
+
".cursor",
|
| 24 |
+
".DS_Store",
|
| 25 |
+
"runs",
|
| 26 |
+
"wandb",
|
| 27 |
+
"node_modules",
|
| 28 |
+
".pytest_cache",
|
| 29 |
+
".mypy_cache",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _ignore(_dir: str, names):
|
| 34 |
+
return [n for n in names if n in EXCLUDES or n.endswith((".pyc", ".log"))]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _stage(stage_dir: Path) -> None:
|
| 38 |
+
if stage_dir.exists():
|
| 39 |
+
shutil.rmtree(stage_dir)
|
| 40 |
+
shutil.copytree(REPO_ROOT, stage_dir, ignore=_ignore, symlinks=False)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_env_space(stage_dir: Path) -> None:
|
| 44 |
+
_stage(stage_dir)
|
| 45 |
+
|
| 46 |
+
dockerfile = """\
|
| 47 |
+
# CERNenv environment Space (Docker, CPU)
|
| 48 |
+
FROM python:3.11-slim
|
| 49 |
+
|
| 50 |
+
ENV PYTHONUNBUFFERED=1 \\
|
| 51 |
+
PIP_NO_CACHE_DIR=1 \\
|
| 52 |
+
PYTHONPATH=/home/user/app
|
| 53 |
+
|
| 54 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \\
|
| 55 |
+
git curl ca-certificates build-essential \\
|
| 56 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 57 |
+
|
| 58 |
+
RUN useradd -ms /bin/bash user
|
| 59 |
+
USER user
|
| 60 |
+
WORKDIR /home/user/app
|
| 61 |
+
|
| 62 |
+
COPY --chown=user:user space/env/requirements.txt /tmp/requirements.txt
|
| 63 |
+
RUN python -m pip install --upgrade pip && \\
|
| 64 |
+
python -m pip install --user -r /tmp/requirements.txt
|
| 65 |
+
|
| 66 |
+
COPY --chown=user:user . /home/user/app
|
| 67 |
+
|
| 68 |
+
EXPOSE 7860
|
| 69 |
+
|
| 70 |
+
CMD [\"python\", \"-m\", \"uvicorn\", \"server.app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"7860\"]
|
| 71 |
+
"""
|
| 72 |
+
(stage_dir / "Dockerfile").write_text(dockerfile)
|
| 73 |
+
|
| 74 |
+
readme = (stage_dir / "space" / "env" / "README.md").read_text()
|
| 75 |
+
(stage_dir / "README.md").write_text(readme)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def build_trainer_space(stage_dir: Path) -> None:
|
| 79 |
+
_stage(stage_dir)
|
| 80 |
+
|
| 81 |
+
dockerfile = """\
|
| 82 |
+
# CERNenv trainer Space (Docker, A100)
|
| 83 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
| 84 |
+
|
| 85 |
+
ENV DEBIAN_FRONTEND=noninteractive \\
|
| 86 |
+
PYTHONUNBUFFERED=1 \\
|
| 87 |
+
PIP_NO_CACHE_DIR=1 \\
|
| 88 |
+
HF_HOME=/home/user/.cache/huggingface \\
|
| 89 |
+
TRANSFORMERS_CACHE=/home/user/.cache/huggingface/transformers \\
|
| 90 |
+
PYTHONPATH=/home/user/app
|
| 91 |
+
|
| 92 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \\
|
| 93 |
+
python3.11 python3.11-venv python3.11-dev python3-pip \\
|
| 94 |
+
git curl ca-certificates build-essential \\
|
| 95 |
+
&& rm -rf /var/lib/apt/lists/* \\
|
| 96 |
+
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python \\
|
| 97 |
+
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python3
|
| 98 |
+
|
| 99 |
+
RUN useradd -ms /bin/bash user
|
| 100 |
+
USER user
|
| 101 |
+
ENV PATH=\"/home/user/.local/bin:${PATH}\"
|
| 102 |
+
WORKDIR /home/user/app
|
| 103 |
+
|
| 104 |
+
COPY --chown=user:user space/training/requirements.txt /tmp/requirements.txt
|
| 105 |
+
RUN python -m pip install --upgrade pip && \\
|
| 106 |
+
python -m pip install --user -r /tmp/requirements.txt
|
| 107 |
+
|
| 108 |
+
COPY --chown=user:user . /home/user/app
|
| 109 |
+
|
| 110 |
+
EXPOSE 7860
|
| 111 |
+
|
| 112 |
+
CMD [\"python\", \"-m\", \"uvicorn\", \"space.training.app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"7860\"]
|
| 113 |
+
"""
|
| 114 |
+
(stage_dir / "Dockerfile").write_text(dockerfile)
|
| 115 |
+
|
| 116 |
+
readme = (stage_dir / "space" / "training" / "README.md").read_text()
|
| 117 |
+
(stage_dir / "README.md").write_text(readme)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def main() -> None: # pragma: no cover
|
| 121 |
+
parser = argparse.ArgumentParser()
|
| 122 |
+
parser.add_argument("kind", choices=["env", "trainer"])
|
| 123 |
+
parser.add_argument("--stage_dir", required=True)
|
| 124 |
+
args = parser.parse_args()
|
| 125 |
+
|
| 126 |
+
stage_dir = Path(args.stage_dir).resolve()
|
| 127 |
+
if args.kind == "env":
|
| 128 |
+
build_env_space(stage_dir)
|
| 129 |
+
else:
|
| 130 |
+
build_trainer_space(stage_dir)
|
| 131 |
+
print(stage_dir)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__": # pragma: no cover
|
| 135 |
+
main()
|
scripts/baseline_agents.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Built-in agents for evaluating CERNenv.
|
| 2 |
+
|
| 3 |
+
These do **not** use any neural model — they are deterministic / random
|
| 4 |
+
policies you can use as baselines and oracles. They consume a
|
| 5 |
+
``CollisionObservation`` and return an ``ExperimentAction``.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import List, Optional, Protocol
|
| 13 |
+
|
| 14 |
+
from models import ActionType, CollisionObservation, ExperimentAction
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CernAgent(Protocol):
|
| 18 |
+
name: str
|
| 19 |
+
|
| 20 |
+
def reset(self) -> None: ...
|
| 21 |
+
|
| 22 |
+
def act(self, obs: CollisionObservation) -> ExperimentAction: ...
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ── Random agent ─────────────────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class RandomAgent:
|
| 30 |
+
"""Picks a uniformly random valid action; useful as a worst-case baseline."""
|
| 31 |
+
|
| 32 |
+
name: str = "random"
|
| 33 |
+
seed: int = 0
|
| 34 |
+
|
| 35 |
+
def __post_init__(self) -> None:
|
| 36 |
+
self._rng = random.Random(self.seed)
|
| 37 |
+
|
| 38 |
+
def reset(self) -> None:
|
| 39 |
+
self._rng = random.Random(self.seed)
|
| 40 |
+
|
| 41 |
+
def act(self, obs: CollisionObservation) -> ExperimentAction:
|
| 42 |
+
action_type = self._rng.choice(list(ActionType))
|
| 43 |
+
params: dict = {}
|
| 44 |
+
if action_type == ActionType.CONFIGURE_BEAM:
|
| 45 |
+
params = {"beam_energy": self._rng.choice(obs.task.beam_energy_options or ["13TeV"])}
|
| 46 |
+
elif action_type == ActionType.SELECT_CHANNEL:
|
| 47 |
+
params = {"channel": self._rng.choice(obs.task.available_channels or ["diphoton"])}
|
| 48 |
+
elif action_type == ActionType.SET_TRIGGER:
|
| 49 |
+
params = {"trigger": self._rng.choice(obs.task.available_triggers or ["high_pt"])}
|
| 50 |
+
elif action_type == ActionType.ALLOCATE_LUMINOSITY:
|
| 51 |
+
params = {"luminosity_fb": self._rng.uniform(20.0, 100.0)}
|
| 52 |
+
elif action_type == ActionType.COLLECT_COLLISIONS:
|
| 53 |
+
params = {"luminosity_fb": self._rng.uniform(20.0, 100.0)}
|
| 54 |
+
elif action_type == ActionType.BUILD_INVARIANT_MASS:
|
| 55 |
+
lo, hi = obs.task.mass_search_window_gev
|
| 56 |
+
params = {"mass_window_gev": [lo, hi]}
|
| 57 |
+
elif action_type == ActionType.SUBMIT_DISCOVERY_CLAIM:
|
| 58 |
+
mass = obs.candidate_masses_gev[-1] if obs.candidate_masses_gev else (
|
| 59 |
+
0.5 * (obs.task.mass_search_window_gev[0] + obs.task.mass_search_window_gev[1])
|
| 60 |
+
)
|
| 61 |
+
params = {
|
| 62 |
+
"claim": {
|
| 63 |
+
"mass_estimate_gev": mass,
|
| 64 |
+
"mass_uncertainty_gev": 5.0,
|
| 65 |
+
"significance_sigma": obs.cumulative_significance,
|
| 66 |
+
"decay_channel": obs.selected_channel or "diphoton",
|
| 67 |
+
"spin_hypothesis": int(self._rng.choice([0, 1, 2])),
|
| 68 |
+
"parity": self._rng.choice(["+", "-"]),
|
| 69 |
+
"confidence": self._rng.uniform(0.4, 0.9),
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
return ExperimentAction(
|
| 73 |
+
action_type=action_type,
|
| 74 |
+
parameters=params,
|
| 75 |
+
confidence=0.4,
|
| 76 |
+
justification="random baseline",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ── Heuristic agent ──────────────────────────────────────────────────────
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class HeuristicAgent:
|
| 85 |
+
"""A scripted analysis-flow agent using high-yield channels and
|
| 86 |
+
sensible default parameters. Acts as the strong non-LLM baseline.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
name: str = "heuristic"
|
| 90 |
+
|
| 91 |
+
def __post_init__(self) -> None:
|
| 92 |
+
self._reset_plan()
|
| 93 |
+
|
| 94 |
+
def reset(self) -> None:
|
| 95 |
+
self._reset_plan()
|
| 96 |
+
|
| 97 |
+
def _reset_plan(self) -> None:
|
| 98 |
+
self._plan: List[ExperimentAction] = [
|
| 99 |
+
ExperimentAction(
|
| 100 |
+
action_type=ActionType.CONFIGURE_BEAM,
|
| 101 |
+
parameters={"beam_energy": "13TeV"},
|
| 102 |
+
confidence=0.9,
|
| 103 |
+
justification="13 TeV maximises reach within budget",
|
| 104 |
+
),
|
| 105 |
+
ExperimentAction(
|
| 106 |
+
action_type=ActionType.SELECT_CHANNEL,
|
| 107 |
+
parameters={"channel": "diphoton"},
|
| 108 |
+
confidence=0.7,
|
| 109 |
+
justification="diphoton has clean low-background signature",
|
| 110 |
+
),
|
| 111 |
+
ExperimentAction(
|
| 112 |
+
action_type=ActionType.SET_TRIGGER,
|
| 113 |
+
parameters={"trigger": "diphoton_hlt"},
|
| 114 |
+
confidence=0.9,
|
| 115 |
+
justification="match trigger to channel",
|
| 116 |
+
),
|
| 117 |
+
ExperimentAction(
|
| 118 |
+
action_type=ActionType.ALLOCATE_LUMINOSITY,
|
| 119 |
+
parameters={"luminosity_fb": 80.0},
|
| 120 |
+
confidence=0.8,
|
| 121 |
+
justification="bulk allocation for the first run",
|
| 122 |
+
),
|
| 123 |
+
ExperimentAction(
|
| 124 |
+
action_type=ActionType.COLLECT_COLLISIONS,
|
| 125 |
+
parameters={"luminosity_fb": 80.0},
|
| 126 |
+
confidence=0.8,
|
| 127 |
+
justification="run physics",
|
| 128 |
+
),
|
| 129 |
+
ExperimentAction(
|
| 130 |
+
action_type=ActionType.RECONSTRUCT_TRACKS,
|
| 131 |
+
method="Athena",
|
| 132 |
+
confidence=0.9,
|
| 133 |
+
justification="reconstruct objects",
|
| 134 |
+
),
|
| 135 |
+
ExperimentAction(
|
| 136 |
+
action_type=ActionType.CALIBRATE_DETECTOR,
|
| 137 |
+
method="ECAL_calibration",
|
| 138 |
+
confidence=0.8,
|
| 139 |
+
justification="reduce systematic uncertainty",
|
| 140 |
+
),
|
| 141 |
+
ExperimentAction(
|
| 142 |
+
action_type=ActionType.BUILD_INVARIANT_MASS,
|
| 143 |
+
parameters={"mass_window_gev": [80.0, 800.0], "n_bins": 60},
|
| 144 |
+
confidence=0.8,
|
| 145 |
+
justification="broad-window histogram",
|
| 146 |
+
),
|
| 147 |
+
ExperimentAction(
|
| 148 |
+
action_type=ActionType.SUBTRACT_BACKGROUND,
|
| 149 |
+
confidence=0.7,
|
| 150 |
+
justification="smooth-fit subtraction",
|
| 151 |
+
),
|
| 152 |
+
ExperimentAction(
|
| 153 |
+
action_type=ActionType.SCAN_BUMP,
|
| 154 |
+
method="BumpHunter",
|
| 155 |
+
confidence=0.8,
|
| 156 |
+
justification="locate candidate peak",
|
| 157 |
+
),
|
| 158 |
+
ExperimentAction(
|
| 159 |
+
action_type=ActionType.FIT_RESONANCE,
|
| 160 |
+
method="ROOT_RooFit",
|
| 161 |
+
confidence=0.85,
|
| 162 |
+
justification="fit Breit-Wigner peak",
|
| 163 |
+
),
|
| 164 |
+
ExperimentAction(
|
| 165 |
+
action_type=ActionType.REQUEST_SYSTEMATICS,
|
| 166 |
+
method="Luminosity_calibration",
|
| 167 |
+
confidence=0.7,
|
| 168 |
+
justification="pin down dominant systematics",
|
| 169 |
+
),
|
| 170 |
+
ExperimentAction(
|
| 171 |
+
action_type=ActionType.ESTIMATE_SIGNIFICANCE,
|
| 172 |
+
method="Asimov_significance",
|
| 173 |
+
confidence=0.85,
|
| 174 |
+
justification="quantify discovery significance",
|
| 175 |
+
),
|
| 176 |
+
ExperimentAction(
|
| 177 |
+
action_type=ActionType.MEASURE_ANGULAR,
|
| 178 |
+
confidence=0.7,
|
| 179 |
+
justification="probe spin",
|
| 180 |
+
),
|
| 181 |
+
]
|
| 182 |
+
self._idx = 0
|
| 183 |
+
self._claim_submitted = False
|
| 184 |
+
|
| 185 |
+
def act(self, obs: CollisionObservation) -> ExperimentAction:
|
| 186 |
+
if self._idx < len(self._plan):
|
| 187 |
+
a = self._plan[self._idx]
|
| 188 |
+
self._idx += 1
|
| 189 |
+
return a
|
| 190 |
+
if not self._claim_submitted:
|
| 191 |
+
self._claim_submitted = True
|
| 192 |
+
mass = obs.candidate_masses_gev[-1] if obs.candidate_masses_gev else 125.0
|
| 193 |
+
sig = obs.cumulative_significance or 5.0
|
| 194 |
+
return ExperimentAction(
|
| 195 |
+
action_type=ActionType.SUBMIT_DISCOVERY_CLAIM,
|
| 196 |
+
parameters={
|
| 197 |
+
"claim": {
|
| 198 |
+
"mass_estimate_gev": mass,
|
| 199 |
+
"mass_uncertainty_gev": 1.0,
|
| 200 |
+
"width_estimate_gev": 0.01,
|
| 201 |
+
"significance_sigma": sig,
|
| 202 |
+
"decay_channel": obs.selected_channel or "diphoton",
|
| 203 |
+
"spin_hypothesis": 0,
|
| 204 |
+
"parity": "+",
|
| 205 |
+
"cross_section_fb": 50.0,
|
| 206 |
+
"confidence": 0.8,
|
| 207 |
+
}
|
| 208 |
+
},
|
| 209 |
+
confidence=0.85,
|
| 210 |
+
justification="submit best calibrated claim",
|
| 211 |
+
)
|
| 212 |
+
return ExperimentAction(
|
| 213 |
+
action_type=ActionType.REQUEST_THEORY_REVIEW,
|
| 214 |
+
confidence=0.3,
|
| 215 |
+
justification="filler step (claim already submitted)",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ── Oracle agent ─────────────────────────────────────────────────────────
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@dataclass
|
| 223 |
+
class OracleAgent:
|
| 224 |
+
"""An oracle that *peeks* at the latent particle truth (only available
|
| 225 |
+
for in-process evaluation; never used remotely). This is the upper bound
|
| 226 |
+
of what a perfect agent could achieve given the noise budget.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
name: str = "oracle"
|
| 230 |
+
truth: Optional[dict] = None # set externally before the episode
|
| 231 |
+
|
| 232 |
+
def reset(self) -> None:
|
| 233 |
+
self._stage = 0
|
| 234 |
+
self._claim_submitted = False
|
| 235 |
+
|
| 236 |
+
def act(self, obs: CollisionObservation) -> ExperimentAction:
|
| 237 |
+
truth = self.truth or {}
|
| 238 |
+
true_channel = truth.get("primary_channel", obs.selected_channel or "diphoton")
|
| 239 |
+
trigger_for_channel = {
|
| 240 |
+
"diphoton": "diphoton_hlt",
|
| 241 |
+
"dilepton_ee": "dilepton_hlt",
|
| 242 |
+
"dilepton_mumu": "dilepton_hlt",
|
| 243 |
+
"four_lepton": "dilepton_hlt",
|
| 244 |
+
"dijet": "jet_hlt",
|
| 245 |
+
"bb": "jet_hlt",
|
| 246 |
+
}.get(true_channel, "high_pt")
|
| 247 |
+
|
| 248 |
+
plan = [
|
| 249 |
+
ExperimentAction(action_type=ActionType.CONFIGURE_BEAM, parameters={"beam_energy": "13TeV"}, confidence=0.95),
|
| 250 |
+
ExperimentAction(action_type=ActionType.SELECT_CHANNEL, parameters={"channel": true_channel}, confidence=0.99),
|
| 251 |
+
ExperimentAction(action_type=ActionType.SET_TRIGGER, parameters={"trigger": trigger_for_channel}, confidence=0.95),
|
| 252 |
+
ExperimentAction(action_type=ActionType.ALLOCATE_LUMINOSITY, parameters={"luminosity_fb": 120.0}, confidence=0.9),
|
| 253 |
+
ExperimentAction(action_type=ActionType.COLLECT_COLLISIONS, parameters={"luminosity_fb": 120.0}, confidence=0.9),
|
| 254 |
+
ExperimentAction(action_type=ActionType.RECONSTRUCT_TRACKS, method="Athena", confidence=0.95),
|
| 255 |
+
ExperimentAction(action_type=ActionType.CALIBRATE_DETECTOR, method="ECAL_calibration", confidence=0.9),
|
| 256 |
+
ExperimentAction(
|
| 257 |
+
action_type=ActionType.BUILD_INVARIANT_MASS,
|
| 258 |
+
parameters={
|
| 259 |
+
"mass_window_gev": [
|
| 260 |
+
max(50.0, float(truth.get("mass_gev", 100.0)) - 50.0),
|
| 261 |
+
float(truth.get("mass_gev", 100.0)) + 80.0,
|
| 262 |
+
],
|
| 263 |
+
"n_bins": 80,
|
| 264 |
+
},
|
| 265 |
+
confidence=0.95,
|
| 266 |
+
),
|
| 267 |
+
ExperimentAction(action_type=ActionType.SUBTRACT_BACKGROUND, confidence=0.9),
|
| 268 |
+
ExperimentAction(action_type=ActionType.FIT_RESONANCE, method="ROOT_RooFit", confidence=0.95),
|
| 269 |
+
ExperimentAction(action_type=ActionType.REQUEST_SYSTEMATICS, method="Luminosity_calibration", confidence=0.9),
|
| 270 |
+
ExperimentAction(action_type=ActionType.ESTIMATE_SIGNIFICANCE, method="Asimov_significance", confidence=0.95),
|
| 271 |
+
ExperimentAction(action_type=ActionType.MEASURE_ANGULAR, confidence=0.85),
|
| 272 |
+
]
|
| 273 |
+
if self._stage < len(plan):
|
| 274 |
+
a = plan[self._stage]
|
| 275 |
+
self._stage += 1
|
| 276 |
+
return a
|
| 277 |
+
|
| 278 |
+
if not self._claim_submitted:
|
| 279 |
+
self._claim_submitted = True
|
| 280 |
+
return ExperimentAction(
|
| 281 |
+
action_type=ActionType.SUBMIT_DISCOVERY_CLAIM,
|
| 282 |
+
parameters={
|
| 283 |
+
"claim": {
|
| 284 |
+
"mass_estimate_gev": float(truth.get("mass_gev", 125.0)),
|
| 285 |
+
"mass_uncertainty_gev": 0.5,
|
| 286 |
+
"width_estimate_gev": float(truth.get("width_gev", 0.01)),
|
| 287 |
+
"significance_sigma": max(obs.cumulative_significance, 5.0),
|
| 288 |
+
"decay_channel": true_channel,
|
| 289 |
+
"spin_hypothesis": int(truth.get("spin", 0)),
|
| 290 |
+
"parity": str(truth.get("parity", "+")),
|
| 291 |
+
"cross_section_fb": float(truth.get("cross_section_fb", 50.0)),
|
| 292 |
+
"confidence": 0.95,
|
| 293 |
+
}
|
| 294 |
+
},
|
| 295 |
+
confidence=0.95,
|
| 296 |
+
justification="oracle claim from hidden truth",
|
| 297 |
+
)
|
| 298 |
+
return ExperimentAction(
|
| 299 |
+
action_type=ActionType.REQUEST_THEORY_REVIEW,
|
| 300 |
+
confidence=0.5,
|
| 301 |
+
justification="oracle filler",
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
__all__ = ["CernAgent", "RandomAgent", "HeuristicAgent", "OracleAgent"]
|
scripts/push_to_hub.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Push CERNenv artefacts to the Hugging Face Hub.
|
| 2 |
+
|
| 3 |
+
Two subcommands:
|
| 4 |
+
|
| 5 |
+
* ``model`` — push trained LoRA adapters (output of ``training_unsloth.py``)
|
| 6 |
+
to a model repo. Generates a model card describing the run.
|
| 7 |
+
|
| 8 |
+
* ``space`` — push a directory as a Hugging Face Space
|
| 9 |
+
(e.g. ``space/training`` for the trainer Space, or the project root
|
| 10 |
+
to publish the env Space). Front-matter is taken from the README.md
|
| 11 |
+
inside the directory.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python -m scripts.push_to_hub model \\
|
| 15 |
+
--adapter_dir runs/unsloth-grpo \\
|
| 16 |
+
--repo_id YOUR_HF_USERNAME/cernenv-grpo-qwen2.5-3b \\
|
| 17 |
+
--base_model unsloth/Qwen2.5-3B-Instruct
|
| 18 |
+
|
| 19 |
+
python -m scripts.push_to_hub space \\
|
| 20 |
+
--space_dir space/training \\
|
| 21 |
+
--repo_id YOUR_HF_USERNAME/cernenv-trainer \\
|
| 22 |
+
--hardware a100-large
|
| 23 |
+
|
| 24 |
+
python -m scripts.push_to_hub space \\
|
| 25 |
+
--space_dir . \\
|
| 26 |
+
--repo_id YOUR_HF_USERNAME/cernenv \\
|
| 27 |
+
--include "models.py" "server/**" "openenv.yaml" "pyproject.toml" "client.py" "README.md"
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import argparse
|
| 33 |
+
import logging
|
| 34 |
+
import os
|
| 35 |
+
import sys
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
from typing import Iterable, List, Optional
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
DEFAULT_SPACE_EXCLUDES: List[str] = [
|
| 45 |
+
".venv/**",
|
| 46 |
+
"__pycache__/**",
|
| 47 |
+
"**/__pycache__/**",
|
| 48 |
+
"*.pyc",
|
| 49 |
+
".cursor/**",
|
| 50 |
+
".git/**",
|
| 51 |
+
".DS_Store",
|
| 52 |
+
"**/.DS_Store",
|
| 53 |
+
"runs/**",
|
| 54 |
+
"training/runs/**",
|
| 55 |
+
"training/plots/**",
|
| 56 |
+
"wandb/**",
|
| 57 |
+
"*.zip",
|
| 58 |
+
"*.apk",
|
| 59 |
+
"*.png",
|
| 60 |
+
"*.jpg",
|
| 61 |
+
"*.jpeg",
|
| 62 |
+
"[External]*.txt",
|
| 63 |
+
"Hackathon FAQs*.txt",
|
| 64 |
+
"*.log",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _hf_login() -> None:
|
| 69 |
+
from huggingface_hub import login
|
| 70 |
+
|
| 71 |
+
token = os.environ.get("HF_TOKEN")
|
| 72 |
+
if not token:
|
| 73 |
+
raise SystemExit(
|
| 74 |
+
"HF_TOKEN environment variable is required (write-scoped Hugging Face token)."
|
| 75 |
+
)
|
| 76 |
+
login(token=token)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _model_card(*, repo_id: str, base_model: str, run_dir: Path) -> str:
|
| 80 |
+
return f"""---
|
| 81 |
+
license: bsd-3-clause
|
| 82 |
+
library_name: peft
|
| 83 |
+
base_model: {base_model}
|
| 84 |
+
tags:
|
| 85 |
+
- cernenv
|
| 86 |
+
- openenv
|
| 87 |
+
- reinforcement-learning
|
| 88 |
+
- grpo
|
| 89 |
+
- unsloth
|
| 90 |
+
- lora
|
| 91 |
+
- particle-physics
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
# {repo_id}
|
| 95 |
+
|
| 96 |
+
LoRA (Low-Rank Adaptation) adapters trained with **GRPO** (Group-Relative
|
| 97 |
+
Policy Optimization) inside the **CERNenv** OpenEnv environment — an
|
| 98 |
+
LHC (Large Hadron Collider) particle-discovery POMDP (Partially Observable
|
| 99 |
+
Markov Decision Process).
|
| 100 |
+
|
| 101 |
+
The agent (this model) plays the role of a high-energy physicist running an
|
| 102 |
+
analysis: it configures the beam, allocates luminosity, picks decay
|
| 103 |
+
channels and triggers, reconstructs events, fits resonances, estimates
|
| 104 |
+
significance, and finally submits a structured discovery claim that is
|
| 105 |
+
graded against a hidden ground-truth particle.
|
| 106 |
+
|
| 107 |
+
* Base model: `{base_model}`
|
| 108 |
+
* RL framework: TRL (Transformer Reinforcement Learning) GRPO
|
| 109 |
+
* Acceleration: Unsloth + 4-bit + LoRA
|
| 110 |
+
* Environment: [CERNenv](https://huggingface.co/spaces/{repo_id.split('/')[0]}/cernenv)
|
| 111 |
+
|
| 112 |
+
## Usage
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
from peft import PeftModel
|
| 116 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 117 |
+
|
| 118 |
+
base = "{base_model}"
|
| 119 |
+
adapter = "{repo_id}"
|
| 120 |
+
|
| 121 |
+
tokenizer = AutoTokenizer.from_pretrained(base)
|
| 122 |
+
model = AutoModelForCausalLM.from_pretrained(base, device_map="auto")
|
| 123 |
+
model = PeftModel.from_pretrained(model, adapter)
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
See the CERNenv repo for full evaluation, plots, and the `LLMAgent` wrapper.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def push_model(
|
| 131 |
+
*,
|
| 132 |
+
adapter_dir: str,
|
| 133 |
+
repo_id: str,
|
| 134 |
+
base_model: str,
|
| 135 |
+
private: bool,
|
| 136 |
+
) -> None:
|
| 137 |
+
from huggingface_hub import HfApi, create_repo
|
| 138 |
+
|
| 139 |
+
_hf_login()
|
| 140 |
+
api = HfApi()
|
| 141 |
+
|
| 142 |
+
run_dir = Path(adapter_dir)
|
| 143 |
+
if not run_dir.exists():
|
| 144 |
+
raise SystemExit(f"adapter_dir not found: {run_dir}")
|
| 145 |
+
|
| 146 |
+
create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
|
| 147 |
+
|
| 148 |
+
card_path = run_dir / "README.md"
|
| 149 |
+
card_path.write_text(_model_card(repo_id=repo_id, base_model=base_model, run_dir=run_dir))
|
| 150 |
+
|
| 151 |
+
logger.info("uploading %s → %s", run_dir, repo_id)
|
| 152 |
+
api.upload_folder(
|
| 153 |
+
folder_path=str(run_dir),
|
| 154 |
+
repo_id=repo_id,
|
| 155 |
+
repo_type="model",
|
| 156 |
+
commit_message="Upload CERNenv GRPO LoRA adapters",
|
| 157 |
+
)
|
| 158 |
+
logger.info("done: https://huggingface.co/%s", repo_id)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def push_space(
|
| 162 |
+
*,
|
| 163 |
+
space_dir: str,
|
| 164 |
+
repo_id: str,
|
| 165 |
+
hardware: Optional[str],
|
| 166 |
+
private: bool,
|
| 167 |
+
include: Optional[List[str]],
|
| 168 |
+
exclude: Optional[List[str]],
|
| 169 |
+
) -> None:
|
| 170 |
+
from huggingface_hub import HfApi, create_repo
|
| 171 |
+
|
| 172 |
+
_hf_login()
|
| 173 |
+
api = HfApi()
|
| 174 |
+
|
| 175 |
+
src = Path(space_dir).resolve()
|
| 176 |
+
if not src.exists():
|
| 177 |
+
raise SystemExit(f"space_dir not found: {src}")
|
| 178 |
+
|
| 179 |
+
create_repo(
|
| 180 |
+
repo_id=repo_id,
|
| 181 |
+
repo_type="space",
|
| 182 |
+
space_sdk="docker",
|
| 183 |
+
space_hardware=hardware,
|
| 184 |
+
private=private,
|
| 185 |
+
exist_ok=True,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
effective_exclude = list(DEFAULT_SPACE_EXCLUDES)
|
| 189 |
+
if exclude:
|
| 190 |
+
effective_exclude.extend(exclude)
|
| 191 |
+
|
| 192 |
+
logger.info("uploading %s → space:%s", src, repo_id)
|
| 193 |
+
logger.info("ignore patterns: %s", effective_exclude)
|
| 194 |
+
api.upload_folder(
|
| 195 |
+
folder_path=str(src),
|
| 196 |
+
repo_id=repo_id,
|
| 197 |
+
repo_type="space",
|
| 198 |
+
commit_message="Update CERNenv Space",
|
| 199 |
+
allow_patterns=include,
|
| 200 |
+
ignore_patterns=effective_exclude,
|
| 201 |
+
)
|
| 202 |
+
logger.info("done: https://huggingface.co/spaces/%s", repo_id)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def main() -> None: # pragma: no cover
|
| 206 |
+
parser = argparse.ArgumentParser()
|
| 207 |
+
sub = parser.add_subparsers(dest="cmd", required=True)
|
| 208 |
+
|
| 209 |
+
m = sub.add_parser("model", help="push trained LoRA adapters to the Hub")
|
| 210 |
+
m.add_argument("--adapter_dir", required=True)
|
| 211 |
+
m.add_argument("--repo_id", required=True)
|
| 212 |
+
m.add_argument("--base_model", required=True)
|
| 213 |
+
m.add_argument("--private", action="store_true")
|
| 214 |
+
|
| 215 |
+
s = sub.add_parser("space", help="push a directory as an HF Space")
|
| 216 |
+
s.add_argument("--space_dir", required=True)
|
| 217 |
+
s.add_argument("--repo_id", required=True)
|
| 218 |
+
s.add_argument("--hardware", default=None,
|
| 219 |
+
help="e.g. a100-large, t4-small, l4-medium")
|
| 220 |
+
s.add_argument("--private", action="store_true")
|
| 221 |
+
s.add_argument("--include", nargs="*", default=None,
|
| 222 |
+
help="glob patterns to include")
|
| 223 |
+
s.add_argument("--exclude", nargs="*", default=None,
|
| 224 |
+
help="glob patterns to exclude")
|
| 225 |
+
|
| 226 |
+
args = parser.parse_args()
|
| 227 |
+
|
| 228 |
+
if args.cmd == "model":
|
| 229 |
+
push_model(
|
| 230 |
+
adapter_dir=args.adapter_dir,
|
| 231 |
+
repo_id=args.repo_id,
|
| 232 |
+
base_model=args.base_model,
|
| 233 |
+
private=args.private,
|
| 234 |
+
)
|
| 235 |
+
elif args.cmd == "space":
|
| 236 |
+
push_space(
|
| 237 |
+
space_dir=args.space_dir,
|
| 238 |
+
repo_id=args.repo_id,
|
| 239 |
+
hardware=args.hardware,
|
| 240 |
+
private=args.private,
|
| 241 |
+
include=args.include,
|
| 242 |
+
exclude=args.exclude,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
if __name__ == "__main__": # pragma: no cover
|
| 247 |
+
main()
|
scripts/run_agent.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run a (non-LLM) baseline agent against the in-process environment.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
python -m scripts.run_agent --agent heuristic --scenario easy_diphoton_160 --seed 7
|
| 5 |
+
python -m scripts.run_agent --agent oracle --difficulty hard --episodes 5
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
from dataclasses import asdict
|
| 13 |
+
from typing import Any, Dict, List
|
| 14 |
+
|
| 15 |
+
from server.environment import CERNCollisionEnvironment
|
| 16 |
+
from scripts.baseline_agents import (
|
| 17 |
+
HeuristicAgent,
|
| 18 |
+
OracleAgent,
|
| 19 |
+
RandomAgent,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
AGENT_REGISTRY = {
|
| 24 |
+
"random": RandomAgent,
|
| 25 |
+
"heuristic": HeuristicAgent,
|
| 26 |
+
"oracle": OracleAgent,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def run_episode(
|
| 31 |
+
*,
|
| 32 |
+
agent_name: str,
|
| 33 |
+
difficulty: str | None,
|
| 34 |
+
scenario: str | None,
|
| 35 |
+
seed: int,
|
| 36 |
+
max_steps: int,
|
| 37 |
+
verbose: bool,
|
| 38 |
+
) -> Dict[str, Any]:
|
| 39 |
+
env = CERNCollisionEnvironment(max_steps=max_steps)
|
| 40 |
+
obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)
|
| 41 |
+
|
| 42 |
+
agent_cls = AGENT_REGISTRY[agent_name]
|
| 43 |
+
if agent_name == "random":
|
| 44 |
+
agent = agent_cls(seed=seed)
|
| 45 |
+
else:
|
| 46 |
+
agent = agent_cls()
|
| 47 |
+
if agent_name == "oracle":
|
| 48 |
+
agent.truth = env.hidden_truth()
|
| 49 |
+
|
| 50 |
+
agent.reset()
|
| 51 |
+
|
| 52 |
+
total_reward = 0.0
|
| 53 |
+
step_log: List[Dict[str, Any]] = []
|
| 54 |
+
while not obs.done:
|
| 55 |
+
action = agent.act(obs)
|
| 56 |
+
obs = env.step(action)
|
| 57 |
+
total_reward += float(obs.reward or 0.0)
|
| 58 |
+
if verbose:
|
| 59 |
+
print(
|
| 60 |
+
f" step {obs.step_index:2d} {action.action_type.value:24s} "
|
| 61 |
+
f"rew={obs.reward:+.3f} done={obs.done}"
|
| 62 |
+
)
|
| 63 |
+
step_log.append(
|
| 64 |
+
{
|
| 65 |
+
"step": obs.step_index,
|
| 66 |
+
"action": action.action_type.value,
|
| 67 |
+
"reward": float(obs.reward or 0.0),
|
| 68 |
+
"violations": obs.rule_violations,
|
| 69 |
+
}
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
summary = {
|
| 73 |
+
"agent": agent_name,
|
| 74 |
+
"scenario": env.state.scenario_name,
|
| 75 |
+
"difficulty": env.state.difficulty,
|
| 76 |
+
"seed": seed,
|
| 77 |
+
"total_reward": total_reward,
|
| 78 |
+
"cumulative_reward": float(env.state.cumulative_reward),
|
| 79 |
+
"terminal_reward": env.state.terminal_reward,
|
| 80 |
+
"discovered": env.state.discovered,
|
| 81 |
+
"correct_mass": env.state.correct_mass,
|
| 82 |
+
"correct_channel": env.state.correct_channel,
|
| 83 |
+
"correct_spin": env.state.correct_spin,
|
| 84 |
+
"steps": len(step_log),
|
| 85 |
+
"truth": env.hidden_truth(),
|
| 86 |
+
"log": step_log,
|
| 87 |
+
}
|
| 88 |
+
return summary
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def main() -> None:
|
| 92 |
+
parser = argparse.ArgumentParser()
|
| 93 |
+
parser.add_argument("--agent", choices=list(AGENT_REGISTRY), default="heuristic")
|
| 94 |
+
parser.add_argument("--scenario", default=None)
|
| 95 |
+
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default=None)
|
| 96 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 97 |
+
parser.add_argument("--episodes", type=int, default=1)
|
| 98 |
+
parser.add_argument("--max-steps", type=int, default=40)
|
| 99 |
+
parser.add_argument("--out", default=None, help="Optional path to dump JSON results")
|
| 100 |
+
parser.add_argument("--quiet", action="store_true")
|
| 101 |
+
args = parser.parse_args()
|
| 102 |
+
|
| 103 |
+
rollouts: List[Dict[str, Any]] = []
|
| 104 |
+
for ep in range(args.episodes):
|
| 105 |
+
seed = args.seed + ep
|
| 106 |
+
summary = run_episode(
|
| 107 |
+
agent_name=args.agent,
|
| 108 |
+
difficulty=args.difficulty,
|
| 109 |
+
scenario=args.scenario,
|
| 110 |
+
seed=seed,
|
| 111 |
+
max_steps=args.max_steps,
|
| 112 |
+
verbose=not args.quiet and args.episodes == 1,
|
| 113 |
+
)
|
| 114 |
+
rollouts.append(summary)
|
| 115 |
+
print(
|
| 116 |
+
f"[{ep + 1}/{args.episodes}] agent={args.agent} "
|
| 117 |
+
f"scenario={summary['scenario']} reward={summary['total_reward']:+.3f} "
|
| 118 |
+
f"discovered={summary['discovered']} correct_mass={summary['correct_mass']} "
|
| 119 |
+
f"correct_channel={summary['correct_channel']}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if args.out:
|
| 123 |
+
with open(args.out, "w") as f:
|
| 124 |
+
json.dump(rollouts, f, indent=2, default=str)
|
| 125 |
+
print(f"saved → {args.out}")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CERNenv server: OpenEnv FastAPI image
|
| 2 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 3 |
+
FROM ${BASE_IMAGE} AS builder
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
RUN apt-get update && \
|
| 8 |
+
apt-get install -y --no-install-recommends git curl && \
|
| 9 |
+
rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
ARG ENV_NAME=cernenv
|
| 12 |
+
|
| 13 |
+
COPY . /app/env
|
| 14 |
+
|
| 15 |
+
WORKDIR /app/env
|
| 16 |
+
|
| 17 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 18 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 19 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 20 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 24 |
+
if [ -f uv.lock ]; then \
|
| 25 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 26 |
+
else \
|
| 27 |
+
uv sync --no-install-project --no-editable; \
|
| 28 |
+
fi
|
| 29 |
+
|
| 30 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 31 |
+
if [ -f uv.lock ]; then \
|
| 32 |
+
uv sync --frozen --no-editable; \
|
| 33 |
+
else \
|
| 34 |
+
uv sync --no-editable; \
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
FROM ${BASE_IMAGE}
|
| 38 |
+
|
| 39 |
+
WORKDIR /app
|
| 40 |
+
|
| 41 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 42 |
+
COPY --from=builder /app/env /app/env
|
| 43 |
+
|
| 44 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 45 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 46 |
+
|
| 47 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 48 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 49 |
+
|
| 50 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""CERNenv server package."""
|
server/app.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI app exposing ``CERNCollisionEnvironment`` over the OpenEnv HTTP API."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from openenv.core.env_server import create_fastapi_app
|
| 9 |
+
|
| 10 |
+
from models import CollisionObservation, ExperimentAction
|
| 11 |
+
from server.environment import CERNCollisionEnvironment
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def make_env_factory(
|
| 15 |
+
max_steps: int,
|
| 16 |
+
default_difficulty: Optional[str],
|
| 17 |
+
):
|
| 18 |
+
def factory() -> CERNCollisionEnvironment:
|
| 19 |
+
return CERNCollisionEnvironment(
|
| 20 |
+
max_steps=max_steps,
|
| 21 |
+
default_difficulty=default_difficulty,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
return factory
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_app(
|
| 28 |
+
*,
|
| 29 |
+
max_steps: int = 40,
|
| 30 |
+
default_difficulty: Optional[str] = None,
|
| 31 |
+
):
|
| 32 |
+
"""Construct the FastAPI app with a per-session environment factory."""
|
| 33 |
+
factory = make_env_factory(max_steps=max_steps, default_difficulty=default_difficulty)
|
| 34 |
+
return create_fastapi_app(factory, ExperimentAction, CollisionObservation)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
app = build_app(
|
| 38 |
+
max_steps=int(os.getenv("CERNENV_MAX_STEPS", "40")),
|
| 39 |
+
default_difficulty=os.getenv("CERNENV_DEFAULT_DIFFICULTY") or None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main() -> None: # pragma: no cover - CLI entrypoint
|
| 44 |
+
import uvicorn
|
| 45 |
+
|
| 46 |
+
host = os.getenv("HOST", "0.0.0.0")
|
| 47 |
+
port = int(os.getenv("PORT", "8000"))
|
| 48 |
+
uvicorn.run("server.app:app", host=host, port=port, log_level="info")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__": # pragma: no cover
|
| 52 |
+
main()
|
server/environment.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""``CERNCollisionEnvironment``: orchestrates simulator + rules + rewards.
|
| 2 |
+
|
| 3 |
+
This is the OpenEnv-compatible ``Environment`` that the FastAPI app exposes.
|
| 4 |
+
It owns one episode at a time:
|
| 5 |
+
|
| 6 |
+
reset(seed) → builds a fresh latent state from a sampled scenario.
|
| 7 |
+
step(action) → validates → generates noisy output → updates state →
|
| 8 |
+
computes reward → builds the agent observation.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import uuid
|
| 15 |
+
from typing import Any, List, Optional
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_server import Environment, State
|
| 18 |
+
|
| 19 |
+
from models import (
|
| 20 |
+
AGENT_ENVIRONMENT_RULES,
|
| 21 |
+
ActionType,
|
| 22 |
+
CollisionObservation,
|
| 23 |
+
DiscoveryClaim,
|
| 24 |
+
ExperimentAction,
|
| 25 |
+
IntermediateOutput,
|
| 26 |
+
OutputType,
|
| 27 |
+
PipelineStepRecord,
|
| 28 |
+
ResourceUsage,
|
| 29 |
+
TaskSpec,
|
| 30 |
+
build_agent_system_prompt,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from server.rewards import (
|
| 34 |
+
RewardWeights,
|
| 35 |
+
compute_step_reward,
|
| 36 |
+
compute_terminal_reward,
|
| 37 |
+
)
|
| 38 |
+
from server.rules import RulesEngine, ViolationCode
|
| 39 |
+
from server.simulator import (
|
| 40 |
+
NoiseModel,
|
| 41 |
+
OutputGenerator,
|
| 42 |
+
TransitionEngine,
|
| 43 |
+
compute_action_cost,
|
| 44 |
+
)
|
| 45 |
+
from server.simulator.latent_state import FullLatentState
|
| 46 |
+
from server.tasks import sample_scenario, Scenario
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ── State container ──────────────────────────────────────────────────────
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CernState(State):
|
| 56 |
+
"""OpenEnv State subclass: includes hidden truth & runtime stats."""
|
| 57 |
+
|
| 58 |
+
scenario_name: Optional[str] = None
|
| 59 |
+
difficulty: Optional[str] = None
|
| 60 |
+
episode_done: bool = False
|
| 61 |
+
cumulative_reward: float = 0.0
|
| 62 |
+
terminal_reward: Optional[float] = None
|
| 63 |
+
discovered: Optional[bool] = None
|
| 64 |
+
correct_mass: Optional[bool] = None
|
| 65 |
+
correct_channel: Optional[bool] = None
|
| 66 |
+
correct_spin: Optional[bool] = None
|
| 67 |
+
truth_mass_gev: Optional[float] = None
|
| 68 |
+
truth_channel: Optional[str] = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ── Environment ──────────────────────────────────────────────────────────
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CERNCollisionEnvironment(Environment[ExperimentAction, CollisionObservation, CernState]):
|
| 75 |
+
"""LHC particle-discovery POMDP environment."""
|
| 76 |
+
|
| 77 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
*,
|
| 82 |
+
max_steps: int = 40,
|
| 83 |
+
default_difficulty: Optional[str] = None,
|
| 84 |
+
default_scenario_name: Optional[str] = None,
|
| 85 |
+
reward_weights: Optional[RewardWeights] = None,
|
| 86 |
+
) -> None:
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.max_steps = max_steps
|
| 89 |
+
self.default_difficulty = default_difficulty
|
| 90 |
+
self.default_scenario_name = default_scenario_name
|
| 91 |
+
self.reward_weights = reward_weights or RewardWeights()
|
| 92 |
+
|
| 93 |
+
self._state = CernState()
|
| 94 |
+
self._scenario: Optional[Scenario] = None
|
| 95 |
+
self._latent: Optional[FullLatentState] = None
|
| 96 |
+
self._task: Optional[TaskSpec] = None
|
| 97 |
+
self._noise: Optional[NoiseModel] = None
|
| 98 |
+
self._output_gen: Optional[OutputGenerator] = None
|
| 99 |
+
self._transition: Optional[TransitionEngine] = None
|
| 100 |
+
self._rules: Optional[RulesEngine] = None
|
| 101 |
+
self._history: List[PipelineStepRecord] = []
|
| 102 |
+
self._all_outputs: List[IntermediateOutput] = []
|
| 103 |
+
|
| 104 |
+
# ── Environment API ────────────────────────────────────────────────
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def state(self) -> CernState:
|
| 108 |
+
return self._state
|
| 109 |
+
|
| 110 |
+
def reset(
|
| 111 |
+
self,
|
| 112 |
+
seed: Optional[int] = None,
|
| 113 |
+
episode_id: Optional[str] = None,
|
| 114 |
+
**kwargs: Any,
|
| 115 |
+
) -> CollisionObservation:
|
| 116 |
+
difficulty = kwargs.get("difficulty") or self.default_difficulty
|
| 117 |
+
scenario_name = kwargs.get("scenario") or self.default_scenario_name
|
| 118 |
+
|
| 119 |
+
scenario = sample_scenario(
|
| 120 |
+
difficulty=difficulty,
|
| 121 |
+
name=scenario_name,
|
| 122 |
+
seed=seed,
|
| 123 |
+
)
|
| 124 |
+
self._scenario = scenario
|
| 125 |
+
self._latent = scenario.fresh_latent()
|
| 126 |
+
self._task = scenario.task
|
| 127 |
+
if seed is not None:
|
| 128 |
+
self._latent.rng_seed = int(seed)
|
| 129 |
+
self._noise = NoiseModel(seed=self._latent.rng_seed)
|
| 130 |
+
self._output_gen = OutputGenerator(self._noise)
|
| 131 |
+
self._transition = TransitionEngine()
|
| 132 |
+
self._rules = RulesEngine(
|
| 133 |
+
mass_search_window_gev=tuple(self._task.mass_search_window_gev),
|
| 134 |
+
)
|
| 135 |
+
self._history = []
|
| 136 |
+
self._all_outputs = []
|
| 137 |
+
|
| 138 |
+
self._state = CernState(
|
| 139 |
+
episode_id=episode_id or f"ep-{uuid.uuid4().hex[:8]}",
|
| 140 |
+
step_count=0,
|
| 141 |
+
scenario_name=scenario.name,
|
| 142 |
+
difficulty=scenario.difficulty,
|
| 143 |
+
episode_done=False,
|
| 144 |
+
cumulative_reward=0.0,
|
| 145 |
+
truth_mass_gev=self._latent.particle.mass_gev,
|
| 146 |
+
truth_channel=self._latent.particle.primary_channel,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
obs = self._build_observation(
|
| 150 |
+
latest_output=None,
|
| 151 |
+
done=False,
|
| 152 |
+
reward=0.0,
|
| 153 |
+
step_breakdown={},
|
| 154 |
+
rule_violations=[],
|
| 155 |
+
)
|
| 156 |
+
return obs
|
| 157 |
+
|
| 158 |
+
def step(
|
| 159 |
+
self,
|
| 160 |
+
action: ExperimentAction,
|
| 161 |
+
timeout_s: Optional[float] = None,
|
| 162 |
+
**kwargs: Any,
|
| 163 |
+
) -> CollisionObservation:
|
| 164 |
+
if self._latent is None:
|
| 165 |
+
self.reset()
|
| 166 |
+
if self._state.episode_done:
|
| 167 |
+
return self._build_terminal_observation(reason="episode already complete")
|
| 168 |
+
|
| 169 |
+
assert self._rules is not None
|
| 170 |
+
assert self._output_gen is not None
|
| 171 |
+
assert self._transition is not None
|
| 172 |
+
|
| 173 |
+
prev_state = self._latent.model_copy(deep=True)
|
| 174 |
+
rule_result = self._rules.validate(action, self._latent)
|
| 175 |
+
|
| 176 |
+
if not rule_result.allowed:
|
| 177 |
+
output = IntermediateOutput(
|
| 178 |
+
output_type=OutputType.FAILURE_REPORT,
|
| 179 |
+
step_index=self._state.step_count,
|
| 180 |
+
success=False,
|
| 181 |
+
quality_score=0.0,
|
| 182 |
+
summary="Action rejected: " + "; ".join(rule_result.messages),
|
| 183 |
+
warnings=rule_result.messages,
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
output = self._output_gen.generate(
|
| 187 |
+
action=action,
|
| 188 |
+
state=self._latent,
|
| 189 |
+
step_index=self._state.step_count,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Apply transition (state mutation + cost accounting)
|
| 193 |
+
if rule_result.allowed:
|
| 194 |
+
self._transition.step(self._latent, action, output)
|
| 195 |
+
else:
|
| 196 |
+
cost = compute_action_cost(action, output)
|
| 197 |
+
self._latent.resources.budget_used_musd += cost["musd"]
|
| 198 |
+
self._latent.resources.time_used_days += cost["days"]
|
| 199 |
+
self._latent.step_count += 1
|
| 200 |
+
|
| 201 |
+
self._all_outputs.append(output)
|
| 202 |
+
cost = compute_action_cost(action, output)
|
| 203 |
+
record = PipelineStepRecord(
|
| 204 |
+
step_index=self._state.step_count,
|
| 205 |
+
action_type=action.action_type,
|
| 206 |
+
method=action.method,
|
| 207 |
+
parameters=action.parameters,
|
| 208 |
+
output_summary=output.summary,
|
| 209 |
+
output_type=output.output_type,
|
| 210 |
+
success=output.success,
|
| 211 |
+
quality_score=float(output.quality_score),
|
| 212 |
+
cost_musd=float(cost["musd"]),
|
| 213 |
+
luminosity_cost_fb=float(cost["luminosity_fb"]),
|
| 214 |
+
time_cost_days=float(cost["days"]),
|
| 215 |
+
)
|
| 216 |
+
self._history.append(record)
|
| 217 |
+
|
| 218 |
+
step_reward = compute_step_reward(
|
| 219 |
+
action=action,
|
| 220 |
+
output=output,
|
| 221 |
+
state_before=prev_state,
|
| 222 |
+
state_after=self._latent,
|
| 223 |
+
rule_result=rule_result,
|
| 224 |
+
weights=self.reward_weights,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
self._state.cumulative_reward += step_reward.reward
|
| 228 |
+
self._state.step_count += 1
|
| 229 |
+
|
| 230 |
+
terminal_now = (
|
| 231 |
+
action.action_type == ActionType.SUBMIT_DISCOVERY_CLAIM
|
| 232 |
+
and rule_result.allowed
|
| 233 |
+
)
|
| 234 |
+
time_up = (
|
| 235 |
+
self._state.step_count >= self.max_steps
|
| 236 |
+
or self._latent.resources.budget_exhausted
|
| 237 |
+
or self._latent.resources.time_exhausted
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
terminal_reward_value = 0.0
|
| 241 |
+
if terminal_now:
|
| 242 |
+
claim = self._claim_from_action(action)
|
| 243 |
+
term = compute_terminal_reward(
|
| 244 |
+
state=self._latent,
|
| 245 |
+
claim=claim,
|
| 246 |
+
weights=self.reward_weights,
|
| 247 |
+
)
|
| 248 |
+
terminal_reward_value = term.reward
|
| 249 |
+
self._state.cumulative_reward += terminal_reward_value
|
| 250 |
+
self._state.terminal_reward = terminal_reward_value
|
| 251 |
+
self._state.discovered = term.discovered
|
| 252 |
+
self._state.correct_mass = term.correct_mass
|
| 253 |
+
self._state.correct_channel = term.correct_channel
|
| 254 |
+
self._state.correct_spin = term.correct_spin
|
| 255 |
+
|
| 256 |
+
done = terminal_now or time_up
|
| 257 |
+
if done:
|
| 258 |
+
self._state.episode_done = True
|
| 259 |
+
|
| 260 |
+
observation = self._build_observation(
|
| 261 |
+
latest_output=output,
|
| 262 |
+
done=done,
|
| 263 |
+
reward=step_reward.reward + terminal_reward_value,
|
| 264 |
+
step_breakdown=step_reward.breakdown.components,
|
| 265 |
+
rule_violations=[
|
| 266 |
+
*(v.value for v in rule_result.violations),
|
| 267 |
+
*(v.value for v in rule_result.soft_violations),
|
| 268 |
+
],
|
| 269 |
+
)
|
| 270 |
+
return observation
|
| 271 |
+
|
| 272 |
+
# ── Helpers ────────────────────────────────────────────────────────
|
| 273 |
+
|
| 274 |
+
def _claim_from_action(self, action: ExperimentAction) -> DiscoveryClaim:
|
| 275 |
+
raw = action.parameters.get("claim") or {}
|
| 276 |
+
try:
|
| 277 |
+
return DiscoveryClaim(**raw)
|
| 278 |
+
except Exception as exc: # pragma: no cover - defensive
|
| 279 |
+
logger.warning("Malformed claim, defaulting: %s", exc)
|
| 280 |
+
return DiscoveryClaim()
|
| 281 |
+
|
| 282 |
+
def _build_terminal_observation(self, reason: str) -> CollisionObservation:
|
| 283 |
+
obs = self._build_observation(
|
| 284 |
+
latest_output=IntermediateOutput(
|
| 285 |
+
output_type=OutputType.FAILURE_REPORT,
|
| 286 |
+
step_index=self._state.step_count,
|
| 287 |
+
success=False,
|
| 288 |
+
summary=reason,
|
| 289 |
+
),
|
| 290 |
+
done=True,
|
| 291 |
+
reward=0.0,
|
| 292 |
+
step_breakdown={},
|
| 293 |
+
rule_violations=["episode_terminated"],
|
| 294 |
+
)
|
| 295 |
+
return obs
|
| 296 |
+
|
| 297 |
+
def _build_observation(
|
| 298 |
+
self,
|
| 299 |
+
*,
|
| 300 |
+
latest_output: Optional[IntermediateOutput],
|
| 301 |
+
done: bool,
|
| 302 |
+
reward: float,
|
| 303 |
+
step_breakdown: dict,
|
| 304 |
+
rule_violations: list,
|
| 305 |
+
) -> CollisionObservation:
|
| 306 |
+
assert self._latent is not None
|
| 307 |
+
assert self._task is not None
|
| 308 |
+
|
| 309 |
+
res = self._latent.resources
|
| 310 |
+
usage = ResourceUsage(
|
| 311 |
+
budget_used_musd=res.budget_used_musd,
|
| 312 |
+
budget_remaining_musd=res.budget_remaining,
|
| 313 |
+
luminosity_used_fb=res.luminosity_used_fb,
|
| 314 |
+
luminosity_remaining_fb=res.luminosity_remaining,
|
| 315 |
+
time_used_days=res.time_used_days,
|
| 316 |
+
time_remaining_days=res.time_remaining,
|
| 317 |
+
compute_hours_used=res.compute_hours_used,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
obs = CollisionObservation(
|
| 321 |
+
done=done,
|
| 322 |
+
reward=float(reward),
|
| 323 |
+
task=self._task,
|
| 324 |
+
step_index=self._state.step_count,
|
| 325 |
+
pipeline_history=list(self._history),
|
| 326 |
+
available_channels=self._task.available_channels,
|
| 327 |
+
available_triggers=self._task.available_triggers,
|
| 328 |
+
available_tools=self._task.available_tools,
|
| 329 |
+
resource_usage=usage,
|
| 330 |
+
latest_output=latest_output,
|
| 331 |
+
all_outputs=list(self._all_outputs),
|
| 332 |
+
candidate_masses_gev=list(self._latent.candidate_masses_gev),
|
| 333 |
+
candidate_significances=list(self._latent.candidate_significances),
|
| 334 |
+
selected_channel=self._latent.selected_channel,
|
| 335 |
+
selected_beam_energy=self._latent.selected_beam_energy,
|
| 336 |
+
cumulative_significance=float(
|
| 337 |
+
self._latent.progress.best_significance_sigma or 0.0
|
| 338 |
+
),
|
| 339 |
+
uncertainty_summary={
|
| 340 |
+
"energy_scale_unc_gev": self._latent.detector.energy_scale_uncertainty,
|
| 341 |
+
"luminosity_unc": self._latent.detector.luminosity_uncertainty,
|
| 342 |
+
"resolution_gev": self._latent.detector.detector_resolution_gev,
|
| 343 |
+
},
|
| 344 |
+
rule_violations=rule_violations,
|
| 345 |
+
step_reward_breakdown=dict(step_breakdown),
|
| 346 |
+
)
|
| 347 |
+
return obs
|
| 348 |
+
|
| 349 |
+
# ── Convenience for diagnostics ────────────────────────────────────
|
| 350 |
+
|
| 351 |
+
def hidden_truth(self) -> Optional[dict]:
|
| 352 |
+
"""Reveal the hidden particle (debug / evaluation only)."""
|
| 353 |
+
if self._latent is None:
|
| 354 |
+
return None
|
| 355 |
+
return self._latent.particle.model_dump()
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
__all__ = [
|
| 359 |
+
"CernState",
|
| 360 |
+
"CERNCollisionEnvironment",
|
| 361 |
+
"AGENT_ENVIRONMENT_RULES",
|
| 362 |
+
"build_agent_system_prompt",
|
| 363 |
+
]
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.0
|
| 2 |
+
numpy>=1.24.0
|
| 3 |
+
scipy>=1.10.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
fastapi>=0.110.0
|
| 6 |
+
uvicorn>=0.27.0
|
server/rewards/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward components for CERNenv."""
|
| 2 |
+
|
| 3 |
+
from .reward_function import (
|
| 4 |
+
RewardBreakdown,
|
| 5 |
+
RewardWeights,
|
| 6 |
+
StepReward,
|
| 7 |
+
TerminalReward,
|
| 8 |
+
compute_step_reward,
|
| 9 |
+
compute_terminal_reward,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"RewardBreakdown",
|
| 14 |
+
"RewardWeights",
|
| 15 |
+
"StepReward",
|
| 16 |
+
"TerminalReward",
|
| 17 |
+
"compute_step_reward",
|
| 18 |
+
"compute_terminal_reward",
|
| 19 |
+
]
|
server/rewards/reward_function.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decomposed reward function.
|
| 2 |
+
|
| 3 |
+
Two stages:
|
| 4 |
+
1. **Per-step reward** ``compute_step_reward``: shapes behaviour with small
|
| 5 |
+
incentives (progress, evidence quality, valid prerequisites) and
|
| 6 |
+
penalties (rule violations, repeated work, wasted resources).
|
| 7 |
+
2. **Terminal reward** ``compute_terminal_reward``: graded only when the
|
| 8 |
+
agent submits a discovery claim or runs out of resources. Compares the
|
| 9 |
+
submitted claim against the hidden ``LatentParticle`` truth.
|
| 10 |
+
|
| 11 |
+
The terminal reward is intentionally dominant so the policy must care about
|
| 12 |
+
the *correct* discovery, not just looking busy.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Dict, List, Optional
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from models import (
|
| 23 |
+
ActionType,
|
| 24 |
+
DiscoveryClaim,
|
| 25 |
+
ExperimentAction,
|
| 26 |
+
IntermediateOutput,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from server.rules.engine import RuleResult, ViolationCode
|
| 30 |
+
from server.simulator.latent_state import FullLatentState
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ── Configuration ────────────────────────────────────────────────────────
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class RewardWeights:
|
| 38 |
+
# ── per-step shaping ────────────────────────────────────────
|
| 39 |
+
valid_action: float = 0.05
|
| 40 |
+
progress_milestone: float = 0.25
|
| 41 |
+
evidence_quality: float = 0.20
|
| 42 |
+
tool_fit: float = 0.10
|
| 43 |
+
soft_violation: float = -0.05
|
| 44 |
+
hard_violation: float = -0.50
|
| 45 |
+
redundancy: float = -0.10
|
| 46 |
+
resource_overspend: float = -0.30
|
| 47 |
+
failure: float = -0.30
|
| 48 |
+
|
| 49 |
+
# ── terminal grading ────────────────────────────────────────
|
| 50 |
+
terminal_scale: float = 5.0 # multiplied with the convex sum below
|
| 51 |
+
|
| 52 |
+
mass_calibration: float = 0.30
|
| 53 |
+
significance_quality: float = 0.20
|
| 54 |
+
channel_correctness: float = 0.20
|
| 55 |
+
spin_correctness: float = 0.10
|
| 56 |
+
width_calibration: float = 0.05
|
| 57 |
+
confidence_calibration: float = 0.10
|
| 58 |
+
efficiency_bonus: float = 0.05
|
| 59 |
+
|
| 60 |
+
overconfident_wrong_penalty: float = 4.0 # subtracted from terminal
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ── Outputs ──────────────────────────────────────────────────────────────
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class RewardBreakdown:
|
| 68 |
+
components: Dict[str, float] = field(default_factory=dict)
|
| 69 |
+
total: float = 0.0
|
| 70 |
+
|
| 71 |
+
def add(self, key: str, value: float) -> None:
|
| 72 |
+
self.components[key] = self.components.get(key, 0.0) + value
|
| 73 |
+
self.total += value
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class StepReward:
|
| 78 |
+
reward: float
|
| 79 |
+
breakdown: RewardBreakdown
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class TerminalReward:
|
| 84 |
+
reward: float
|
| 85 |
+
breakdown: RewardBreakdown
|
| 86 |
+
discovered: bool
|
| 87 |
+
correct_mass: bool
|
| 88 |
+
correct_channel: bool
|
| 89 |
+
correct_spin: bool
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ── Per-step ─────────────────────────────────────────────────────────────
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
_PROGRESS_FLAGS = [
|
| 96 |
+
"beam_configured",
|
| 97 |
+
"luminosity_allocated",
|
| 98 |
+
"trigger_set",
|
| 99 |
+
"collisions_collected",
|
| 100 |
+
"channel_selected",
|
| 101 |
+
"tracks_reconstructed",
|
| 102 |
+
"detector_calibrated",
|
| 103 |
+
"invariant_mass_built",
|
| 104 |
+
"background_subtracted",
|
| 105 |
+
"resonance_fitted",
|
| 106 |
+
"significance_estimated",
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _milestone_progress(state_before: FullLatentState, state_after: FullLatentState) -> int:
|
| 111 |
+
"""Number of new progress milestones unlocked this step."""
|
| 112 |
+
delta = 0
|
| 113 |
+
for flag in _PROGRESS_FLAGS:
|
| 114 |
+
was = getattr(state_before.progress, flag)
|
| 115 |
+
now = getattr(state_after.progress, flag)
|
| 116 |
+
if now and not was:
|
| 117 |
+
delta += 1
|
| 118 |
+
return delta
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def compute_step_reward(
|
| 122 |
+
*,
|
| 123 |
+
action: ExperimentAction,
|
| 124 |
+
output: IntermediateOutput,
|
| 125 |
+
state_before: FullLatentState,
|
| 126 |
+
state_after: FullLatentState,
|
| 127 |
+
rule_result: RuleResult,
|
| 128 |
+
weights: RewardWeights = RewardWeights(),
|
| 129 |
+
) -> StepReward:
|
| 130 |
+
breakdown = RewardBreakdown()
|
| 131 |
+
|
| 132 |
+
if rule_result.allowed and output.success:
|
| 133 |
+
breakdown.add("valid_action", weights.valid_action)
|
| 134 |
+
if not output.success:
|
| 135 |
+
breakdown.add("failure", weights.failure)
|
| 136 |
+
|
| 137 |
+
# progress
|
| 138 |
+
new_milestones = _milestone_progress(state_before, state_after)
|
| 139 |
+
if new_milestones > 0:
|
| 140 |
+
breakdown.add("progress", weights.progress_milestone * new_milestones)
|
| 141 |
+
|
| 142 |
+
# evidence quality
|
| 143 |
+
if output.success:
|
| 144 |
+
breakdown.add("evidence_quality", weights.evidence_quality * float(output.quality_score))
|
| 145 |
+
|
| 146 |
+
# tool fit (named method exists in the recommended toolset)
|
| 147 |
+
if action.method:
|
| 148 |
+
breakdown.add("tool_fit", weights.tool_fit * 0.5)
|
| 149 |
+
|
| 150 |
+
# rule penalties
|
| 151 |
+
if rule_result.violations:
|
| 152 |
+
breakdown.add("hard_violation", weights.hard_violation * len(rule_result.violations))
|
| 153 |
+
if rule_result.soft_violations:
|
| 154 |
+
soft_redundant = sum(1 for v in rule_result.soft_violations if v == ViolationCode.REDUNDANT)
|
| 155 |
+
soft_other = len(rule_result.soft_violations) - soft_redundant
|
| 156 |
+
if soft_redundant:
|
| 157 |
+
breakdown.add("redundancy", weights.redundancy * soft_redundant)
|
| 158 |
+
if soft_other:
|
| 159 |
+
breakdown.add("soft_violation", weights.soft_violation * soft_other)
|
| 160 |
+
|
| 161 |
+
# resource overspend
|
| 162 |
+
res = state_after.resources
|
| 163 |
+
if res.budget_used_musd > res.budget_total_musd:
|
| 164 |
+
breakdown.add("budget_overspend", weights.resource_overspend)
|
| 165 |
+
if res.luminosity_used_fb > res.luminosity_total_fb:
|
| 166 |
+
breakdown.add("lumi_overspend", weights.resource_overspend)
|
| 167 |
+
if res.time_used_days > res.time_limit_days:
|
| 168 |
+
breakdown.add("time_overspend", weights.resource_overspend)
|
| 169 |
+
|
| 170 |
+
return StepReward(reward=float(breakdown.total), breakdown=breakdown)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ── Terminal grading ─────────────────────────────────────────────────────
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _mass_score(true_mass: float, claim_mass: Optional[float], unc: Optional[float]) -> float:
|
| 177 |
+
"""1.0 within 1σ, smoothly decays to 0 by 5 GeV (or 5σ, whichever larger)."""
|
| 178 |
+
if claim_mass is None or true_mass <= 0:
|
| 179 |
+
return 0.0
|
| 180 |
+
err = abs(claim_mass - true_mass)
|
| 181 |
+
# Tolerance: max(1.0 GeV, 1% of true mass, claimed unc)
|
| 182 |
+
tol = max(1.0, 0.01 * true_mass)
|
| 183 |
+
if unc is not None and unc > 0:
|
| 184 |
+
tol = max(tol, float(unc))
|
| 185 |
+
if err <= tol:
|
| 186 |
+
return 1.0
|
| 187 |
+
if err >= 5 * tol:
|
| 188 |
+
return 0.0
|
| 189 |
+
return float(np.clip(1.0 - (err - tol) / (4 * tol), 0.0, 1.0))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _significance_score(state: FullLatentState, claim_sigma: Optional[float]) -> float:
|
| 193 |
+
"""High score when claimed σ matches measured σ and is ≥ 5."""
|
| 194 |
+
measured = state.progress.best_significance_sigma or 0.0
|
| 195 |
+
if claim_sigma is None:
|
| 196 |
+
return 0.0
|
| 197 |
+
over_claim = max(0.0, claim_sigma - measured)
|
| 198 |
+
base = float(np.clip(measured / 5.0, 0.0, 1.0))
|
| 199 |
+
penalty = float(np.clip(over_claim / 3.0, 0.0, 1.0))
|
| 200 |
+
return float(np.clip(base - 0.5 * penalty, 0.0, 1.0))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _confidence_calibration(claim_conf: float, mass_score: float, channel_correct: bool) -> float:
|
| 204 |
+
"""Reward agents whose confidence tracks their actual accuracy."""
|
| 205 |
+
actual = 0.5 * mass_score + 0.5 * (1.0 if channel_correct else 0.0)
|
| 206 |
+
err = abs(actual - claim_conf)
|
| 207 |
+
return float(np.clip(1.0 - err, 0.0, 1.0))
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _efficiency_bonus(state: FullLatentState) -> float:
|
| 211 |
+
"""Reward leftover budget (encourages succinct experiments)."""
|
| 212 |
+
res = state.resources
|
| 213 |
+
score = 0.0
|
| 214 |
+
score += np.clip(res.budget_remaining / res.budget_total_musd, 0.0, 1.0)
|
| 215 |
+
score += np.clip(res.luminosity_remaining / res.luminosity_total_fb, 0.0, 1.0)
|
| 216 |
+
score += np.clip(res.time_remaining / res.time_limit_days, 0.0, 1.0)
|
| 217 |
+
return float(score / 3.0)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def compute_terminal_reward(
|
| 221 |
+
*,
|
| 222 |
+
state: FullLatentState,
|
| 223 |
+
claim: DiscoveryClaim,
|
| 224 |
+
weights: RewardWeights = RewardWeights(),
|
| 225 |
+
) -> TerminalReward:
|
| 226 |
+
breakdown = RewardBreakdown()
|
| 227 |
+
truth = state.particle
|
| 228 |
+
|
| 229 |
+
mass_score = _mass_score(truth.mass_gev, claim.mass_estimate_gev, claim.mass_uncertainty_gev)
|
| 230 |
+
breakdown.add("mass_calibration", weights.mass_calibration * mass_score)
|
| 231 |
+
|
| 232 |
+
sig_score = _significance_score(state, claim.significance_sigma)
|
| 233 |
+
breakdown.add("significance_quality", weights.significance_quality * sig_score)
|
| 234 |
+
|
| 235 |
+
channel_ok = claim.decay_channel == truth.primary_channel
|
| 236 |
+
breakdown.add("channel_correctness", weights.channel_correctness * (1.0 if channel_ok else 0.0))
|
| 237 |
+
|
| 238 |
+
spin_ok = claim.spin_hypothesis is not None and claim.spin_hypothesis == truth.spin
|
| 239 |
+
breakdown.add("spin_correctness", weights.spin_correctness * (1.0 if spin_ok else 0.0))
|
| 240 |
+
|
| 241 |
+
width_score = 0.0
|
| 242 |
+
if claim.width_estimate_gev is not None and truth.width_gev > 0:
|
| 243 |
+
rel = abs(claim.width_estimate_gev - truth.width_gev) / max(truth.width_gev, 1e-3)
|
| 244 |
+
width_score = float(np.clip(1.0 - rel, 0.0, 1.0))
|
| 245 |
+
breakdown.add("width_calibration", weights.width_calibration * width_score)
|
| 246 |
+
|
| 247 |
+
conf_score = _confidence_calibration(claim.confidence, mass_score, channel_ok)
|
| 248 |
+
breakdown.add("confidence_calibration", weights.confidence_calibration * conf_score)
|
| 249 |
+
|
| 250 |
+
eff_score = _efficiency_bonus(state)
|
| 251 |
+
breakdown.add("efficiency_bonus", weights.efficiency_bonus * eff_score)
|
| 252 |
+
|
| 253 |
+
discovered = (
|
| 254 |
+
mass_score >= 0.5
|
| 255 |
+
and channel_ok
|
| 256 |
+
and (claim.significance_sigma or 0.0) >= 4.5
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
raw = breakdown.total * weights.terminal_scale
|
| 260 |
+
|
| 261 |
+
# Overconfident-wrong penalty: high confidence but wrong channel & far mass
|
| 262 |
+
if claim.confidence >= 0.8 and (mass_score < 0.2 or not channel_ok):
|
| 263 |
+
raw -= weights.overconfident_wrong_penalty
|
| 264 |
+
breakdown.add("overconfident_wrong", -weights.overconfident_wrong_penalty)
|
| 265 |
+
|
| 266 |
+
return TerminalReward(
|
| 267 |
+
reward=float(raw),
|
| 268 |
+
breakdown=breakdown,
|
| 269 |
+
discovered=discovered,
|
| 270 |
+
correct_mass=mass_score >= 0.5,
|
| 271 |
+
correct_channel=channel_ok,
|
| 272 |
+
correct_spin=spin_ok,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
__all__ = [
|
| 277 |
+
"RewardBreakdown",
|
| 278 |
+
"RewardWeights",
|
| 279 |
+
"StepReward",
|
| 280 |
+
"TerminalReward",
|
| 281 |
+
"compute_step_reward",
|
| 282 |
+
"compute_terminal_reward",
|
| 283 |
+
]
|
server/rules/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Rules engine: prerequisites, resources, redundancy, claim validity."""
|
| 2 |
+
|
| 3 |
+
from .engine import RuleResult, RulesEngine, ViolationCode
|
| 4 |
+
|
| 5 |
+
__all__ = ["RuleResult", "RulesEngine", "ViolationCode"]
|
server/rules/engine.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RulesEngine for CERNenv.
|
| 2 |
+
|
| 3 |
+
Validates an incoming ``ExperimentAction`` against the current latent state
|
| 4 |
+
*before* it is executed. Rule violations are reported back as warnings on the
|
| 5 |
+
observation and feed into the per-step penalty in the reward function.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from enum import Enum
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
|
| 14 |
+
from models import (
|
| 15 |
+
ActionType,
|
| 16 |
+
DetectorChannel,
|
| 17 |
+
ExperimentAction,
|
| 18 |
+
TriggerType,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from server.simulator.latent_state import FullLatentState
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ViolationCode(str, Enum):
|
| 25 |
+
PREREQ_MISSING = "prerequisite_missing"
|
| 26 |
+
BUDGET_EXHAUSTED = "budget_exhausted"
|
| 27 |
+
LUMI_EXHAUSTED = "luminosity_exhausted"
|
| 28 |
+
TIME_EXHAUSTED = "time_exhausted"
|
| 29 |
+
REDUNDANT = "redundant"
|
| 30 |
+
INVALID_PARAMS = "invalid_parameters"
|
| 31 |
+
INVALID_CLAIM = "invalid_claim"
|
| 32 |
+
CHANNEL_MISMATCH = "channel_mismatch"
|
| 33 |
+
OUT_OF_WINDOW = "out_of_search_window"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class RuleResult:
|
| 38 |
+
allowed: bool
|
| 39 |
+
violations: List[ViolationCode] = field(default_factory=list)
|
| 40 |
+
messages: List[str] = field(default_factory=list)
|
| 41 |
+
soft_violations: List[ViolationCode] = field(default_factory=list)
|
| 42 |
+
|
| 43 |
+
def add(self, code: ViolationCode, msg: str, soft: bool = False) -> None:
|
| 44 |
+
self.messages.append(msg)
|
| 45 |
+
if soft:
|
| 46 |
+
self.soft_violations.append(code)
|
| 47 |
+
else:
|
| 48 |
+
self.violations.append(code)
|
| 49 |
+
self.allowed = False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class RulesEngine:
|
| 53 |
+
"""Stateless validator (state is passed in)."""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
mass_search_window_gev: tuple[float, float] = (50.0, 1000.0),
|
| 58 |
+
) -> None:
|
| 59 |
+
self.mass_search_window_gev = mass_search_window_gev
|
| 60 |
+
|
| 61 |
+
# ── Public API ─────────────────────────────────────────────────────
|
| 62 |
+
|
| 63 |
+
def validate(
|
| 64 |
+
self,
|
| 65 |
+
action: ExperimentAction,
|
| 66 |
+
state: FullLatentState,
|
| 67 |
+
) -> RuleResult:
|
| 68 |
+
result = RuleResult(allowed=True)
|
| 69 |
+
|
| 70 |
+
# ── resource gating (hard) ────────────────────────────────
|
| 71 |
+
if state.resources.budget_exhausted:
|
| 72 |
+
result.add(ViolationCode.BUDGET_EXHAUSTED, "Budget fully spent.")
|
| 73 |
+
if state.resources.time_exhausted:
|
| 74 |
+
result.add(ViolationCode.TIME_EXHAUSTED, "Time budget exhausted.")
|
| 75 |
+
# luminosity exhaustion only blocks DAQ-style actions
|
| 76 |
+
if (
|
| 77 |
+
state.resources.luminosity_exhausted
|
| 78 |
+
and action.action_type in {
|
| 79 |
+
ActionType.ALLOCATE_LUMINOSITY,
|
| 80 |
+
ActionType.COLLECT_COLLISIONS,
|
| 81 |
+
}
|
| 82 |
+
):
|
| 83 |
+
result.add(ViolationCode.LUMI_EXHAUSTED, "Integrated luminosity budget spent.")
|
| 84 |
+
|
| 85 |
+
if not result.allowed:
|
| 86 |
+
return result
|
| 87 |
+
|
| 88 |
+
a = action.action_type
|
| 89 |
+
prog = state.progress
|
| 90 |
+
|
| 91 |
+
# ── prerequisites ──────────────────────────────────────────
|
| 92 |
+
if a == ActionType.COLLECT_COLLISIONS:
|
| 93 |
+
if not prog.beam_configured:
|
| 94 |
+
result.add(ViolationCode.PREREQ_MISSING, "Configure the beam first.")
|
| 95 |
+
if not prog.luminosity_allocated:
|
| 96 |
+
result.add(ViolationCode.PREREQ_MISSING, "Allocate luminosity first.")
|
| 97 |
+
if not prog.trigger_set:
|
| 98 |
+
result.add(ViolationCode.PREREQ_MISSING, "Set a trigger first.")
|
| 99 |
+
if not state.selected_channel:
|
| 100 |
+
result.add(ViolationCode.PREREQ_MISSING, "Select a decay channel first.")
|
| 101 |
+
|
| 102 |
+
elif a == ActionType.BUILD_INVARIANT_MASS:
|
| 103 |
+
if not prog.collisions_collected:
|
| 104 |
+
result.add(ViolationCode.PREREQ_MISSING, "Collect collisions before building histograms.")
|
| 105 |
+
if not prog.tracks_reconstructed:
|
| 106 |
+
result.add(ViolationCode.PREREQ_MISSING, "Reconstruct tracks before building histograms.")
|
| 107 |
+
|
| 108 |
+
elif a == ActionType.SUBTRACT_BACKGROUND:
|
| 109 |
+
if not prog.invariant_mass_built:
|
| 110 |
+
result.add(ViolationCode.PREREQ_MISSING, "Build invariant-mass histogram first.")
|
| 111 |
+
|
| 112 |
+
elif a == ActionType.FIT_RESONANCE:
|
| 113 |
+
if not prog.invariant_mass_built:
|
| 114 |
+
result.add(ViolationCode.PREREQ_MISSING, "Build the histogram before fitting.")
|
| 115 |
+
|
| 116 |
+
elif a == ActionType.MEASURE_ANGULAR:
|
| 117 |
+
if not (prog.resonance_fitted or prog.bump_scanned):
|
| 118 |
+
result.add(
|
| 119 |
+
ViolationCode.PREREQ_MISSING,
|
| 120 |
+
"Identify a peak (fit or bump scan) before angular analysis.",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
elif a == ActionType.ESTIMATE_SIGNIFICANCE:
|
| 124 |
+
if not prog.collisions_collected:
|
| 125 |
+
result.add(ViolationCode.PREREQ_MISSING, "Collect data before significance estimation.")
|
| 126 |
+
|
| 127 |
+
elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
|
| 128 |
+
if not prog.resonance_fitted and not prog.bump_scanned:
|
| 129 |
+
result.add(ViolationCode.PREREQ_MISSING, "No fitted resonance or bump scan; cannot claim a discovery.")
|
| 130 |
+
if not prog.significance_estimated:
|
| 131 |
+
result.add(ViolationCode.PREREQ_MISSING, "Estimate significance before submitting a claim.")
|
| 132 |
+
|
| 133 |
+
# ── parameter & search-window validation (soft) ────────────
|
| 134 |
+
if a == ActionType.SELECT_CHANNEL:
|
| 135 |
+
channel = action.parameters.get("channel")
|
| 136 |
+
if channel:
|
| 137 |
+
try:
|
| 138 |
+
DetectorChannel(channel)
|
| 139 |
+
except ValueError:
|
| 140 |
+
result.add(ViolationCode.INVALID_PARAMS, f"Unknown channel '{channel}'.", soft=True)
|
| 141 |
+
|
| 142 |
+
if a == ActionType.SET_TRIGGER:
|
| 143 |
+
trig = action.parameters.get("trigger")
|
| 144 |
+
if trig:
|
| 145 |
+
try:
|
| 146 |
+
TriggerType(trig)
|
| 147 |
+
except ValueError:
|
| 148 |
+
result.add(ViolationCode.INVALID_PARAMS, f"Unknown trigger '{trig}'.", soft=True)
|
| 149 |
+
|
| 150 |
+
if a == ActionType.BUILD_INVARIANT_MASS:
|
| 151 |
+
window = action.parameters.get("mass_window_gev")
|
| 152 |
+
if window and len(window) == 2:
|
| 153 |
+
lo, hi = float(window[0]), float(window[1])
|
| 154 |
+
if hi <= lo:
|
| 155 |
+
result.add(
|
| 156 |
+
ViolationCode.INVALID_PARAMS,
|
| 157 |
+
f"Mass window [{lo}, {hi}] is non-positive.",
|
| 158 |
+
soft=True,
|
| 159 |
+
)
|
| 160 |
+
if lo > self.mass_search_window_gev[1] or hi < self.mass_search_window_gev[0]:
|
| 161 |
+
result.add(
|
| 162 |
+
ViolationCode.OUT_OF_WINDOW,
|
| 163 |
+
f"Histogram window [{lo}, {hi}] is outside the task search window "
|
| 164 |
+
f"{self.mass_search_window_gev}.",
|
| 165 |
+
soft=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# ── redundancy (soft) ─────────────────────────────────────
|
| 169 |
+
if a == ActionType.CONFIGURE_BEAM and prog.beam_configured:
|
| 170 |
+
result.add(ViolationCode.REDUNDANT, "Beam already configured; reconfiguring wastes budget.", soft=True)
|
| 171 |
+
if a == ActionType.SELECT_CHANNEL and prog.channel_selected:
|
| 172 |
+
result.add(ViolationCode.REDUNDANT, "Channel already selected.", soft=True)
|
| 173 |
+
if a == ActionType.RECONSTRUCT_TRACKS and prog.tracks_reconstructed:
|
| 174 |
+
result.add(ViolationCode.REDUNDANT, "Tracks already reconstructed.", soft=True)
|
| 175 |
+
if a == ActionType.CALIBRATE_DETECTOR and prog.detector_calibrated:
|
| 176 |
+
result.add(ViolationCode.REDUNDANT, "Detector already calibrated.", soft=True)
|
| 177 |
+
|
| 178 |
+
# ── claim sanity ──────────────────────────────────────────
|
| 179 |
+
if a == ActionType.SUBMIT_DISCOVERY_CLAIM:
|
| 180 |
+
claim = action.parameters.get("claim") or {}
|
| 181 |
+
mass = claim.get("mass_estimate_gev")
|
| 182 |
+
if mass is None:
|
| 183 |
+
result.add(ViolationCode.INVALID_CLAIM, "Claim missing mass estimate.")
|
| 184 |
+
else:
|
| 185 |
+
try:
|
| 186 |
+
m = float(mass)
|
| 187 |
+
except Exception:
|
| 188 |
+
result.add(ViolationCode.INVALID_CLAIM, "Claim mass is not numeric.")
|
| 189 |
+
else:
|
| 190 |
+
lo, hi = self.mass_search_window_gev
|
| 191 |
+
if not (lo <= m <= hi):
|
| 192 |
+
result.add(
|
| 193 |
+
ViolationCode.INVALID_CLAIM,
|
| 194 |
+
f"Claim mass {m} outside search window [{lo}, {hi}].",
|
| 195 |
+
soft=True,
|
| 196 |
+
)
|
| 197 |
+
if claim.get("significance_sigma") is None:
|
| 198 |
+
result.add(ViolationCode.INVALID_CLAIM, "Claim missing significance.", soft=True)
|
| 199 |
+
|
| 200 |
+
return result
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
__all__ = ["RuleResult", "RulesEngine", "ViolationCode"]
|
server/simulator/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simulator: latent particle truth, noise model, output generation."""
|
| 2 |
+
|
| 3 |
+
from .latent_state import (
|
| 4 |
+
DetectorState,
|
| 5 |
+
ExperimentProgress,
|
| 6 |
+
FullLatentState,
|
| 7 |
+
LatentParticle,
|
| 8 |
+
ResourceState,
|
| 9 |
+
)
|
| 10 |
+
from .noise import NoiseModel
|
| 11 |
+
from .output_generator import OutputGenerator
|
| 12 |
+
from .transition import (
|
| 13 |
+
ACTION_COSTS,
|
| 14 |
+
TransitionEngine,
|
| 15 |
+
TransitionResult,
|
| 16 |
+
compute_action_cost,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"ACTION_COSTS",
|
| 21 |
+
"DetectorState",
|
| 22 |
+
"ExperimentProgress",
|
| 23 |
+
"FullLatentState",
|
| 24 |
+
"LatentParticle",
|
| 25 |
+
"NoiseModel",
|
| 26 |
+
"OutputGenerator",
|
| 27 |
+
"ResourceState",
|
| 28 |
+
"TransitionEngine",
|
| 29 |
+
"TransitionResult",
|
| 30 |
+
"compute_action_cost",
|
| 31 |
+
]
|
server/simulator/latent_state.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Latent (hidden) state of the LHC simulator.
|
| 2 |
+
|
| 3 |
+
The agent never sees these structures. They define the ground-truth particle
|
| 4 |
+
properties, detector imperfections, experiment progress flags, and the live
|
| 5 |
+
resource budget.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ── Particle truth ────────────────────────────────────────────────────────
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LatentParticle(BaseModel):
|
| 19 |
+
"""The hidden mystery particle that the agent must discover.
|
| 20 |
+
|
| 21 |
+
Defines the true mass, width, decay branching ratios, spin, parity,
|
| 22 |
+
production cross-section, and dominant decay channel. The agent has to
|
| 23 |
+
recover these values from noisy observations.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
name: str = "X"
|
| 27 |
+
mass_gev: float = 125.0
|
| 28 |
+
width_gev: float = 0.004
|
| 29 |
+
spin: int = 0 # 0, 1, or 2
|
| 30 |
+
parity: str = "+" # "+" or "-"
|
| 31 |
+
cross_section_fb: float = 50.0 # signal cross-section in femtobarns
|
| 32 |
+
decay_branching: Dict[str, float] = Field(
|
| 33 |
+
default_factory=lambda: {
|
| 34 |
+
"diphoton": 0.0023,
|
| 35 |
+
"dilepton_ee": 0.00003,
|
| 36 |
+
"dilepton_mumu": 0.00022,
|
| 37 |
+
"four_lepton": 0.000125,
|
| 38 |
+
"bb": 0.58,
|
| 39 |
+
"dijet": 0.30,
|
| 40 |
+
},
|
| 41 |
+
description="Branching ratio (BR) per decay channel, sums to ~1.",
|
| 42 |
+
)
|
| 43 |
+
primary_channel: str = "diphoton"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ── Detector & accelerator state ─────────────────────────────────────────
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DetectorState(BaseModel):
|
| 50 |
+
"""Hidden detector and accelerator parameters that shape noise.
|
| 51 |
+
|
| 52 |
+
These influence resolution, trigger efficiency, pileup, and systematic
|
| 53 |
+
uncertainties applied to every observation.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
detector_resolution_gev: float = 1.5 # absolute mass resolution σ_m
|
| 57 |
+
pileup_mu: float = 30.0 # average pileup interactions per crossing
|
| 58 |
+
trigger_efficiency: float = 0.85
|
| 59 |
+
luminosity_uncertainty: float = 0.025 # 2.5% relative uncertainty
|
| 60 |
+
energy_scale_offset: float = 0.0 # systematic shift in GeV
|
| 61 |
+
energy_scale_uncertainty: float = 0.3 # σ on the scale
|
| 62 |
+
background_shape_alpha: float = -2.5 # exponent of background ~ 1/m^|α|
|
| 63 |
+
qcd_background_strength: float = 1.0 # scale factor for hadronic background
|
| 64 |
+
detector_calibrated: bool = False
|
| 65 |
+
tracker_aligned: bool = False
|
| 66 |
+
# Channel-dependent reconstruction efficiency
|
| 67 |
+
channel_efficiency: Dict[str, float] = Field(
|
| 68 |
+
default_factory=lambda: {
|
| 69 |
+
"diphoton": 0.45,
|
| 70 |
+
"dilepton_ee": 0.55,
|
| 71 |
+
"dilepton_mumu": 0.70,
|
| 72 |
+
"four_lepton": 0.40,
|
| 73 |
+
"dijet": 0.80,
|
| 74 |
+
"bb": 0.50,
|
| 75 |
+
}
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ── Experiment progress flags ────────────────────────────────────────────
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ExperimentProgress(BaseModel):
|
| 83 |
+
"""Boolean milestones used by rules and reward shaping."""
|
| 84 |
+
|
| 85 |
+
beam_configured: bool = False
|
| 86 |
+
luminosity_allocated: bool = False
|
| 87 |
+
trigger_set: bool = False
|
| 88 |
+
collisions_collected: bool = False
|
| 89 |
+
detector_calibrated: bool = False
|
| 90 |
+
tracks_reconstructed: bool = False
|
| 91 |
+
channel_selected: bool = False
|
| 92 |
+
invariant_mass_built: bool = False
|
| 93 |
+
background_subtracted: bool = False
|
| 94 |
+
resonance_fitted: bool = False
|
| 95 |
+
bump_scanned: bool = False
|
| 96 |
+
angular_measured: bool = False
|
| 97 |
+
significance_estimated: bool = False
|
| 98 |
+
systematics_requested: bool = False
|
| 99 |
+
theory_review_requested: bool = False
|
| 100 |
+
claim_submitted: bool = False
|
| 101 |
+
|
| 102 |
+
n_events_collected: int = 0
|
| 103 |
+
n_signal_candidates: int = 0
|
| 104 |
+
n_background_estimate: int = 0
|
| 105 |
+
best_fit_mass_gev: Optional[float] = None
|
| 106 |
+
best_fit_width_gev: Optional[float] = None
|
| 107 |
+
best_significance_sigma: Optional[float] = None
|
| 108 |
+
best_channel: Optional[str] = None
|
| 109 |
+
best_beam_energy: Optional[str] = None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ── Resources ─────────────────────────────────────────────────────────────
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ResourceState(BaseModel):
|
| 116 |
+
"""Live resource accounting (superset of the agent-visible ResourceUsage)."""
|
| 117 |
+
|
| 118 |
+
budget_total_musd: float = 100.0
|
| 119 |
+
budget_used_musd: float = 0.0
|
| 120 |
+
luminosity_total_fb: float = 300.0
|
| 121 |
+
luminosity_used_fb: float = 0.0
|
| 122 |
+
time_limit_days: float = 365.0
|
| 123 |
+
time_used_days: float = 0.0
|
| 124 |
+
compute_hours_used: float = 0.0
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def budget_remaining(self) -> float:
|
| 128 |
+
return max(0.0, self.budget_total_musd - self.budget_used_musd)
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def luminosity_remaining(self) -> float:
|
| 132 |
+
return max(0.0, self.luminosity_total_fb - self.luminosity_used_fb)
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def time_remaining(self) -> float:
|
| 136 |
+
return max(0.0, self.time_limit_days - self.time_used_days)
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def budget_exhausted(self) -> bool:
|
| 140 |
+
return self.budget_remaining <= 0
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def luminosity_exhausted(self) -> bool:
|
| 144 |
+
return self.luminosity_remaining <= 0
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def time_exhausted(self) -> bool:
|
| 148 |
+
return self.time_remaining <= 0
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ── Aggregate hidden state ───────────────────────────────────────────────
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class FullLatentState(BaseModel):
|
| 155 |
+
"""Complete hidden state of the simulated LHC analysis world."""
|
| 156 |
+
|
| 157 |
+
particle: LatentParticle = Field(default_factory=LatentParticle)
|
| 158 |
+
detector: DetectorState = Field(default_factory=DetectorState)
|
| 159 |
+
progress: ExperimentProgress = Field(default_factory=ExperimentProgress)
|
| 160 |
+
resources: ResourceState = Field(default_factory=ResourceState)
|
| 161 |
+
|
| 162 |
+
selected_channel: Optional[str] = None
|
| 163 |
+
selected_beam_energy: Optional[str] = None
|
| 164 |
+
selected_trigger: Optional[str] = None
|
| 165 |
+
|
| 166 |
+
candidate_masses_gev: List[float] = Field(default_factory=list)
|
| 167 |
+
candidate_significances: List[float] = Field(default_factory=list)
|
| 168 |
+
|
| 169 |
+
hidden_failure_conditions: List[str] = Field(default_factory=list)
|
| 170 |
+
rng_seed: int = 42
|
| 171 |
+
step_count: int = 0
|
server/simulator/noise.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stochastic noise model for the LHC simulator.
|
| 2 |
+
|
| 3 |
+
All randomness is funneled through a single seeded ``numpy.Generator`` so
|
| 4 |
+
episodes are reproducible. The methods are physics-flavoured: Poisson event
|
| 5 |
+
counts, Gaussian-smeared masses, log-normal cross-sections, false discovery
|
| 6 |
+
helpers, and quality degradation.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NoiseModel:
|
| 17 |
+
"""Centralised noise generator for the CERN simulator."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, seed: int = 42):
|
| 20 |
+
self.rng = np.random.default_rng(seed)
|
| 21 |
+
|
| 22 |
+
def reseed(self, seed: int) -> None:
|
| 23 |
+
self.rng = np.random.default_rng(seed)
|
| 24 |
+
|
| 25 |
+
# ── counting / Poisson statistics ─────────────────────────────────
|
| 26 |
+
|
| 27 |
+
def poisson(self, lam: float) -> int:
|
| 28 |
+
return int(self.rng.poisson(max(lam, 0.0)))
|
| 29 |
+
|
| 30 |
+
def signal_yield(
|
| 31 |
+
self,
|
| 32 |
+
cross_section_fb: float,
|
| 33 |
+
luminosity_fb: float,
|
| 34 |
+
branching: float,
|
| 35 |
+
efficiency: float,
|
| 36 |
+
trigger_efficiency: float,
|
| 37 |
+
) -> int:
|
| 38 |
+
"""Expected signal events ~ σ × L × BR × ε_reco × ε_trig + Poisson noise.
|
| 39 |
+
|
| 40 |
+
BR = branching ratio of the decay channel.
|
| 41 |
+
ε_reco = channel reconstruction efficiency.
|
| 42 |
+
ε_trig = trigger acceptance.
|
| 43 |
+
"""
|
| 44 |
+
mu = cross_section_fb * luminosity_fb * branching * efficiency * trigger_efficiency
|
| 45 |
+
return self.poisson(mu)
|
| 46 |
+
|
| 47 |
+
def background_yield(
|
| 48 |
+
self,
|
| 49 |
+
baseline_per_fb: float,
|
| 50 |
+
luminosity_fb: float,
|
| 51 |
+
qcd_strength: float,
|
| 52 |
+
trigger_efficiency: float,
|
| 53 |
+
) -> int:
|
| 54 |
+
"""Expected background events scale linearly with luminosity."""
|
| 55 |
+
mu = baseline_per_fb * luminosity_fb * qcd_strength * trigger_efficiency
|
| 56 |
+
return self.poisson(mu)
|
| 57 |
+
|
| 58 |
+
# ── mass smearing ──────────────────────────────────────────────────
|
| 59 |
+
|
| 60 |
+
def smear_mass(
|
| 61 |
+
self,
|
| 62 |
+
true_mass_gev: float,
|
| 63 |
+
resolution_gev: float,
|
| 64 |
+
scale_offset_gev: float = 0.0,
|
| 65 |
+
) -> float:
|
| 66 |
+
return float(self.rng.normal(true_mass_gev + scale_offset_gev, resolution_gev))
|
| 67 |
+
|
| 68 |
+
def fit_mass_estimate(
|
| 69 |
+
self,
|
| 70 |
+
true_mass_gev: float,
|
| 71 |
+
n_signal: int,
|
| 72 |
+
resolution_gev: float,
|
| 73 |
+
scale_offset_gev: float,
|
| 74 |
+
) -> float:
|
| 75 |
+
"""Fitted mass ≈ true mass + Gaussian error scaling like 1/√N_signal."""
|
| 76 |
+
n_eff = max(n_signal, 1)
|
| 77 |
+
sigma = resolution_gev / np.sqrt(n_eff)
|
| 78 |
+
return float(self.rng.normal(true_mass_gev + scale_offset_gev, sigma))
|
| 79 |
+
|
| 80 |
+
def fit_mass_uncertainty(
|
| 81 |
+
self,
|
| 82 |
+
n_signal: int,
|
| 83 |
+
resolution_gev: float,
|
| 84 |
+
) -> float:
|
| 85 |
+
"""Statistical mass uncertainty from a peak with N_signal events."""
|
| 86 |
+
n_eff = max(n_signal, 1)
|
| 87 |
+
return float(resolution_gev / np.sqrt(n_eff))
|
| 88 |
+
|
| 89 |
+
# ── significance ───────────────────────────────────────────────────
|
| 90 |
+
|
| 91 |
+
def asimov_significance(
|
| 92 |
+
self,
|
| 93 |
+
n_signal: int,
|
| 94 |
+
n_background: int,
|
| 95 |
+
nuisance_inflation: float = 0.0,
|
| 96 |
+
) -> float:
|
| 97 |
+
"""Asymptotic Asimov-style significance Z = √(2[(s+b) ln(1+s/b) - s]).
|
| 98 |
+
|
| 99 |
+
A small nuisance_inflation term in [0,1] shrinks Z to mimic systematic
|
| 100 |
+
penalties when calibration / systematics studies are skipped.
|
| 101 |
+
"""
|
| 102 |
+
if n_background <= 0:
|
| 103 |
+
return 0.0
|
| 104 |
+
s = float(n_signal)
|
| 105 |
+
b = float(n_background)
|
| 106 |
+
if s <= 0:
|
| 107 |
+
return 0.0
|
| 108 |
+
term = (s + b) * np.log(1.0 + s / b) - s
|
| 109 |
+
z = float(np.sqrt(max(2.0 * term, 0.0)))
|
| 110 |
+
return float(z * (1.0 - nuisance_inflation))
|
| 111 |
+
|
| 112 |
+
# ── helpers ─────────────────────────────────────────────────────────
|
| 113 |
+
|
| 114 |
+
def coin_flip(self, p: float) -> bool:
|
| 115 |
+
return bool(self.rng.random() < p)
|
| 116 |
+
|
| 117 |
+
def jitter(self, mean: float, sigma: float) -> float:
|
| 118 |
+
return float(self.rng.normal(mean, sigma))
|
| 119 |
+
|
| 120 |
+
def quality_degradation(self, base_quality: float, factors: List[float]) -> float:
|
| 121 |
+
q = base_quality
|
| 122 |
+
for f in factors:
|
| 123 |
+
q *= f
|
| 124 |
+
return float(np.clip(q + self.rng.normal(0, 0.02), 0.0, 1.0))
|
| 125 |
+
|
| 126 |
+
def sample_qc_metric(
|
| 127 |
+
self, mean: float, std: float, clip_lo: float = 0.0, clip_hi: float = 1.0
|
| 128 |
+
) -> float:
|
| 129 |
+
return float(np.clip(self.rng.normal(mean, std), clip_lo, clip_hi))
|
| 130 |
+
|
| 131 |
+
def histogram(
|
| 132 |
+
self,
|
| 133 |
+
n_signal: int,
|
| 134 |
+
n_background: int,
|
| 135 |
+
true_mass_gev: float,
|
| 136 |
+
resolution_gev: float,
|
| 137 |
+
window_lo_gev: float,
|
| 138 |
+
window_hi_gev: float,
|
| 139 |
+
n_bins: int = 40,
|
| 140 |
+
background_alpha: float = -2.5,
|
| 141 |
+
) -> List[int]:
|
| 142 |
+
"""Generate a coarse invariant-mass histogram.
|
| 143 |
+
|
| 144 |
+
Signal is Gaussian around the (smeared) true mass with width
|
| 145 |
+
=resolution; background is a falling power-law shape.
|
| 146 |
+
"""
|
| 147 |
+
if window_hi_gev <= window_lo_gev:
|
| 148 |
+
return [0] * n_bins
|
| 149 |
+
edges = np.linspace(window_lo_gev, window_hi_gev, n_bins + 1)
|
| 150 |
+
centers = 0.5 * (edges[:-1] + edges[1:])
|
| 151 |
+
|
| 152 |
+
sig_mu = true_mass_gev
|
| 153 |
+
sig_pdf = np.exp(-0.5 * ((centers - sig_mu) / max(resolution_gev, 1e-3)) ** 2)
|
| 154 |
+
sig_pdf /= max(sig_pdf.sum(), 1e-9)
|
| 155 |
+
|
| 156 |
+
bg_pdf = np.power(np.clip(centers, 1.0, None), background_alpha)
|
| 157 |
+
bg_pdf /= max(bg_pdf.sum(), 1e-9)
|
| 158 |
+
|
| 159 |
+
sig_counts = self.rng.multinomial(max(n_signal, 0), sig_pdf)
|
| 160 |
+
bg_counts = self.rng.multinomial(max(n_background, 0), bg_pdf)
|
| 161 |
+
return (sig_counts + bg_counts).astype(int).tolist()
|
server/simulator/output_generator.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Builds the noisy ``IntermediateOutput`` returned to the agent each step.
|
| 2 |
+
|
| 3 |
+
The OutputGenerator never mutates state; it only inspects the latent state
|
| 4 |
+
plus the action and produces a structured artifact. State changes happen in
|
| 5 |
+
``TransitionEngine``.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from models import (
|
| 15 |
+
ActionType,
|
| 16 |
+
DetectorChannel,
|
| 17 |
+
ExperimentAction,
|
| 18 |
+
IntermediateOutput,
|
| 19 |
+
OutputType,
|
| 20 |
+
TriggerType,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from .latent_state import FullLatentState
|
| 24 |
+
from .noise import NoiseModel
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ── Channel-specific background per fb^-1 (very rough physics-flavoured) ─
|
| 28 |
+
BACKGROUND_PER_FB: Dict[str, float] = {
|
| 29 |
+
"diphoton": 1500.0,
|
| 30 |
+
"dilepton_ee": 8000.0,
|
| 31 |
+
"dilepton_mumu": 9000.0,
|
| 32 |
+
"four_lepton": 80.0,
|
| 33 |
+
"dijet": 250000.0,
|
| 34 |
+
"bb": 50000.0,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ── Trigger ↔ channel affinity ───────────────────────────────────────────
|
| 39 |
+
TRIGGER_AFFINITY: Dict[str, Dict[str, float]] = {
|
| 40 |
+
"low_pt": {
|
| 41 |
+
"diphoton": 0.5,
|
| 42 |
+
"dilepton_ee": 0.6,
|
| 43 |
+
"dilepton_mumu": 0.6,
|
| 44 |
+
"four_lepton": 0.5,
|
| 45 |
+
"dijet": 0.9,
|
| 46 |
+
"bb": 0.7,
|
| 47 |
+
},
|
| 48 |
+
"high_pt": {
|
| 49 |
+
"diphoton": 0.9,
|
| 50 |
+
"dilepton_ee": 0.8,
|
| 51 |
+
"dilepton_mumu": 0.85,
|
| 52 |
+
"four_lepton": 0.85,
|
| 53 |
+
"dijet": 0.7,
|
| 54 |
+
"bb": 0.55,
|
| 55 |
+
},
|
| 56 |
+
"diphoton_hlt": {
|
| 57 |
+
"diphoton": 1.0,
|
| 58 |
+
"dilepton_ee": 0.05,
|
| 59 |
+
"dilepton_mumu": 0.05,
|
| 60 |
+
"four_lepton": 0.1,
|
| 61 |
+
"dijet": 0.05,
|
| 62 |
+
"bb": 0.05,
|
| 63 |
+
},
|
| 64 |
+
"dilepton_hlt": {
|
| 65 |
+
"diphoton": 0.05,
|
| 66 |
+
"dilepton_ee": 1.0,
|
| 67 |
+
"dilepton_mumu": 1.0,
|
| 68 |
+
"four_lepton": 0.85,
|
| 69 |
+
"dijet": 0.05,
|
| 70 |
+
"bb": 0.05,
|
| 71 |
+
},
|
| 72 |
+
"jet_hlt": {
|
| 73 |
+
"diphoton": 0.1,
|
| 74 |
+
"dilepton_ee": 0.1,
|
| 75 |
+
"dilepton_mumu": 0.1,
|
| 76 |
+
"four_lepton": 0.1,
|
| 77 |
+
"dijet": 1.0,
|
| 78 |
+
"bb": 0.85,
|
| 79 |
+
},
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ── Beam-energy luminosity & cross-section scaling ───────────────────────
|
| 84 |
+
BEAM_SCALING: Dict[str, Dict[str, float]] = {
|
| 85 |
+
"7TeV": {"xsec_scale": 0.45, "cost_per_fb": 0.05, "days_per_fb": 0.6},
|
| 86 |
+
"8TeV": {"xsec_scale": 0.65, "cost_per_fb": 0.08, "days_per_fb": 0.7},
|
| 87 |
+
"13TeV": {"xsec_scale": 1.00, "cost_per_fb": 0.12, "days_per_fb": 0.8},
|
| 88 |
+
"14TeV": {"xsec_scale": 1.15, "cost_per_fb": 0.18, "days_per_fb": 0.9},
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _trigger_efficiency(trigger: Optional[str], channel: Optional[str]) -> float:
|
| 93 |
+
if not trigger or not channel:
|
| 94 |
+
return 0.0
|
| 95 |
+
table = TRIGGER_AFFINITY.get(trigger, {})
|
| 96 |
+
return float(table.get(channel, 0.1))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class OutputGenerator:
|
| 100 |
+
"""Translates an action + latent state into a noisy observable artifact."""
|
| 101 |
+
|
| 102 |
+
def __init__(self, noise: NoiseModel):
|
| 103 |
+
self.noise = noise
|
| 104 |
+
|
| 105 |
+
# ── Public API ────────────────────────────────────────────────────
|
| 106 |
+
|
| 107 |
+
def generate(
|
| 108 |
+
self,
|
| 109 |
+
action: ExperimentAction,
|
| 110 |
+
state: FullLatentState,
|
| 111 |
+
step_index: int,
|
| 112 |
+
) -> IntermediateOutput:
|
| 113 |
+
a = action.action_type
|
| 114 |
+
|
| 115 |
+
if a == ActionType.CONFIGURE_BEAM:
|
| 116 |
+
return self._beam(action, state, step_index)
|
| 117 |
+
if a == ActionType.ALLOCATE_LUMINOSITY:
|
| 118 |
+
return self._luminosity(action, state, step_index)
|
| 119 |
+
if a == ActionType.SET_TRIGGER:
|
| 120 |
+
return self._trigger(action, state, step_index)
|
| 121 |
+
if a == ActionType.COLLECT_COLLISIONS:
|
| 122 |
+
return self._collect(action, state, step_index)
|
| 123 |
+
if a == ActionType.CALIBRATE_DETECTOR:
|
| 124 |
+
return self._calibrate(action, state, step_index)
|
| 125 |
+
if a == ActionType.RECONSTRUCT_TRACKS:
|
| 126 |
+
return self._reconstruct(action, state, step_index)
|
| 127 |
+
if a == ActionType.SELECT_CHANNEL:
|
| 128 |
+
return self._select_channel(action, state, step_index)
|
| 129 |
+
if a == ActionType.BUILD_INVARIANT_MASS:
|
| 130 |
+
return self._invariant_mass(action, state, step_index)
|
| 131 |
+
if a == ActionType.SUBTRACT_BACKGROUND:
|
| 132 |
+
return self._subtract_background(action, state, step_index)
|
| 133 |
+
if a == ActionType.FIT_RESONANCE:
|
| 134 |
+
return self._fit_resonance(action, state, step_index)
|
| 135 |
+
if a == ActionType.SCAN_BUMP:
|
| 136 |
+
return self._scan_bump(action, state, step_index)
|
| 137 |
+
if a == ActionType.MEASURE_ANGULAR:
|
| 138 |
+
return self._measure_angular(action, state, step_index)
|
| 139 |
+
if a == ActionType.ESTIMATE_SIGNIFICANCE:
|
| 140 |
+
return self._estimate_significance(action, state, step_index)
|
| 141 |
+
if a == ActionType.REQUEST_SYSTEMATICS:
|
| 142 |
+
return self._request_systematics(action, state, step_index)
|
| 143 |
+
if a == ActionType.REQUEST_THEORY_REVIEW:
|
| 144 |
+
return self._request_theory(action, state, step_index)
|
| 145 |
+
if a == ActionType.SUBMIT_DISCOVERY_CLAIM:
|
| 146 |
+
return self._submit_claim(action, state, step_index)
|
| 147 |
+
|
| 148 |
+
return self._failure(step_index, f"Unhandled action: {a}")
|
| 149 |
+
|
| 150 |
+
# ── helpers ────────────────────────────────────────────────────────
|
| 151 |
+
|
| 152 |
+
def _failure(self, step_index: int, msg: str) -> IntermediateOutput:
|
| 153 |
+
return IntermediateOutput(
|
| 154 |
+
output_type=OutputType.FAILURE_REPORT,
|
| 155 |
+
step_index=step_index,
|
| 156 |
+
success=False,
|
| 157 |
+
quality_score=0.0,
|
| 158 |
+
summary=msg,
|
| 159 |
+
warnings=[msg],
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# ── DAQ (Data Acquisition) outputs ────────────────────────────────
|
| 163 |
+
|
| 164 |
+
def _beam(
|
| 165 |
+
self,
|
| 166 |
+
action: ExperimentAction,
|
| 167 |
+
state: FullLatentState,
|
| 168 |
+
step_index: int,
|
| 169 |
+
) -> IntermediateOutput:
|
| 170 |
+
beam = action.parameters.get("beam_energy") or state.selected_beam_energy or "13TeV"
|
| 171 |
+
scaling = BEAM_SCALING.get(beam, BEAM_SCALING["13TeV"])
|
| 172 |
+
return IntermediateOutput(
|
| 173 |
+
output_type=OutputType.BEAM_CONFIG,
|
| 174 |
+
step_index=step_index,
|
| 175 |
+
success=True,
|
| 176 |
+
quality_score=0.9,
|
| 177 |
+
summary=f"LHC configured at √s={beam}; effective xsec scale={scaling['xsec_scale']:.2f}.",
|
| 178 |
+
data={
|
| 179 |
+
"beam_energy": beam,
|
| 180 |
+
"xsec_scale": scaling["xsec_scale"],
|
| 181 |
+
"cost_per_fb_musd": scaling["cost_per_fb"],
|
| 182 |
+
"days_per_fb": scaling["days_per_fb"],
|
| 183 |
+
},
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def _luminosity(
|
| 187 |
+
self,
|
| 188 |
+
action: ExperimentAction,
|
| 189 |
+
state: FullLatentState,
|
| 190 |
+
step_index: int,
|
| 191 |
+
) -> IntermediateOutput:
|
| 192 |
+
requested = float(action.parameters.get("luminosity_fb", 30.0))
|
| 193 |
+
granted = max(0.0, min(requested, state.resources.luminosity_remaining))
|
| 194 |
+
warnings: List[str] = []
|
| 195 |
+
if granted < requested:
|
| 196 |
+
warnings.append(
|
| 197 |
+
f"Luminosity capped: requested {requested:.1f} fb^-1, "
|
| 198 |
+
f"granted {granted:.1f} fb^-1."
|
| 199 |
+
)
|
| 200 |
+
return IntermediateOutput(
|
| 201 |
+
output_type=OutputType.LUMINOSITY_LOG,
|
| 202 |
+
step_index=step_index,
|
| 203 |
+
success=granted > 0,
|
| 204 |
+
quality_score=1.0 if granted > 0 else 0.0,
|
| 205 |
+
summary=f"Allocated {granted:.1f} fb^-1 of integrated luminosity.",
|
| 206 |
+
data={"luminosity_fb": granted, "requested_fb": requested},
|
| 207 |
+
warnings=warnings,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def _trigger(
|
| 211 |
+
self,
|
| 212 |
+
action: ExperimentAction,
|
| 213 |
+
state: FullLatentState,
|
| 214 |
+
step_index: int,
|
| 215 |
+
) -> IntermediateOutput:
|
| 216 |
+
trigger = action.parameters.get("trigger") or state.selected_trigger or "high_pt"
|
| 217 |
+
try:
|
| 218 |
+
TriggerType(trigger)
|
| 219 |
+
except ValueError:
|
| 220 |
+
return self._failure(step_index, f"Unknown trigger: {trigger}")
|
| 221 |
+
eff = state.detector.trigger_efficiency
|
| 222 |
+
return IntermediateOutput(
|
| 223 |
+
output_type=OutputType.TRIGGER_REPORT,
|
| 224 |
+
step_index=step_index,
|
| 225 |
+
success=True,
|
| 226 |
+
quality_score=eff,
|
| 227 |
+
summary=f"Trigger {trigger} armed; ε_trig={eff:.2f}.",
|
| 228 |
+
data={"trigger": trigger, "trigger_efficiency": eff},
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def _collect(
|
| 232 |
+
self,
|
| 233 |
+
action: ExperimentAction,
|
| 234 |
+
state: FullLatentState,
|
| 235 |
+
step_index: int,
|
| 236 |
+
) -> IntermediateOutput:
|
| 237 |
+
beam = state.selected_beam_energy or "13TeV"
|
| 238 |
+
scaling = BEAM_SCALING.get(beam, BEAM_SCALING["13TeV"])
|
| 239 |
+
lumi_request = float(action.parameters.get("luminosity_fb", 0.0))
|
| 240 |
+
if lumi_request <= 0:
|
| 241 |
+
lumi_request = max(0.0, state.resources.luminosity_remaining * 0.2)
|
| 242 |
+
lumi = max(0.0, min(lumi_request, state.resources.luminosity_remaining))
|
| 243 |
+
if lumi <= 0:
|
| 244 |
+
return self._failure(step_index, "No luminosity remaining to collect.")
|
| 245 |
+
|
| 246 |
+
channel = state.selected_channel or state.particle.primary_channel
|
| 247 |
+
try:
|
| 248 |
+
DetectorChannel(channel)
|
| 249 |
+
except ValueError:
|
| 250 |
+
return self._failure(step_index, f"Invalid channel: {channel}")
|
| 251 |
+
|
| 252 |
+
trig = state.selected_trigger or "high_pt"
|
| 253 |
+
trig_eff = _trigger_efficiency(trig, channel)
|
| 254 |
+
reco_eff = state.detector.channel_efficiency.get(channel, 0.4)
|
| 255 |
+
if not state.detector.tracker_aligned and channel in {"dilepton_ee", "dilepton_mumu", "four_lepton"}:
|
| 256 |
+
reco_eff *= 0.7
|
| 257 |
+
if not state.detector.detector_calibrated and channel in {"diphoton"}:
|
| 258 |
+
reco_eff *= 0.8
|
| 259 |
+
|
| 260 |
+
br = state.particle.decay_branching.get(channel, 0.0)
|
| 261 |
+
eff_xsec = state.particle.cross_section_fb * scaling["xsec_scale"]
|
| 262 |
+
|
| 263 |
+
n_sig = self.noise.signal_yield(
|
| 264 |
+
cross_section_fb=eff_xsec,
|
| 265 |
+
luminosity_fb=lumi,
|
| 266 |
+
branching=br,
|
| 267 |
+
efficiency=reco_eff,
|
| 268 |
+
trigger_efficiency=trig_eff,
|
| 269 |
+
)
|
| 270 |
+
n_bg = self.noise.background_yield(
|
| 271 |
+
baseline_per_fb=BACKGROUND_PER_FB.get(channel, 1000.0),
|
| 272 |
+
luminosity_fb=lumi,
|
| 273 |
+
qcd_strength=state.detector.qcd_background_strength,
|
| 274 |
+
trigger_efficiency=trig_eff,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
cost = lumi * scaling["cost_per_fb"]
|
| 278 |
+
days = lumi * scaling["days_per_fb"]
|
| 279 |
+
|
| 280 |
+
return IntermediateOutput(
|
| 281 |
+
output_type=OutputType.COLLISION_BATCH,
|
| 282 |
+
step_index=step_index,
|
| 283 |
+
success=True,
|
| 284 |
+
quality_score=float(np.clip(reco_eff * trig_eff + 0.1, 0.0, 1.0)),
|
| 285 |
+
summary=(
|
| 286 |
+
f"Collected {lumi:.1f} fb^-1 in {channel} with trigger {trig}: "
|
| 287 |
+
f"~{n_sig + n_bg} reconstructed events."
|
| 288 |
+
),
|
| 289 |
+
data={
|
| 290 |
+
"luminosity_fb": lumi,
|
| 291 |
+
"beam_energy": beam,
|
| 292 |
+
"channel": channel,
|
| 293 |
+
"trigger": trig,
|
| 294 |
+
"n_signal_candidates": int(n_sig),
|
| 295 |
+
"n_background_estimate": int(n_bg),
|
| 296 |
+
"cost_musd": cost,
|
| 297 |
+
"time_days": days,
|
| 298 |
+
"trigger_efficiency": trig_eff,
|
| 299 |
+
"reco_efficiency": reco_eff,
|
| 300 |
+
},
|
| 301 |
+
uncertainty=float(np.clip(0.05 + (1.0 - reco_eff) * 0.2, 0.0, 0.5)),
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# ── Reconstruction outputs ────────────────────────────────────────
|
| 305 |
+
|
| 306 |
+
def _calibrate(
|
| 307 |
+
self,
|
| 308 |
+
action: ExperimentAction,
|
| 309 |
+
state: FullLatentState,
|
| 310 |
+
step_index: int,
|
| 311 |
+
) -> IntermediateOutput:
|
| 312 |
+
method = action.method or "ECAL_calibration"
|
| 313 |
+
improvement = self.noise.sample_qc_metric(0.5, 0.1, 0.0, 0.95)
|
| 314 |
+
return IntermediateOutput(
|
| 315 |
+
output_type=OutputType.CALIBRATION_REPORT,
|
| 316 |
+
step_index=step_index,
|
| 317 |
+
success=True,
|
| 318 |
+
quality_score=0.9,
|
| 319 |
+
summary=f"Detector calibrated using {method}; resolution improved by {improvement*100:.1f}%.",
|
| 320 |
+
data={
|
| 321 |
+
"method": method,
|
| 322 |
+
"resolution_improvement": improvement,
|
| 323 |
+
},
|
| 324 |
+
uncertainty=0.05,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
def _reconstruct(
|
| 328 |
+
self,
|
| 329 |
+
action: ExperimentAction,
|
| 330 |
+
state: FullLatentState,
|
| 331 |
+
step_index: int,
|
| 332 |
+
) -> IntermediateOutput:
|
| 333 |
+
method = action.method or "Athena"
|
| 334 |
+
return IntermediateOutput(
|
| 335 |
+
output_type=OutputType.RECONSTRUCTION,
|
| 336 |
+
step_index=step_index,
|
| 337 |
+
success=True,
|
| 338 |
+
quality_score=0.85,
|
| 339 |
+
summary=f"Tracks and physics objects reconstructed via {method}.",
|
| 340 |
+
data={"method": method},
|
| 341 |
+
uncertainty=0.05,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def _select_channel(
|
| 345 |
+
self,
|
| 346 |
+
action: ExperimentAction,
|
| 347 |
+
state: FullLatentState,
|
| 348 |
+
step_index: int,
|
| 349 |
+
) -> IntermediateOutput:
|
| 350 |
+
channel = action.parameters.get("channel") or state.selected_channel
|
| 351 |
+
if not channel:
|
| 352 |
+
return self._failure(step_index, "No channel specified.")
|
| 353 |
+
try:
|
| 354 |
+
DetectorChannel(channel)
|
| 355 |
+
except ValueError:
|
| 356 |
+
return self._failure(step_index, f"Unknown channel: {channel}")
|
| 357 |
+
return IntermediateOutput(
|
| 358 |
+
output_type=OutputType.CHANNEL_SELECTION,
|
| 359 |
+
step_index=step_index,
|
| 360 |
+
success=True,
|
| 361 |
+
quality_score=0.95,
|
| 362 |
+
summary=f"Analysis channel set to {channel}.",
|
| 363 |
+
data={"channel": channel},
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# ── Analysis outputs ──────────────────────────────────────────────
|
| 367 |
+
|
| 368 |
+
def _invariant_mass(
|
| 369 |
+
self,
|
| 370 |
+
action: ExperimentAction,
|
| 371 |
+
state: FullLatentState,
|
| 372 |
+
step_index: int,
|
| 373 |
+
) -> IntermediateOutput:
|
| 374 |
+
if state.progress.n_events_collected <= 0:
|
| 375 |
+
return self._failure(step_index, "No collisions collected yet.")
|
| 376 |
+
window = action.parameters.get("mass_window_gev") or [50.0, 1000.0]
|
| 377 |
+
n_bins = int(action.parameters.get("n_bins", 40))
|
| 378 |
+
true_m = state.particle.mass_gev
|
| 379 |
+
in_window = window[0] <= true_m <= window[1]
|
| 380 |
+
n_sig = state.progress.n_signal_candidates if in_window else 0
|
| 381 |
+
hist = self.noise.histogram(
|
| 382 |
+
n_signal=n_sig,
|
| 383 |
+
n_background=state.progress.n_background_estimate,
|
| 384 |
+
true_mass_gev=true_m,
|
| 385 |
+
resolution_gev=state.detector.detector_resolution_gev,
|
| 386 |
+
window_lo_gev=window[0],
|
| 387 |
+
window_hi_gev=window[1],
|
| 388 |
+
n_bins=n_bins,
|
| 389 |
+
background_alpha=state.detector.background_shape_alpha,
|
| 390 |
+
)
|
| 391 |
+
return IntermediateOutput(
|
| 392 |
+
output_type=OutputType.INVARIANT_MASS_HIST,
|
| 393 |
+
step_index=step_index,
|
| 394 |
+
success=True,
|
| 395 |
+
quality_score=0.85 if in_window else 0.4,
|
| 396 |
+
summary=(
|
| 397 |
+
f"Invariant-mass histogram in [{window[0]:.0f}, {window[1]:.0f}] GeV "
|
| 398 |
+
f"with {n_bins} bins, total {sum(hist)} entries."
|
| 399 |
+
),
|
| 400 |
+
data={
|
| 401 |
+
"window_gev": window,
|
| 402 |
+
"bin_counts": hist,
|
| 403 |
+
"n_signal_in_window": n_sig,
|
| 404 |
+
"n_background_in_window": state.progress.n_background_estimate,
|
| 405 |
+
},
|
| 406 |
+
uncertainty=0.1,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
def _subtract_background(
|
| 410 |
+
self,
|
| 411 |
+
action: ExperimentAction,
|
| 412 |
+
state: FullLatentState,
|
| 413 |
+
step_index: int,
|
| 414 |
+
) -> IntermediateOutput:
|
| 415 |
+
if not state.progress.invariant_mass_built:
|
| 416 |
+
return self._failure(step_index, "Build the invariant-mass histogram first.")
|
| 417 |
+
residual = self.noise.sample_qc_metric(0.05, 0.02, 0.0, 0.5)
|
| 418 |
+
return IntermediateOutput(
|
| 419 |
+
output_type=OutputType.BACKGROUND_SUBTRACTION,
|
| 420 |
+
step_index=step_index,
|
| 421 |
+
success=True,
|
| 422 |
+
quality_score=0.85,
|
| 423 |
+
summary=f"Smooth background subtracted; residual fraction ≈ {residual*100:.1f}%.",
|
| 424 |
+
data={"residual_fraction": residual},
|
| 425 |
+
uncertainty=0.08,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def _fit_resonance(
|
| 429 |
+
self,
|
| 430 |
+
action: ExperimentAction,
|
| 431 |
+
state: FullLatentState,
|
| 432 |
+
step_index: int,
|
| 433 |
+
) -> IntermediateOutput:
|
| 434 |
+
if not state.progress.background_subtracted and not state.progress.invariant_mass_built:
|
| 435 |
+
return self._failure(step_index, "Need a histogram (and ideally background subtraction) before fitting.")
|
| 436 |
+
n_sig = max(state.progress.n_signal_candidates, 1)
|
| 437 |
+
true_m = state.particle.mass_gev
|
| 438 |
+
scale = state.detector.energy_scale_offset
|
| 439 |
+
res = state.detector.detector_resolution_gev
|
| 440 |
+
m_fit = self.noise.fit_mass_estimate(true_m, n_sig, res, scale)
|
| 441 |
+
m_unc = self.noise.fit_mass_uncertainty(n_sig, res)
|
| 442 |
+
w_fit = max(0.001, abs(self.noise.jitter(state.particle.width_gev, 0.1 * res)))
|
| 443 |
+
return IntermediateOutput(
|
| 444 |
+
output_type=OutputType.FIT_RESULT,
|
| 445 |
+
step_index=step_index,
|
| 446 |
+
success=True,
|
| 447 |
+
quality_score=0.9,
|
| 448 |
+
summary=f"Resonance fit: m={m_fit:.2f} ± {m_unc:.2f} GeV, Γ≈{w_fit:.3f} GeV.",
|
| 449 |
+
data={
|
| 450 |
+
"fit_mass_gev": m_fit,
|
| 451 |
+
"fit_mass_unc_gev": m_unc,
|
| 452 |
+
"fit_width_gev": w_fit,
|
| 453 |
+
"n_signal_used": int(n_sig),
|
| 454 |
+
},
|
| 455 |
+
uncertainty=float(np.clip(m_unc / max(true_m, 1.0), 0.0, 1.0)),
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def _scan_bump(
|
| 459 |
+
self,
|
| 460 |
+
action: ExperimentAction,
|
| 461 |
+
state: FullLatentState,
|
| 462 |
+
step_index: int,
|
| 463 |
+
) -> IntermediateOutput:
|
| 464 |
+
if state.progress.n_events_collected <= 0:
|
| 465 |
+
return self._failure(step_index, "Collect data before bump-hunting.")
|
| 466 |
+
true_m = state.particle.mass_gev
|
| 467 |
+
m_obs = self.noise.smear_mass(true_m, state.detector.detector_resolution_gev * 1.2)
|
| 468 |
+
return IntermediateOutput(
|
| 469 |
+
output_type=OutputType.BUMP_SCAN,
|
| 470 |
+
step_index=step_index,
|
| 471 |
+
success=True,
|
| 472 |
+
quality_score=0.7,
|
| 473 |
+
summary=f"Bump scan most-significant region near m≈{m_obs:.1f} GeV.",
|
| 474 |
+
data={"candidate_mass_gev": m_obs},
|
| 475 |
+
uncertainty=0.15,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
def _measure_angular(
|
| 479 |
+
self,
|
| 480 |
+
action: ExperimentAction,
|
| 481 |
+
state: FullLatentState,
|
| 482 |
+
step_index: int,
|
| 483 |
+
) -> IntermediateOutput:
|
| 484 |
+
spin_truth = state.particle.spin
|
| 485 |
+
# Returns posterior over {0,1,2} biased by truth + noise
|
| 486 |
+
weights = np.array([0.1, 0.1, 0.1])
|
| 487 |
+
weights[spin_truth] += 0.6
|
| 488 |
+
weights += self.noise.rng.normal(0, 0.05, size=3)
|
| 489 |
+
weights = np.clip(weights, 0.01, None)
|
| 490 |
+
weights /= weights.sum()
|
| 491 |
+
return IntermediateOutput(
|
| 492 |
+
output_type=OutputType.ANGULAR_RESULT,
|
| 493 |
+
step_index=step_index,
|
| 494 |
+
success=True,
|
| 495 |
+
quality_score=0.8,
|
| 496 |
+
summary=(
|
| 497 |
+
"Angular distribution favours spin-"
|
| 498 |
+
f"{int(np.argmax(weights))} ({weights.max():.2f} posterior)."
|
| 499 |
+
),
|
| 500 |
+
data={
|
| 501 |
+
"spin_posterior": weights.tolist(),
|
| 502 |
+
"favoured_spin": int(np.argmax(weights)),
|
| 503 |
+
"parity_estimate": state.particle.parity,
|
| 504 |
+
},
|
| 505 |
+
uncertainty=float(1.0 - weights.max()),
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
def _estimate_significance(
|
| 509 |
+
self,
|
| 510 |
+
action: ExperimentAction,
|
| 511 |
+
state: FullLatentState,
|
| 512 |
+
step_index: int,
|
| 513 |
+
) -> IntermediateOutput:
|
| 514 |
+
n_sig = state.progress.n_signal_candidates
|
| 515 |
+
n_bg = state.progress.n_background_estimate
|
| 516 |
+
nuisance = 0.0
|
| 517 |
+
if not state.progress.systematics_requested:
|
| 518 |
+
nuisance += 0.15
|
| 519 |
+
if not state.progress.detector_calibrated:
|
| 520 |
+
nuisance += 0.10
|
| 521 |
+
z = self.noise.asimov_significance(n_sig, n_bg, nuisance_inflation=nuisance)
|
| 522 |
+
return IntermediateOutput(
|
| 523 |
+
output_type=OutputType.SIGNIFICANCE,
|
| 524 |
+
step_index=step_index,
|
| 525 |
+
success=True,
|
| 526 |
+
quality_score=0.9,
|
| 527 |
+
summary=f"Estimated local significance Z = {z:.2f} σ.",
|
| 528 |
+
data={
|
| 529 |
+
"significance_sigma": z,
|
| 530 |
+
"n_signal": int(n_sig),
|
| 531 |
+
"n_background": int(n_bg),
|
| 532 |
+
"nuisance_inflation": nuisance,
|
| 533 |
+
},
|
| 534 |
+
uncertainty=float(np.clip(0.05 + nuisance, 0.0, 0.5)),
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# ── Meta outputs ──────────────────────────────────────────────────
|
| 538 |
+
|
| 539 |
+
def _request_systematics(
|
| 540 |
+
self,
|
| 541 |
+
action: ExperimentAction,
|
| 542 |
+
state: FullLatentState,
|
| 543 |
+
step_index: int,
|
| 544 |
+
) -> IntermediateOutput:
|
| 545 |
+
method = action.method or "Luminosity_calibration"
|
| 546 |
+
return IntermediateOutput(
|
| 547 |
+
output_type=OutputType.SYSTEMATICS_REPORT,
|
| 548 |
+
step_index=step_index,
|
| 549 |
+
success=True,
|
| 550 |
+
quality_score=0.85,
|
| 551 |
+
summary=f"Systematics study via {method}; nuisance band tightened.",
|
| 552 |
+
data={"method": method},
|
| 553 |
+
uncertainty=0.04,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
def _request_theory(
|
| 557 |
+
self,
|
| 558 |
+
action: ExperimentAction,
|
| 559 |
+
state: FullLatentState,
|
| 560 |
+
step_index: int,
|
| 561 |
+
) -> IntermediateOutput:
|
| 562 |
+
return IntermediateOutput(
|
| 563 |
+
output_type=OutputType.THEORY_REVIEW,
|
| 564 |
+
step_index=step_index,
|
| 565 |
+
success=True,
|
| 566 |
+
quality_score=0.7,
|
| 567 |
+
summary="Theory review: candidate consistent with Standard-Model-extension scalar / vector hypotheses.",
|
| 568 |
+
data={"hypotheses": ["BSM scalar", "BSM vector", "SM background fluctuation"]},
|
| 569 |
+
uncertainty=0.2,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
def _submit_claim(
|
| 573 |
+
self,
|
| 574 |
+
action: ExperimentAction,
|
| 575 |
+
state: FullLatentState,
|
| 576 |
+
step_index: int,
|
| 577 |
+
) -> IntermediateOutput:
|
| 578 |
+
claim: Dict[str, Any] = action.parameters.get("claim") or {}
|
| 579 |
+
return IntermediateOutput(
|
| 580 |
+
output_type=OutputType.DISCOVERY_CLAIM,
|
| 581 |
+
step_index=step_index,
|
| 582 |
+
success=True,
|
| 583 |
+
quality_score=1.0,
|
| 584 |
+
summary="Discovery claim submitted for grading.",
|
| 585 |
+
data=claim,
|
| 586 |
+
)
|
server/simulator/transition.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pure-function transition engine.
|
| 2 |
+
|
| 3 |
+
Given a (latent_state, action, generated_output) triple, produces the next
|
| 4 |
+
latent state plus the deltas needed for the agent-visible observation. The
|
| 5 |
+
``TransitionEngine`` does **not** generate randomness directly; it consumes
|
| 6 |
+
artifacts from the ``OutputGenerator``.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Dict
|
| 13 |
+
|
| 14 |
+
from models import (
|
| 15 |
+
ActionType,
|
| 16 |
+
ExperimentAction,
|
| 17 |
+
IntermediateOutput,
|
| 18 |
+
OutputType,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from .latent_state import FullLatentState
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Per-action default cost in (millions of USD, days, compute hours)
|
| 25 |
+
ACTION_COSTS: Dict[ActionType, Dict[str, float]] = {
|
| 26 |
+
ActionType.CONFIGURE_BEAM: {"musd": 0.10, "days": 0.5, "compute": 0.1},
|
| 27 |
+
ActionType.ALLOCATE_LUMINOSITY: {"musd": 0.05, "days": 0.2, "compute": 0.0},
|
| 28 |
+
ActionType.SET_TRIGGER: {"musd": 0.05, "days": 0.1, "compute": 0.0},
|
| 29 |
+
ActionType.COLLECT_COLLISIONS: {"musd": 0.00, "days": 0.0, "compute": 1.0}, # main cost is in luminosity
|
| 30 |
+
ActionType.CALIBRATE_DETECTOR: {"musd": 0.20, "days": 1.0, "compute": 1.5},
|
| 31 |
+
ActionType.RECONSTRUCT_TRACKS: {"musd": 0.15, "days": 0.8, "compute": 5.0},
|
| 32 |
+
ActionType.SELECT_CHANNEL: {"musd": 0.00, "days": 0.05, "compute": 0.0},
|
| 33 |
+
ActionType.BUILD_INVARIANT_MASS: {"musd": 0.05, "days": 0.3, "compute": 1.0},
|
| 34 |
+
ActionType.SUBTRACT_BACKGROUND: {"musd": 0.05, "days": 0.3, "compute": 0.5},
|
| 35 |
+
ActionType.FIT_RESONANCE: {"musd": 0.10, "days": 0.4, "compute": 0.5},
|
| 36 |
+
ActionType.SCAN_BUMP: {"musd": 0.05, "days": 0.2, "compute": 0.5},
|
| 37 |
+
ActionType.MEASURE_ANGULAR: {"musd": 0.10, "days": 0.4, "compute": 0.5},
|
| 38 |
+
ActionType.ESTIMATE_SIGNIFICANCE: {"musd": 0.05, "days": 0.1, "compute": 0.2},
|
| 39 |
+
ActionType.REQUEST_SYSTEMATICS: {"musd": 0.30, "days": 1.5, "compute": 1.0},
|
| 40 |
+
ActionType.REQUEST_THEORY_REVIEW: {"musd": 0.05, "days": 0.5, "compute": 0.0},
|
| 41 |
+
ActionType.SUBMIT_DISCOVERY_CLAIM:{"musd": 0.0, "days": 0.1, "compute": 0.0},
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def compute_action_cost(action: ExperimentAction, output: IntermediateOutput) -> Dict[str, float]:
|
| 46 |
+
"""Return realised (musd, days, compute_hours, luminosity_fb) for this action."""
|
| 47 |
+
base = ACTION_COSTS.get(action.action_type, {"musd": 0.0, "days": 0.0, "compute": 0.0})
|
| 48 |
+
musd = float(base.get("musd", 0.0))
|
| 49 |
+
days = float(base.get("days", 0.0))
|
| 50 |
+
compute = float(base.get("compute", 0.0))
|
| 51 |
+
lumi_fb = 0.0
|
| 52 |
+
|
| 53 |
+
data = output.data or {}
|
| 54 |
+
if action.action_type == ActionType.COLLECT_COLLISIONS:
|
| 55 |
+
lumi_fb = float(data.get("luminosity_fb", 0.0))
|
| 56 |
+
musd += float(data.get("cost_musd", 0.0))
|
| 57 |
+
days += float(data.get("time_days", 0.0))
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"musd": musd,
|
| 61 |
+
"days": days,
|
| 62 |
+
"compute_hours": compute,
|
| 63 |
+
"luminosity_fb": lumi_fb,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class TransitionResult:
|
| 69 |
+
next_state: FullLatentState
|
| 70 |
+
realised_cost: Dict[str, float]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TransitionEngine:
|
| 74 |
+
"""Applies an action's output to evolve the latent state."""
|
| 75 |
+
|
| 76 |
+
def step(
|
| 77 |
+
self,
|
| 78 |
+
state: FullLatentState,
|
| 79 |
+
action: ExperimentAction,
|
| 80 |
+
output: IntermediateOutput,
|
| 81 |
+
) -> TransitionResult:
|
| 82 |
+
# We mutate the live state in place, then return it. This is fine
|
| 83 |
+
# because the environment owns the only reference.
|
| 84 |
+
cost = compute_action_cost(action, output)
|
| 85 |
+
state.resources.budget_used_musd += cost["musd"]
|
| 86 |
+
state.resources.time_used_days += cost["days"]
|
| 87 |
+
state.resources.compute_hours_used += cost["compute_hours"]
|
| 88 |
+
state.resources.luminosity_used_fb += cost["luminosity_fb"]
|
| 89 |
+
|
| 90 |
+
if not output.success:
|
| 91 |
+
state.step_count += 1
|
| 92 |
+
return TransitionResult(next_state=state, realised_cost=cost)
|
| 93 |
+
|
| 94 |
+
a = action.action_type
|
| 95 |
+
data = output.data or {}
|
| 96 |
+
|
| 97 |
+
if a == ActionType.CONFIGURE_BEAM:
|
| 98 |
+
beam = data.get("beam_energy")
|
| 99 |
+
state.selected_beam_energy = beam
|
| 100 |
+
state.progress.beam_configured = True
|
| 101 |
+
|
| 102 |
+
elif a == ActionType.ALLOCATE_LUMINOSITY:
|
| 103 |
+
state.progress.luminosity_allocated = True
|
| 104 |
+
|
| 105 |
+
elif a == ActionType.SET_TRIGGER:
|
| 106 |
+
trig = data.get("trigger")
|
| 107 |
+
state.selected_trigger = trig
|
| 108 |
+
state.progress.trigger_set = True
|
| 109 |
+
|
| 110 |
+
elif a == ActionType.COLLECT_COLLISIONS:
|
| 111 |
+
state.progress.collisions_collected = True
|
| 112 |
+
state.progress.n_events_collected += int(
|
| 113 |
+
data.get("n_signal_candidates", 0)
|
| 114 |
+
) + int(data.get("n_background_estimate", 0))
|
| 115 |
+
state.progress.n_signal_candidates += int(data.get("n_signal_candidates", 0))
|
| 116 |
+
state.progress.n_background_estimate += int(data.get("n_background_estimate", 0))
|
| 117 |
+
state.progress.best_channel = data.get("channel") or state.progress.best_channel
|
| 118 |
+
state.progress.best_beam_energy = (
|
| 119 |
+
data.get("beam_energy") or state.progress.best_beam_energy
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
elif a == ActionType.CALIBRATE_DETECTOR:
|
| 123 |
+
state.progress.detector_calibrated = True
|
| 124 |
+
state.detector.detector_calibrated = True
|
| 125 |
+
improvement = float(data.get("resolution_improvement", 0.0))
|
| 126 |
+
state.detector.detector_resolution_gev = max(
|
| 127 |
+
0.05,
|
| 128 |
+
state.detector.detector_resolution_gev * (1.0 - improvement),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
elif a == ActionType.RECONSTRUCT_TRACKS:
|
| 132 |
+
state.progress.tracks_reconstructed = True
|
| 133 |
+
state.detector.tracker_aligned = True
|
| 134 |
+
|
| 135 |
+
elif a == ActionType.SELECT_CHANNEL:
|
| 136 |
+
channel = data.get("channel")
|
| 137 |
+
if channel:
|
| 138 |
+
state.selected_channel = channel
|
| 139 |
+
state.progress.channel_selected = True
|
| 140 |
+
|
| 141 |
+
elif a == ActionType.BUILD_INVARIANT_MASS:
|
| 142 |
+
state.progress.invariant_mass_built = True
|
| 143 |
+
|
| 144 |
+
elif a == ActionType.SUBTRACT_BACKGROUND:
|
| 145 |
+
state.progress.background_subtracted = True
|
| 146 |
+
|
| 147 |
+
elif a == ActionType.FIT_RESONANCE:
|
| 148 |
+
state.progress.resonance_fitted = True
|
| 149 |
+
m = float(data.get("fit_mass_gev", 0.0))
|
| 150 |
+
unc = float(data.get("fit_mass_unc_gev", 0.0))
|
| 151 |
+
w = float(data.get("fit_width_gev", 0.0))
|
| 152 |
+
if m > 0:
|
| 153 |
+
state.candidate_masses_gev.append(m)
|
| 154 |
+
state.candidate_significances.append(0.0)
|
| 155 |
+
state.progress.best_fit_mass_gev = m
|
| 156 |
+
state.progress.best_fit_width_gev = w
|
| 157 |
+
|
| 158 |
+
elif a == ActionType.SCAN_BUMP:
|
| 159 |
+
state.progress.bump_scanned = True
|
| 160 |
+
cm = float(data.get("candidate_mass_gev", 0.0))
|
| 161 |
+
if cm > 0:
|
| 162 |
+
state.candidate_masses_gev.append(cm)
|
| 163 |
+
state.candidate_significances.append(0.0)
|
| 164 |
+
|
| 165 |
+
elif a == ActionType.MEASURE_ANGULAR:
|
| 166 |
+
state.progress.angular_measured = True
|
| 167 |
+
|
| 168 |
+
elif a == ActionType.ESTIMATE_SIGNIFICANCE:
|
| 169 |
+
state.progress.significance_estimated = True
|
| 170 |
+
sig = float(data.get("significance_sigma", 0.0))
|
| 171 |
+
state.progress.best_significance_sigma = max(
|
| 172 |
+
state.progress.best_significance_sigma or 0.0, sig
|
| 173 |
+
)
|
| 174 |
+
if state.candidate_significances:
|
| 175 |
+
state.candidate_significances[-1] = sig
|
| 176 |
+
|
| 177 |
+
elif a == ActionType.REQUEST_SYSTEMATICS:
|
| 178 |
+
state.progress.systematics_requested = True
|
| 179 |
+
state.detector.energy_scale_uncertainty *= 0.6
|
| 180 |
+
state.detector.luminosity_uncertainty *= 0.7
|
| 181 |
+
|
| 182 |
+
elif a == ActionType.REQUEST_THEORY_REVIEW:
|
| 183 |
+
state.progress.theory_review_requested = True
|
| 184 |
+
|
| 185 |
+
elif a == ActionType.SUBMIT_DISCOVERY_CLAIM:
|
| 186 |
+
state.progress.claim_submitted = True
|
| 187 |
+
|
| 188 |
+
state.step_count += 1
|
| 189 |
+
return TransitionResult(next_state=state, realised_cost=cost)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
__all__ = [
|
| 193 |
+
"ACTION_COSTS",
|
| 194 |
+
"TransitionEngine",
|
| 195 |
+
"TransitionResult",
|
| 196 |
+
"compute_action_cost",
|
| 197 |
+
]
|
server/tasks/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task generator: curated scenarios + procedural curriculum."""
|
| 2 |
+
|
| 3 |
+
from .scenarios import (
|
| 4 |
+
CURATED_SCENARIOS,
|
| 5 |
+
Scenario,
|
| 6 |
+
sample_scenario,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = ["CURATED_SCENARIOS", "Scenario", "sample_scenario"]
|
server/tasks/scenarios.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Built-in physics scenarios + procedural sampling.
|
| 2 |
+
|
| 3 |
+
Each scenario binds a hidden ``LatentParticle`` truth and a public
|
| 4 |
+
``TaskSpec`` (search window, available channels, resource budgets, expected
|
| 5 |
+
findings, paper references). Curated scenarios are inspired by famous LHC
|
| 6 |
+
discoveries; procedural ones randomise mass, channel, width and budgets to
|
| 7 |
+
build a curriculum.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import List, Optional
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from models import (
|
| 18 |
+
DetectorChannel,
|
| 19 |
+
ExpectedFinding,
|
| 20 |
+
PaperReference,
|
| 21 |
+
TOOL_REGISTRY,
|
| 22 |
+
TaskSpec,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from server.simulator.latent_state import (
|
| 26 |
+
DetectorState,
|
| 27 |
+
FullLatentState,
|
| 28 |
+
LatentParticle,
|
| 29 |
+
ResourceState,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class Scenario:
|
| 35 |
+
name: str
|
| 36 |
+
difficulty: str
|
| 37 |
+
task: TaskSpec
|
| 38 |
+
latent: FullLatentState
|
| 39 |
+
|
| 40 |
+
def fresh_latent(self) -> FullLatentState:
|
| 41 |
+
# Pydantic deep-copy so the env can mutate freely
|
| 42 |
+
return self.latent.model_copy(deep=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ── Curated, story-driven scenarios ──────────────────────────────────────
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _higgs_like_scenario() -> Scenario:
|
| 49 |
+
particle = LatentParticle(
|
| 50 |
+
name="HiggsLike",
|
| 51 |
+
mass_gev=125.0,
|
| 52 |
+
width_gev=0.004,
|
| 53 |
+
spin=0,
|
| 54 |
+
parity="+",
|
| 55 |
+
cross_section_fb=55.0,
|
| 56 |
+
decay_branching={
|
| 57 |
+
"diphoton": 0.0023,
|
| 58 |
+
"dilepton_ee": 0.00003,
|
| 59 |
+
"dilepton_mumu": 0.00022,
|
| 60 |
+
"four_lepton": 0.000125,
|
| 61 |
+
"bb": 0.58,
|
| 62 |
+
"dijet": 0.30,
|
| 63 |
+
},
|
| 64 |
+
primary_channel="diphoton",
|
| 65 |
+
)
|
| 66 |
+
detector = DetectorState(
|
| 67 |
+
detector_resolution_gev=1.5,
|
| 68 |
+
pileup_mu=30.0,
|
| 69 |
+
trigger_efficiency=0.85,
|
| 70 |
+
)
|
| 71 |
+
resources = ResourceState(
|
| 72 |
+
budget_total_musd=120.0,
|
| 73 |
+
luminosity_total_fb=300.0,
|
| 74 |
+
time_limit_days=365.0,
|
| 75 |
+
)
|
| 76 |
+
latent = FullLatentState(
|
| 77 |
+
particle=particle,
|
| 78 |
+
detector=detector,
|
| 79 |
+
resources=resources,
|
| 80 |
+
rng_seed=125,
|
| 81 |
+
)
|
| 82 |
+
task = TaskSpec(
|
| 83 |
+
problem_statement=(
|
| 84 |
+
"An anomalous excess at ~125 GeV is rumoured in early 13 TeV runs. "
|
| 85 |
+
"Plan a campaign to confirm or refute a Standard-Model Higgs-like scalar. "
|
| 86 |
+
"Pick channels, allocate luminosity, fit, and submit a calibrated discovery claim."
|
| 87 |
+
),
|
| 88 |
+
target_collider="LHC",
|
| 89 |
+
mass_search_window_gev=[100.0, 200.0],
|
| 90 |
+
budget_limit_musd=120.0,
|
| 91 |
+
luminosity_budget_fb=300.0,
|
| 92 |
+
time_limit_days=365.0,
|
| 93 |
+
prior_observations=[
|
| 94 |
+
"Earlier Tevatron data shows a mild diphoton excess near 125 GeV.",
|
| 95 |
+
"ATLAS/CMS rumour mills suggest a 4ℓ excess at low mass.",
|
| 96 |
+
],
|
| 97 |
+
success_criteria=[
|
| 98 |
+
"Identify a resonance within 1 GeV of the truth.",
|
| 99 |
+
"Reach ≥5σ local significance.",
|
| 100 |
+
"Submit confidence consistent with calibration.",
|
| 101 |
+
],
|
| 102 |
+
paper_references=[
|
| 103 |
+
PaperReference(
|
| 104 |
+
title="Observation of a new particle in the search for the SM Higgs boson",
|
| 105 |
+
arxiv_id="1207.7214",
|
| 106 |
+
doi="10.1016/j.physletb.2012.08.020",
|
| 107 |
+
),
|
| 108 |
+
],
|
| 109 |
+
expected_findings=[
|
| 110 |
+
ExpectedFinding(finding="Diphoton resonance at ~125 GeV", category="discovery"),
|
| 111 |
+
ExpectedFinding(finding="Spin-0, even parity", category="property"),
|
| 112 |
+
],
|
| 113 |
+
difficulty="medium",
|
| 114 |
+
available_tools=list(TOOL_REGISTRY.keys()),
|
| 115 |
+
)
|
| 116 |
+
return Scenario(name="higgs_like_125", difficulty="medium", task=task, latent=latent)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _hidden_zprime_scenario() -> Scenario:
|
| 120 |
+
particle = LatentParticle(
|
| 121 |
+
name="ZPrime",
|
| 122 |
+
mass_gev=600.0,
|
| 123 |
+
width_gev=18.0,
|
| 124 |
+
spin=1,
|
| 125 |
+
parity="-",
|
| 126 |
+
cross_section_fb=12.0,
|
| 127 |
+
decay_branching={
|
| 128 |
+
"diphoton": 0.0,
|
| 129 |
+
"dilepton_ee": 0.04,
|
| 130 |
+
"dilepton_mumu": 0.04,
|
| 131 |
+
"four_lepton": 0.0,
|
| 132 |
+
"bb": 0.20,
|
| 133 |
+
"dijet": 0.70,
|
| 134 |
+
},
|
| 135 |
+
primary_channel="dilepton_mumu",
|
| 136 |
+
)
|
| 137 |
+
detector = DetectorState(
|
| 138 |
+
detector_resolution_gev=8.0,
|
| 139 |
+
pileup_mu=45.0,
|
| 140 |
+
trigger_efficiency=0.78,
|
| 141 |
+
qcd_background_strength=1.2,
|
| 142 |
+
)
|
| 143 |
+
resources = ResourceState(
|
| 144 |
+
budget_total_musd=140.0,
|
| 145 |
+
luminosity_total_fb=200.0,
|
| 146 |
+
time_limit_days=400.0,
|
| 147 |
+
)
|
| 148 |
+
latent = FullLatentState(
|
| 149 |
+
particle=particle, detector=detector, resources=resources, rng_seed=600,
|
| 150 |
+
)
|
| 151 |
+
task = TaskSpec(
|
| 152 |
+
problem_statement=(
|
| 153 |
+
"Run-2 dilepton spectra hint at a high-mass excess. Hunt for a heavy "
|
| 154 |
+
"Z'-like vector resonance and characterise spin-1, parity-odd hypothesis."
|
| 155 |
+
),
|
| 156 |
+
mass_search_window_gev=[300.0, 1500.0],
|
| 157 |
+
budget_limit_musd=140.0,
|
| 158 |
+
luminosity_budget_fb=200.0,
|
| 159 |
+
time_limit_days=400.0,
|
| 160 |
+
prior_observations=[
|
| 161 |
+
"High-pT dilepton tail shows a 2.7σ shoulder near 600 GeV.",
|
| 162 |
+
"Dijet smooth-fit residuals consistent with the same window.",
|
| 163 |
+
],
|
| 164 |
+
success_criteria=[
|
| 165 |
+
"Identify a high-mass dilepton/dijet resonance.",
|
| 166 |
+
"Constrain spin to be vector (1).",
|
| 167 |
+
"Report calibrated mass within 5% and ≥4σ significance.",
|
| 168 |
+
],
|
| 169 |
+
paper_references=[
|
| 170 |
+
PaperReference(
|
| 171 |
+
title="Search for high-mass dilepton resonances at the LHC",
|
| 172 |
+
arxiv_id="1903.06248",
|
| 173 |
+
),
|
| 174 |
+
],
|
| 175 |
+
expected_findings=[
|
| 176 |
+
ExpectedFinding(finding="Heavy Z'-like dilepton resonance", category="discovery"),
|
| 177 |
+
ExpectedFinding(finding="Spin-1, parity-odd", category="property"),
|
| 178 |
+
],
|
| 179 |
+
difficulty="hard",
|
| 180 |
+
available_tools=list(TOOL_REGISTRY.keys()),
|
| 181 |
+
)
|
| 182 |
+
return Scenario(name="hidden_zprime_600", difficulty="hard", task=task, latent=latent)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _diboson_resonance_scenario() -> Scenario:
|
| 186 |
+
particle = LatentParticle(
|
| 187 |
+
name="Graviton",
|
| 188 |
+
mass_gev=750.0,
|
| 189 |
+
width_gev=45.0,
|
| 190 |
+
spin=2,
|
| 191 |
+
parity="+",
|
| 192 |
+
cross_section_fb=6.0,
|
| 193 |
+
decay_branching={
|
| 194 |
+
"diphoton": 0.06,
|
| 195 |
+
"dilepton_ee": 0.005,
|
| 196 |
+
"dilepton_mumu": 0.005,
|
| 197 |
+
"four_lepton": 0.001,
|
| 198 |
+
"bb": 0.15,
|
| 199 |
+
"dijet": 0.70,
|
| 200 |
+
},
|
| 201 |
+
primary_channel="diphoton",
|
| 202 |
+
)
|
| 203 |
+
detector = DetectorState(
|
| 204 |
+
detector_resolution_gev=12.0,
|
| 205 |
+
pileup_mu=50.0,
|
| 206 |
+
trigger_efficiency=0.80,
|
| 207 |
+
)
|
| 208 |
+
resources = ResourceState(
|
| 209 |
+
budget_total_musd=110.0,
|
| 210 |
+
luminosity_total_fb=180.0,
|
| 211 |
+
time_limit_days=350.0,
|
| 212 |
+
)
|
| 213 |
+
latent = FullLatentState(
|
| 214 |
+
particle=particle, detector=detector, resources=resources, rng_seed=750,
|
| 215 |
+
)
|
| 216 |
+
task = TaskSpec(
|
| 217 |
+
problem_statement=(
|
| 218 |
+
"A faint γγ excess at 750 GeV stirred the field briefly in 2015-2016. "
|
| 219 |
+
"Re-investigate with the modern luminosity budget and decide if it is "
|
| 220 |
+
"real or a fluctuation."
|
| 221 |
+
),
|
| 222 |
+
mass_search_window_gev=[400.0, 1200.0],
|
| 223 |
+
budget_limit_musd=110.0,
|
| 224 |
+
luminosity_budget_fb=180.0,
|
| 225 |
+
time_limit_days=350.0,
|
| 226 |
+
prior_observations=[
|
| 227 |
+
"Public CMS/ATLAS data show a 2-3σ diphoton bump near 750 GeV.",
|
| 228 |
+
"Theory papers proposed graviton, scalar singlet, and SM-fluctuation explanations.",
|
| 229 |
+
],
|
| 230 |
+
success_criteria=[
|
| 231 |
+
"Decide between discovery and fluctuation with calibrated confidence.",
|
| 232 |
+
],
|
| 233 |
+
paper_references=[
|
| 234 |
+
PaperReference(
|
| 235 |
+
title="Search for resonant production of high-mass diphoton pairs",
|
| 236 |
+
arxiv_id="1606.04093",
|
| 237 |
+
),
|
| 238 |
+
],
|
| 239 |
+
expected_findings=[
|
| 240 |
+
ExpectedFinding(finding="Possible diphoton resonance near 750 GeV", category="discovery"),
|
| 241 |
+
],
|
| 242 |
+
difficulty="hard",
|
| 243 |
+
available_tools=list(TOOL_REGISTRY.keys()),
|
| 244 |
+
)
|
| 245 |
+
return Scenario(name="diphoton_750", difficulty="hard", task=task, latent=latent)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _easy_diphoton_scenario() -> Scenario:
|
| 249 |
+
"""Generous budgets, narrow scalar, single obvious channel."""
|
| 250 |
+
particle = LatentParticle(
|
| 251 |
+
name="EasyScalar",
|
| 252 |
+
mass_gev=160.0,
|
| 253 |
+
width_gev=0.5,
|
| 254 |
+
spin=0,
|
| 255 |
+
parity="+",
|
| 256 |
+
cross_section_fb=120.0,
|
| 257 |
+
decay_branching={
|
| 258 |
+
"diphoton": 0.05,
|
| 259 |
+
"dilepton_ee": 0.001,
|
| 260 |
+
"dilepton_mumu": 0.005,
|
| 261 |
+
"four_lepton": 0.0001,
|
| 262 |
+
"bb": 0.50,
|
| 263 |
+
"dijet": 0.30,
|
| 264 |
+
},
|
| 265 |
+
primary_channel="diphoton",
|
| 266 |
+
)
|
| 267 |
+
detector = DetectorState(
|
| 268 |
+
detector_resolution_gev=2.0,
|
| 269 |
+
pileup_mu=20.0,
|
| 270 |
+
trigger_efficiency=0.9,
|
| 271 |
+
)
|
| 272 |
+
resources = ResourceState(
|
| 273 |
+
budget_total_musd=200.0,
|
| 274 |
+
luminosity_total_fb=400.0,
|
| 275 |
+
time_limit_days=500.0,
|
| 276 |
+
)
|
| 277 |
+
latent = FullLatentState(
|
| 278 |
+
particle=particle, detector=detector, resources=resources, rng_seed=160,
|
| 279 |
+
)
|
| 280 |
+
task = TaskSpec(
|
| 281 |
+
problem_statement=(
|
| 282 |
+
"Tutorial scenario: discover a narrow scalar that decays cleanly to "
|
| 283 |
+
"two photons. Resources are abundant; focus on running a clean pipeline."
|
| 284 |
+
),
|
| 285 |
+
mass_search_window_gev=[80.0, 300.0],
|
| 286 |
+
budget_limit_musd=200.0,
|
| 287 |
+
luminosity_budget_fb=400.0,
|
| 288 |
+
time_limit_days=500.0,
|
| 289 |
+
success_criteria=[
|
| 290 |
+
"Identify the diphoton peak and submit a calibrated 5σ claim.",
|
| 291 |
+
],
|
| 292 |
+
expected_findings=[
|
| 293 |
+
ExpectedFinding(finding="Diphoton scalar near 160 GeV", category="discovery"),
|
| 294 |
+
],
|
| 295 |
+
difficulty="easy",
|
| 296 |
+
available_tools=list(TOOL_REGISTRY.keys()),
|
| 297 |
+
)
|
| 298 |
+
return Scenario(name="easy_diphoton_160", difficulty="easy", task=task, latent=latent)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
CURATED_SCENARIOS: List[Scenario] = [
|
| 302 |
+
_easy_diphoton_scenario(),
|
| 303 |
+
_higgs_like_scenario(),
|
| 304 |
+
_hidden_zprime_scenario(),
|
| 305 |
+
_diboson_resonance_scenario(),
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# ── Procedural sampler ───────────────────────────────────────────────────
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
_DIFFICULTY_TIERS = {
|
| 313 |
+
"easy": {"mass_lo": 90.0, "mass_hi": 250.0, "xsec_lo": 80.0, "xsec_hi": 150.0, "res": 1.5, "budget": 200.0, "lumi": 400.0},
|
| 314 |
+
"medium": {"mass_lo": 100.0, "mass_hi": 600.0, "xsec_lo": 25.0, "xsec_hi": 80.0, "res": 3.0, "budget": 150.0, "lumi": 300.0},
|
| 315 |
+
"hard": {"mass_lo": 250.0, "mass_hi": 1500.0, "xsec_lo": 5.0, "xsec_hi": 25.0, "res": 8.0, "budget": 110.0, "lumi": 200.0},
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _procedural_scenario(difficulty: str, rng: np.random.Generator) -> Scenario:
|
| 320 |
+
tier = _DIFFICULTY_TIERS.get(difficulty, _DIFFICULTY_TIERS["medium"])
|
| 321 |
+
mass = float(rng.uniform(tier["mass_lo"], tier["mass_hi"]))
|
| 322 |
+
xsec = float(rng.uniform(tier["xsec_lo"], tier["xsec_hi"]))
|
| 323 |
+
spin = int(rng.choice([0, 1, 2]))
|
| 324 |
+
parity = str(rng.choice(["+", "-"]))
|
| 325 |
+
primary = str(rng.choice([c.value for c in DetectorChannel]))
|
| 326 |
+
|
| 327 |
+
branching = {c.value: 0.001 for c in DetectorChannel}
|
| 328 |
+
branching[primary] = float(rng.uniform(0.02, 0.6))
|
| 329 |
+
# normalise so it sums to ~1
|
| 330 |
+
total = sum(branching.values())
|
| 331 |
+
branching = {k: v / total for k, v in branching.items()}
|
| 332 |
+
|
| 333 |
+
particle = LatentParticle(
|
| 334 |
+
name=f"Mystery_{int(mass)}GeV",
|
| 335 |
+
mass_gev=mass,
|
| 336 |
+
width_gev=float(rng.uniform(0.5, 30.0) if difficulty != "easy" else rng.uniform(0.05, 2.0)),
|
| 337 |
+
spin=spin,
|
| 338 |
+
parity=parity,
|
| 339 |
+
cross_section_fb=xsec,
|
| 340 |
+
decay_branching=branching,
|
| 341 |
+
primary_channel=primary,
|
| 342 |
+
)
|
| 343 |
+
detector = DetectorState(
|
| 344 |
+
detector_resolution_gev=tier["res"],
|
| 345 |
+
pileup_mu=float(rng.uniform(20.0, 60.0)),
|
| 346 |
+
trigger_efficiency=float(rng.uniform(0.7, 0.92)),
|
| 347 |
+
qcd_background_strength=float(rng.uniform(0.8, 1.3)),
|
| 348 |
+
)
|
| 349 |
+
resources = ResourceState(
|
| 350 |
+
budget_total_musd=tier["budget"],
|
| 351 |
+
luminosity_total_fb=tier["lumi"],
|
| 352 |
+
time_limit_days=float(rng.uniform(300.0, 500.0)),
|
| 353 |
+
)
|
| 354 |
+
latent = FullLatentState(
|
| 355 |
+
particle=particle, detector=detector, resources=resources,
|
| 356 |
+
rng_seed=int(rng.integers(1, 1_000_000)),
|
| 357 |
+
)
|
| 358 |
+
window_lo = max(50.0, mass - 200.0)
|
| 359 |
+
window_hi = mass + 300.0
|
| 360 |
+
task = TaskSpec(
|
| 361 |
+
problem_statement=(
|
| 362 |
+
f"Procedural ({difficulty}): a hidden resonance lives somewhere in "
|
| 363 |
+
f"[{window_lo:.0f}, {window_hi:.0f}] GeV. Discover and characterise it."
|
| 364 |
+
),
|
| 365 |
+
mass_search_window_gev=[window_lo, window_hi],
|
| 366 |
+
budget_limit_musd=tier["budget"],
|
| 367 |
+
luminosity_budget_fb=tier["lumi"],
|
| 368 |
+
time_limit_days=resources.time_limit_days,
|
| 369 |
+
difficulty=difficulty,
|
| 370 |
+
available_tools=list(TOOL_REGISTRY.keys()),
|
| 371 |
+
success_criteria=[
|
| 372 |
+
"Discover the hidden resonance with a calibrated mass and channel.",
|
| 373 |
+
],
|
| 374 |
+
)
|
| 375 |
+
return Scenario(
|
| 376 |
+
name=f"procedural_{difficulty}_{int(mass)}",
|
| 377 |
+
difficulty=difficulty,
|
| 378 |
+
task=task,
|
| 379 |
+
latent=latent,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def sample_scenario(
|
| 384 |
+
*,
|
| 385 |
+
difficulty: Optional[str] = None,
|
| 386 |
+
name: Optional[str] = None,
|
| 387 |
+
seed: Optional[int] = None,
|
| 388 |
+
) -> Scenario:
|
| 389 |
+
rng = np.random.default_rng(seed)
|
| 390 |
+
|
| 391 |
+
if name:
|
| 392 |
+
for s in CURATED_SCENARIOS:
|
| 393 |
+
if s.name == name:
|
| 394 |
+
fresh = Scenario(
|
| 395 |
+
name=s.name,
|
| 396 |
+
difficulty=s.difficulty,
|
| 397 |
+
task=s.task,
|
| 398 |
+
latent=s.fresh_latent(),
|
| 399 |
+
)
|
| 400 |
+
if seed is not None:
|
| 401 |
+
fresh.latent.rng_seed = int(seed)
|
| 402 |
+
return fresh
|
| 403 |
+
|
| 404 |
+
if difficulty in {"easy", "medium", "hard"}:
|
| 405 |
+
# mix curated + procedural
|
| 406 |
+
curated_pool = [s for s in CURATED_SCENARIOS if s.difficulty == difficulty]
|
| 407 |
+
if curated_pool and rng.random() < 0.4:
|
| 408 |
+
picked = curated_pool[int(rng.integers(0, len(curated_pool)))]
|
| 409 |
+
return Scenario(
|
| 410 |
+
name=picked.name,
|
| 411 |
+
difficulty=picked.difficulty,
|
| 412 |
+
task=picked.task,
|
| 413 |
+
latent=picked.fresh_latent(),
|
| 414 |
+
)
|
| 415 |
+
return _procedural_scenario(difficulty, rng)
|
| 416 |
+
|
| 417 |
+
# default: random difficulty
|
| 418 |
+
diff = str(rng.choice(["easy", "medium", "hard"]))
|
| 419 |
+
return _procedural_scenario(diff, rng)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
__all__ = ["CURATED_SCENARIOS", "Scenario", "sample_scenario"]
|
space/__init__.py
ADDED
|
File without changes
|
space/env/Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CERNenv environment Space (Docker, CPU)
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
PYTHONPATH=/home/user/app
|
| 7 |
+
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
git curl ca-certificates build-essential \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
RUN useradd -ms /bin/bash user
|
| 13 |
+
USER user
|
| 14 |
+
WORKDIR /home/user/app
|
| 15 |
+
|
| 16 |
+
COPY --chown=user:user space/env/requirements.txt /home/user/app/space-env-requirements.txt
|
| 17 |
+
RUN python -m pip install --upgrade pip && \
|
| 18 |
+
python -m pip install --user -r /home/user/app/space-env-requirements.txt
|
| 19 |
+
|
| 20 |
+
COPY --chown=user:user . /home/user/app
|
| 21 |
+
|
| 22 |
+
EXPOSE 7860
|
| 23 |
+
|
| 24 |
+
CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
space/env/README.md
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CERNenv
|
| 3 |
+
emoji: ⚛️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
suggested_hardware: cpu-basic
|
| 8 |
+
pinned: false
|
| 9 |
+
license: bsd-3-clause
|
| 10 |
+
short_description: LHC particle-discovery RL environment
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# CERNenv — LHC Discovery RL Environment
|
| 14 |
+
|
| 15 |
+
OpenEnv-compatible reinforcement-learning environment that simulates an
|
| 16 |
+
LHC (Large Hadron Collider) analysis. An LLM (Large Language Model) agent
|
| 17 |
+
configures the beam, allocates luminosity, picks a decay channel and
|
| 18 |
+
trigger, runs reconstruction, fits an invariant-mass spectrum, estimates
|
| 19 |
+
significance, and finally submits a structured discovery claim that is
|
| 20 |
+
graded against a hidden ground-truth particle.
|
| 21 |
+
|
| 22 |
+
The Space exposes the standard OpenEnv HTTP + WebSocket API:
|
| 23 |
+
|
| 24 |
+
* `GET /health` — liveness
|
| 25 |
+
* `GET /schema` — action / observation / state JSON schemas
|
| 26 |
+
* `POST /reset` — start a new episode (`{ "seed": 7, "scenario": "easy_diphoton_160" }`)
|
| 27 |
+
* `POST /step` — execute one action
|
| 28 |
+
* `GET /state` — current `CernState`
|
| 29 |
+
* `WS /ws` — persistent session (recommended for multi-step rollouts)
|
| 30 |
+
|
| 31 |
+
## Quickstart (Python client)
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
import asyncio
|
| 35 |
+
from openenv.core import EnvClient
|
| 36 |
+
from huggingface_hub import constants
|
| 37 |
+
|
| 38 |
+
# replace with your space id
|
| 39 |
+
SPACE = "YOUR_HF_USERNAME/cernenv"
|
| 40 |
+
|
| 41 |
+
# (option A) connect to the running Space directly
|
| 42 |
+
import websockets
|
| 43 |
+
async def main():
|
| 44 |
+
async with EnvClient.from_env(SPACE) as env: # uses websockets under the hood
|
| 45 |
+
result = await env.reset(seed=7, scenario="easy_diphoton_160")
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
asyncio.run(main())
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
For training, see the companion **CERNenv Trainer** Space.
|
space/env/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.24.0
|
| 2 |
+
scipy>=1.10.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
fastapi>=0.110.0
|
| 5 |
+
uvicorn>=0.27.0
|
| 6 |
+
git+https://github.com/meta-pytorch/OpenEnv.git
|
space/training/Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CERNenv trainer Space (Docker, A100)
|
| 2 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 5 |
+
PYTHONUNBUFFERED=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
HF_HOME=/home/user/.cache/huggingface \
|
| 8 |
+
TRANSFORMERS_CACHE=/home/user/.cache/huggingface/transformers \
|
| 9 |
+
PYTHONPATH=/home/user/app
|
| 10 |
+
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
python3.11 python3.11-venv python3.11-dev python3-pip \
|
| 13 |
+
git curl ca-certificates build-essential \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 15 |
+
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python \
|
| 16 |
+
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python3
|
| 17 |
+
|
| 18 |
+
RUN useradd -ms /bin/bash user
|
| 19 |
+
USER user
|
| 20 |
+
ENV PATH="/home/user/.local/bin:${PATH}"
|
| 21 |
+
WORKDIR /home/user/app
|
| 22 |
+
|
| 23 |
+
COPY --chown=user:user space/training/requirements.txt /home/user/app/space-training-requirements.txt
|
| 24 |
+
RUN python -m pip install --upgrade pip && \
|
| 25 |
+
python -m pip install --user -r /home/user/app/space-training-requirements.txt
|
| 26 |
+
|
| 27 |
+
COPY --chown=user:user . /home/user/app
|
| 28 |
+
|
| 29 |
+
EXPOSE 7860
|
| 30 |
+
|
| 31 |
+
CMD ["python", "-m", "uvicorn", "space.training.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
space/training/README.md
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CERNenv Trainer
|
| 3 |
+
emoji: ⚛️
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: docker
|
| 7 |
+
suggested_hardware: a100-large
|
| 8 |
+
suggested_storage: medium
|
| 9 |
+
pinned: false
|
| 10 |
+
license: bsd-3-clause
|
| 11 |
+
short_description: GRPO trainer for CERNenv (Unsloth + LoRA, A100)
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# CERNenv Trainer (Hugging Face Space, A100)
|
| 15 |
+
|
| 16 |
+
Fine-tunes a small instruction-tuned LLM (Large Language Model) to act as
|
| 17 |
+
an LHC (Large Hadron Collider) physicist inside the **CERNenv** OpenEnv
|
| 18 |
+
environment using **GRPO** (Group-Relative Policy Optimization),
|
| 19 |
+
**Unsloth**, and **LoRA** (Low-Rank Adaptation).
|
| 20 |
+
|
| 21 |
+
## Hardware
|
| 22 |
+
- Recommended: **A100 large (80 GB)**
|
| 23 |
+
- Minimum: T4 / L4 (will use a smaller model + fewer episodes)
|
| 24 |
+
|
| 25 |
+
## Required Space secrets
|
| 26 |
+
| Secret | Purpose |
|
| 27 |
+
| --- | --- |
|
| 28 |
+
| `HF_TOKEN` | Hugging Face token with `write` access for model push |
|
| 29 |
+
| `HF_USERNAME` | Hub username, used as the default model-repo owner |
|
| 30 |
+
|
| 31 |
+
## Optional environment variables
|
| 32 |
+
| Variable | Default | Notes |
|
| 33 |
+
| --- | --- | --- |
|
| 34 |
+
| `MODEL_NAME` | `unsloth/Qwen2.5-3B-Instruct` | Any chat model Unsloth supports |
|
| 35 |
+
| `TOTAL_EPISODES` | `400` | Prompts × generations rollouts |
|
| 36 |
+
| `DIFFICULTY` | `easy` | `easy` / `medium` / `hard` |
|
| 37 |
+
| `MAX_STEPS` | `18` | Steps per episode |
|
| 38 |
+
| `NUM_GENERATIONS` | `4` | GRPO group size |
|
| 39 |
+
| `OUTPUT_DIR` | `runs/unsloth-grpo` | LoRA adapter output |
|
| 40 |
+
| `PUSH_REPO` | `${HF_USERNAME}/cernenv-grpo-qwen2.5-3b` | Hub repo for adapters |
|
| 41 |
+
| `AUTOSTART` | `0` | Set to `1` to start training on Space boot |
|
| 42 |
+
|
| 43 |
+
## How to use
|
| 44 |
+
|
| 45 |
+
This Space exposes a tiny FastAPI control panel:
|
| 46 |
+
- `GET /` — status + current run info
|
| 47 |
+
- `POST /train` — start / restart a training run
|
| 48 |
+
- `GET /logs` — live tail of `training.log`
|
| 49 |
+
- `GET /metrics` — reward + success-rate snapshots
|
| 50 |
+
|
| 51 |
+
Click **"Start training"** in the UI, or set `AUTOSTART=1` in the Space variables to kick off immediately on boot.
|
| 52 |
+
|
| 53 |
+
When training finishes, the LoRA adapters are pushed to `PUSH_REPO`.
|
| 54 |
+
|
| 55 |
+
## Local equivalent
|
| 56 |
+
|
| 57 |
+
The same training run is reproducible locally with:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
PYTHONPATH=. python -m training.training_unsloth \
|
| 61 |
+
--model_name unsloth/Qwen2.5-3B-Instruct \
|
| 62 |
+
--difficulty easy --total_episodes 400 --max_steps 18 \
|
| 63 |
+
--output_dir runs/unsloth-grpo
|
| 64 |
+
```
|
space/training/__init__.py
ADDED
|
File without changes
|
space/training/app.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI control panel for the CERNenv trainer Space.
|
| 2 |
+
|
| 3 |
+
Endpoints:
|
| 4 |
+
GET / → status page (HTML)
|
| 5 |
+
GET /status → JSON status of the current training run
|
| 6 |
+
GET /metrics → JSON snapshot of reward / success rate
|
| 7 |
+
GET /logs → tail of the training log
|
| 8 |
+
POST /train → start (or restart) a training run
|
| 9 |
+
GET /health → liveness probe
|
| 10 |
+
|
| 11 |
+
Designed to run on a Hugging Face Space with `sdk: docker`. Heavy training
|
| 12 |
+
work runs in a background thread so the HTTP server stays responsive.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import subprocess
|
| 21 |
+
import sys
|
| 22 |
+
import threading
|
| 23 |
+
import time
|
| 24 |
+
from datetime import datetime, timezone
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any, Dict, Optional
|
| 27 |
+
|
| 28 |
+
from fastapi import FastAPI, HTTPException
|
| 29 |
+
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _resolve_repo_root() -> Path:
|
| 37 |
+
env_root = os.environ.get("CERNENV_ROOT")
|
| 38 |
+
candidates = []
|
| 39 |
+
if env_root:
|
| 40 |
+
candidates.append(Path(env_root))
|
| 41 |
+
candidates.extend([
|
| 42 |
+
Path("/home/user/app"),
|
| 43 |
+
Path(__file__).resolve().parent.parent.parent,
|
| 44 |
+
])
|
| 45 |
+
for p in candidates:
|
| 46 |
+
try:
|
| 47 |
+
if p.exists():
|
| 48 |
+
return p.resolve()
|
| 49 |
+
except OSError:
|
| 50 |
+
continue
|
| 51 |
+
return candidates[-1].resolve()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
REPO_ROOT = _resolve_repo_root()
|
| 55 |
+
LOG_DIR = REPO_ROOT / "training" / "runs"
|
| 56 |
+
try:
|
| 57 |
+
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
except OSError as exc: # pragma: no cover - read-only filesystem fallback
|
| 59 |
+
logger.warning("could not create %s (%s); using /tmp", LOG_DIR, exc)
|
| 60 |
+
LOG_DIR = Path("/tmp/cernenv-runs")
|
| 61 |
+
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
LOG_FILE = LOG_DIR / "training.log"
|
| 63 |
+
METRICS_FILE = REPO_ROOT / "training" / "plots" / "metrics_summary.json"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _env(name: str, default: str) -> str:
|
| 67 |
+
return os.environ.get(name, default)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
CONFIG = {
|
| 71 |
+
"model_name": _env("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct"),
|
| 72 |
+
"difficulty": _env("DIFFICULTY", "easy"),
|
| 73 |
+
"total_episodes": int(_env("TOTAL_EPISODES", "400")),
|
| 74 |
+
"max_steps": int(_env("MAX_STEPS", "18")),
|
| 75 |
+
"num_generations": int(_env("NUM_GENERATIONS", "4")),
|
| 76 |
+
"output_dir": _env("OUTPUT_DIR", "training/runs/unsloth-grpo"),
|
| 77 |
+
"hf_username": _env("HF_USERNAME", "YOUR_HF_USERNAME"),
|
| 78 |
+
"push_repo": _env(
|
| 79 |
+
"PUSH_REPO",
|
| 80 |
+
f"{_env('HF_USERNAME', 'YOUR_HF_USERNAME')}/cernenv-grpo-qwen2.5-3b",
|
| 81 |
+
),
|
| 82 |
+
"autostart": _env("AUTOSTART", "0") == "1",
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ── Run state ────────────────────────────────────────────────────────────
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class RunState:
|
| 90 |
+
def __init__(self) -> None:
|
| 91 |
+
self.lock = threading.Lock()
|
| 92 |
+
self.thread: Optional[threading.Thread] = None
|
| 93 |
+
self.process: Optional[subprocess.Popen] = None
|
| 94 |
+
self.status: str = "idle" # idle | running | finished | failed
|
| 95 |
+
self.started_at: Optional[str] = None
|
| 96 |
+
self.finished_at: Optional[str] = None
|
| 97 |
+
self.last_error: Optional[str] = None
|
| 98 |
+
self.last_config: Dict[str, Any] = {}
|
| 99 |
+
|
| 100 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 101 |
+
with self.lock:
|
| 102 |
+
return {
|
| 103 |
+
"status": self.status,
|
| 104 |
+
"started_at": self.started_at,
|
| 105 |
+
"finished_at": self.finished_at,
|
| 106 |
+
"last_error": self.last_error,
|
| 107 |
+
"last_config": self.last_config,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
STATE = RunState()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ── Training pipeline ────────────────────────────────────────────────────
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _stream_subprocess(cmd: list[str], log_handle) -> int:
|
| 118 |
+
log_handle.write(f"\n$ {' '.join(cmd)}\n")
|
| 119 |
+
log_handle.flush()
|
| 120 |
+
proc = subprocess.Popen(
|
| 121 |
+
cmd,
|
| 122 |
+
cwd=str(REPO_ROOT),
|
| 123 |
+
stdout=subprocess.PIPE,
|
| 124 |
+
stderr=subprocess.STDOUT,
|
| 125 |
+
bufsize=1,
|
| 126 |
+
universal_newlines=True,
|
| 127 |
+
env={**os.environ, "PYTHONPATH": str(REPO_ROOT)},
|
| 128 |
+
)
|
| 129 |
+
STATE.process = proc
|
| 130 |
+
assert proc.stdout is not None
|
| 131 |
+
for line in proc.stdout:
|
| 132 |
+
log_handle.write(line)
|
| 133 |
+
log_handle.flush()
|
| 134 |
+
rc = proc.wait()
|
| 135 |
+
log_handle.write(f"[exit code {rc}]\n")
|
| 136 |
+
log_handle.flush()
|
| 137 |
+
STATE.process = None
|
| 138 |
+
return rc
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _training_pipeline(config: Dict[str, Any]) -> None:
|
| 142 |
+
started = datetime.now(timezone.utc).isoformat()
|
| 143 |
+
with STATE.lock:
|
| 144 |
+
STATE.status = "running"
|
| 145 |
+
STATE.started_at = started
|
| 146 |
+
STATE.finished_at = None
|
| 147 |
+
STATE.last_error = None
|
| 148 |
+
STATE.last_config = dict(config)
|
| 149 |
+
|
| 150 |
+
LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
with open(LOG_FILE, "a") as log:
|
| 152 |
+
log.write(f"\n=== Training started {started} ===\n")
|
| 153 |
+
log.write(json.dumps(config, indent=2) + "\n")
|
| 154 |
+
log.flush()
|
| 155 |
+
try:
|
| 156 |
+
output_dir = config["output_dir"]
|
| 157 |
+
difficulty = config["difficulty"]
|
| 158 |
+
max_steps = str(config["max_steps"])
|
| 159 |
+
episodes = str(config["total_episodes"])
|
| 160 |
+
num_gens = str(config["num_generations"])
|
| 161 |
+
model_name = config["model_name"]
|
| 162 |
+
push_repo = config["push_repo"]
|
| 163 |
+
eval_pre = "training/runs/eval_pre_train.jsonl"
|
| 164 |
+
eval_post = "training/runs/eval_post_train.jsonl"
|
| 165 |
+
plots_dir = "training/plots"
|
| 166 |
+
|
| 167 |
+
log.write("\n--- baseline (heuristic / oracle / random) ---\n")
|
| 168 |
+
log.flush()
|
| 169 |
+
for agent in ("random", "heuristic", "oracle"):
|
| 170 |
+
_stream_subprocess(
|
| 171 |
+
[
|
| 172 |
+
sys.executable, "-m", "scripts.run_agent",
|
| 173 |
+
"--agent", agent, "--difficulty", difficulty,
|
| 174 |
+
"--episodes", "3", "--quiet",
|
| 175 |
+
],
|
| 176 |
+
log,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
log.write("\n--- pre-train evaluation ---\n")
|
| 180 |
+
log.flush()
|
| 181 |
+
rc = _stream_subprocess(
|
| 182 |
+
[
|
| 183 |
+
sys.executable, "-m", "training.evaluate",
|
| 184 |
+
"--model_name", model_name,
|
| 185 |
+
"--difficulty", difficulty,
|
| 186 |
+
"--episodes", "16",
|
| 187 |
+
"--max_steps", max_steps,
|
| 188 |
+
"--tag", "pre_train",
|
| 189 |
+
"--out", eval_pre,
|
| 190 |
+
],
|
| 191 |
+
log,
|
| 192 |
+
)
|
| 193 |
+
if rc != 0:
|
| 194 |
+
raise RuntimeError(f"pre-train eval failed (rc={rc})")
|
| 195 |
+
|
| 196 |
+
log.write("\n--- GRPO training ---\n")
|
| 197 |
+
log.flush()
|
| 198 |
+
rc = _stream_subprocess(
|
| 199 |
+
[
|
| 200 |
+
sys.executable, "-m", "training.training_unsloth",
|
| 201 |
+
"--model_name", model_name,
|
| 202 |
+
"--difficulty", difficulty,
|
| 203 |
+
"--total_episodes", episodes,
|
| 204 |
+
"--max_steps", max_steps,
|
| 205 |
+
"--num_generations", num_gens,
|
| 206 |
+
"--output_dir", output_dir,
|
| 207 |
+
],
|
| 208 |
+
log,
|
| 209 |
+
)
|
| 210 |
+
if rc != 0:
|
| 211 |
+
raise RuntimeError(f"training failed (rc={rc})")
|
| 212 |
+
|
| 213 |
+
log.write("\n--- post-train evaluation ---\n")
|
| 214 |
+
log.flush()
|
| 215 |
+
rc = _stream_subprocess(
|
| 216 |
+
[
|
| 217 |
+
sys.executable, "-m", "training.evaluate",
|
| 218 |
+
"--model_name", model_name,
|
| 219 |
+
"--adapter_dir", output_dir,
|
| 220 |
+
"--difficulty", difficulty,
|
| 221 |
+
"--episodes", "16",
|
| 222 |
+
"--max_steps", max_steps,
|
| 223 |
+
"--tag", "post_train",
|
| 224 |
+
"--out", eval_post,
|
| 225 |
+
],
|
| 226 |
+
log,
|
| 227 |
+
)
|
| 228 |
+
if rc != 0:
|
| 229 |
+
raise RuntimeError(f"post-train eval failed (rc={rc})")
|
| 230 |
+
|
| 231 |
+
log.write("\n--- plots ---\n")
|
| 232 |
+
log.flush()
|
| 233 |
+
_stream_subprocess(
|
| 234 |
+
[
|
| 235 |
+
sys.executable, "-m", "training.plots",
|
| 236 |
+
"--pre", eval_pre,
|
| 237 |
+
"--post", eval_post,
|
| 238 |
+
"--out_dir", plots_dir,
|
| 239 |
+
],
|
| 240 |
+
log,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
if os.environ.get("HF_TOKEN"):
|
| 244 |
+
log.write("\n--- push adapters to Hub ---\n")
|
| 245 |
+
log.flush()
|
| 246 |
+
_stream_subprocess(
|
| 247 |
+
[
|
| 248 |
+
sys.executable, "-m", "scripts.push_to_hub", "model",
|
| 249 |
+
"--adapter_dir", output_dir,
|
| 250 |
+
"--repo_id", push_repo,
|
| 251 |
+
"--base_model", model_name,
|
| 252 |
+
],
|
| 253 |
+
log,
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
log.write("\n[skip] HF_TOKEN not set — not pushing to Hub\n")
|
| 257 |
+
log.flush()
|
| 258 |
+
|
| 259 |
+
with STATE.lock:
|
| 260 |
+
STATE.status = "finished"
|
| 261 |
+
except Exception as exc:
|
| 262 |
+
logger.exception("training pipeline failed")
|
| 263 |
+
with STATE.lock:
|
| 264 |
+
STATE.status = "failed"
|
| 265 |
+
STATE.last_error = str(exc)
|
| 266 |
+
finally:
|
| 267 |
+
finished = datetime.now(timezone.utc).isoformat()
|
| 268 |
+
log.write(f"\n=== Training ended {finished} ===\n")
|
| 269 |
+
log.flush()
|
| 270 |
+
with STATE.lock:
|
| 271 |
+
STATE.finished_at = finished
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _start_training(config: Dict[str, Any]) -> None:
|
| 275 |
+
with STATE.lock:
|
| 276 |
+
if STATE.status == "running":
|
| 277 |
+
raise RuntimeError("a training run is already in progress")
|
| 278 |
+
STATE.thread = threading.Thread(
|
| 279 |
+
target=_training_pipeline,
|
| 280 |
+
args=(config,),
|
| 281 |
+
name="cernenv-trainer",
|
| 282 |
+
daemon=True,
|
| 283 |
+
)
|
| 284 |
+
STATE.thread.start()
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ── FastAPI app ──────────────────────────────────────────────────────────
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
app = FastAPI(title="CERNenv Trainer", version="0.1.0")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
_HTML = """\
|
| 294 |
+
<!doctype html>
|
| 295 |
+
<html lang=en>
|
| 296 |
+
<head>
|
| 297 |
+
<meta charset=utf-8>
|
| 298 |
+
<title>CERNenv Trainer</title>
|
| 299 |
+
<style>
|
| 300 |
+
body {{ font-family: ui-sans-serif, system-ui, sans-serif; margin: 2rem auto; max-width: 760px; color:#111 }}
|
| 301 |
+
h1 {{ margin-bottom: 0 }}
|
| 302 |
+
.muted {{ color:#666 }}
|
| 303 |
+
pre {{ background:#0e1116; color:#e6edf3; padding:1rem; border-radius:6px; overflow-x:auto; max-height:50vh }}
|
| 304 |
+
button {{ font-size:1rem; padding:.6rem 1rem; border-radius:6px; border:1px solid #888; background:#fff; cursor:pointer }}
|
| 305 |
+
.pill {{ display:inline-block; padding:.1rem .5rem; border-radius:999px; background:#eef; color:#225 }}
|
| 306 |
+
.ok {{ background:#dfd; color:#272 }}
|
| 307 |
+
.fail {{ background:#fdd; color:#822 }}
|
| 308 |
+
.run {{ background:#fdf6d8; color:#774 }}
|
| 309 |
+
table {{ border-collapse:collapse; }}
|
| 310 |
+
td {{ padding:.2rem .8rem .2rem 0; }}
|
| 311 |
+
</style>
|
| 312 |
+
</head>
|
| 313 |
+
<body>
|
| 314 |
+
<h1>⚛️ CERNenv Trainer</h1>
|
| 315 |
+
<p class=muted>GRPO + Unsloth + LoRA on the CERNenv LHC discovery environment.</p>
|
| 316 |
+
|
| 317 |
+
<h3>Status: <span id=status class=pill>?</span></h3>
|
| 318 |
+
<table id=meta></table>
|
| 319 |
+
|
| 320 |
+
<p>
|
| 321 |
+
<button onclick="startRun()">▶ Start training</button>
|
| 322 |
+
<button onclick="refresh()">↻ Refresh</button>
|
| 323 |
+
</p>
|
| 324 |
+
|
| 325 |
+
<h3>Logs (tail)</h3>
|
| 326 |
+
<pre id=logs>loading…</pre>
|
| 327 |
+
|
| 328 |
+
<script>
|
| 329 |
+
async function refresh() {{
|
| 330 |
+
const s = await fetch('/status').then(r => r.json());
|
| 331 |
+
const pill = document.getElementById('status');
|
| 332 |
+
pill.textContent = s.status;
|
| 333 |
+
pill.className = 'pill ' + ({{idle:'',running:'run',finished:'ok',failed:'fail'}}[s.status] || '');
|
| 334 |
+
|
| 335 |
+
const meta = document.getElementById('meta');
|
| 336 |
+
meta.innerHTML = '';
|
| 337 |
+
for (const [k, v] of Object.entries({{
|
| 338 |
+
started_at: s.started_at, finished_at: s.finished_at, error: s.last_error,
|
| 339 |
+
...(s.last_config || {{}}),
|
| 340 |
+
}})) {{
|
| 341 |
+
if (v == null || v === '') continue;
|
| 342 |
+
const tr = document.createElement('tr');
|
| 343 |
+
tr.innerHTML = `<td><b>${{k}}</b></td><td><code>${{v}}</code></td>`;
|
| 344 |
+
meta.appendChild(tr);
|
| 345 |
+
}}
|
| 346 |
+
|
| 347 |
+
const logs = await fetch('/logs?tail=200').then(r => r.text());
|
| 348 |
+
document.getElementById('logs').textContent = logs || '(no logs yet)';
|
| 349 |
+
}}
|
| 350 |
+
async function startRun() {{
|
| 351 |
+
await fetch('/train', {{method:'POST'}});
|
| 352 |
+
setTimeout(refresh, 500);
|
| 353 |
+
}}
|
| 354 |
+
refresh();
|
| 355 |
+
setInterval(refresh, 5000);
|
| 356 |
+
</script>
|
| 357 |
+
</body>
|
| 358 |
+
</html>
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
@app.get("/", response_class=HTMLResponse)
|
| 363 |
+
def index() -> HTMLResponse:
|
| 364 |
+
return HTMLResponse(_HTML)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
@app.get("/health")
|
| 368 |
+
def health() -> Dict[str, str]:
|
| 369 |
+
return {"status": "ok"}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@app.get("/status")
|
| 373 |
+
def status() -> JSONResponse:
|
| 374 |
+
return JSONResponse(STATE.to_dict())
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
@app.get("/metrics")
|
| 378 |
+
def metrics() -> JSONResponse:
|
| 379 |
+
if METRICS_FILE.exists():
|
| 380 |
+
try:
|
| 381 |
+
return JSONResponse(json.loads(METRICS_FILE.read_text()))
|
| 382 |
+
except Exception:
|
| 383 |
+
return JSONResponse({"error": "metrics file unreadable"}, status_code=500)
|
| 384 |
+
return JSONResponse({"pre": None, "post": None})
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
@app.get("/logs", response_class=PlainTextResponse)
|
| 388 |
+
def logs(tail: int = 400) -> PlainTextResponse:
|
| 389 |
+
if not LOG_FILE.exists():
|
| 390 |
+
return PlainTextResponse("")
|
| 391 |
+
text = LOG_FILE.read_text()
|
| 392 |
+
lines = text.splitlines()
|
| 393 |
+
return PlainTextResponse("\n".join(lines[-max(tail, 1):]))
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
@app.post("/train")
|
| 397 |
+
def train() -> JSONResponse:
|
| 398 |
+
try:
|
| 399 |
+
_start_training(dict(CONFIG))
|
| 400 |
+
except RuntimeError as exc:
|
| 401 |
+
raise HTTPException(status_code=409, detail=str(exc))
|
| 402 |
+
return JSONResponse({"status": "started", "config": CONFIG})
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
@app.on_event("startup")
|
| 406 |
+
def _maybe_autostart() -> None:
|
| 407 |
+
if CONFIG["autostart"]:
|
| 408 |
+
try:
|
| 409 |
+
_start_training(dict(CONFIG))
|
| 410 |
+
logger.info("autostarted training run")
|
| 411 |
+
except RuntimeError as exc:
|
| 412 |
+
logger.warning("autostart skipped: %s", exc)
|
space/training/requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
+
torch==2.4.0
|
| 3 |
+
unsloth
|
| 4 |
+
unsloth_zoo
|
| 5 |
+
transformers>=4.44.0
|
| 6 |
+
trl>=0.9.0
|
| 7 |
+
peft>=0.10.0
|
| 8 |
+
accelerate>=1.0.0
|
| 9 |
+
datasets>=2.18.0
|
| 10 |
+
bitsandbytes>=0.43.0
|
| 11 |
+
matplotlib>=3.8.0
|
| 12 |
+
numpy>=1.24.0
|
| 13 |
+
scipy>=1.10.0
|
| 14 |
+
pydantic>=2.0.0
|
| 15 |
+
fastapi>=0.110.0
|
| 16 |
+
uvicorn>=0.27.0
|
| 17 |
+
huggingface_hub>=0.24.0
|
| 18 |
+
git+https://github.com/meta-pytorch/OpenEnv.git
|
training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Training utilities: rollout collection, GRPO/PPO training, evaluation."""
|
training/colab_train_unsloth.ipynb
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# CERNenv — Unsloth + LoRA + GRPO training\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Trains a small instruction-tuned LLM (Large Language Model) to act as an LHC (Large Hadron Collider) physicist inside the **CERNenv** OpenEnv environment, using **GRPO** (Group-Relative Policy Optimization) with **Unsloth** + **LoRA** (Low-Rank Adaptation).\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"Runs on:\n",
|
| 12 |
+
"- a **Hugging Face Space** with an A100 GPU (recommended)\n",
|
| 13 |
+
"- Google **Colab** (T4 / L4) as a fallback\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"Outputs:\n",
|
| 16 |
+
"- LoRA adapters at `runs/unsloth-grpo`\n",
|
| 17 |
+
"- Reward / success-rate curves at `training/plots/`\n",
|
| 18 |
+
"- Final adapters pushed to your Hugging Face Hub repo"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"source": [
|
| 25 |
+
"## 1. Environment setup"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"%%capture\n",
|
| 35 |
+
"import sys, os\n",
|
| 36 |
+
"IN_COLAB = 'google.colab' in sys.modules\n",
|
| 37 |
+
"IN_HF_SPACE = os.environ.get('SPACE_ID') is not None\n",
|
| 38 |
+
"print('Colab:', IN_COLAB, '| HF Space:', IN_HF_SPACE)\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"if IN_COLAB:\n",
|
| 41 |
+
" !git clone https://github.com/YOUR_HF_USERNAME/CERNenv.git\n",
|
| 42 |
+
" %cd CERNenv\n",
|
| 43 |
+
"elif IN_HF_SPACE:\n",
|
| 44 |
+
" %cd /home/user/app\n",
|
| 45 |
+
"else:\n",
|
| 46 |
+
" pass\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"!pip install -q -r requirements-unsloth.txt"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": null,
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"outputs": [],
|
| 56 |
+
"source": [
|
| 57 |
+
"import os, json, subprocess, sys\n",
|
| 58 |
+
"from pathlib import Path\n",
|
| 59 |
+
"import torch\n",
|
| 60 |
+
"print('CUDA:', torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)\n",
|
| 61 |
+
"Path('training/plots').mkdir(parents=True, exist_ok=True)\n",
|
| 62 |
+
"Path('training/runs').mkdir(parents=True, exist_ok=True)"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "markdown",
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"source": [
|
| 69 |
+
"## 2. Hugging Face authentication\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"On a Space, set the `HF_TOKEN` Space-secret. Locally / on Colab, paste a token below. The token must have **write** access to your model repo."
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"from huggingface_hub import login\n",
|
| 81 |
+
"HF_TOKEN = os.environ.get('HF_TOKEN')\n",
|
| 82 |
+
"if HF_TOKEN:\n",
|
| 83 |
+
" login(HF_TOKEN)\n",
|
| 84 |
+
" print('logged in via HF_TOKEN env var')\n",
|
| 85 |
+
"else:\n",
|
| 86 |
+
" from getpass import getpass\n",
|
| 87 |
+
" login(getpass('Paste HF token: '))"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "markdown",
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"source": [
|
| 94 |
+
"## 3. Configure the run"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": null,
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"HF_USERNAME = os.environ.get('HF_USERNAME', 'YOUR_HF_USERNAME')\n",
|
| 104 |
+
"MODEL_NAME = os.environ.get('MODEL_NAME', 'unsloth/Qwen2.5-3B-Instruct')\n",
|
| 105 |
+
"TOTAL_EPISODES = int(os.environ.get('TOTAL_EPISODES', '400'))\n",
|
| 106 |
+
"DIFFICULTY = os.environ.get('DIFFICULTY', 'easy')\n",
|
| 107 |
+
"MAX_STEPS = int(os.environ.get('MAX_STEPS', '18'))\n",
|
| 108 |
+
"OUTPUT_DIR = os.environ.get('OUTPUT_DIR', 'training/runs/unsloth-grpo')\n",
|
| 109 |
+
"PUSH_REPO = os.environ.get('PUSH_REPO', f'{HF_USERNAME}/cernenv-grpo-qwen2.5-3b')\n",
|
| 110 |
+
"print({'model': MODEL_NAME, 'episodes': TOTAL_EPISODES, 'difficulty': DIFFICULTY,\n",
|
| 111 |
+
" 'max_steps': MAX_STEPS, 'out': OUTPUT_DIR, 'repo': PUSH_REPO})"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "markdown",
|
| 116 |
+
"metadata": {},
|
| 117 |
+
"source": [
|
| 118 |
+
"## 4. Quick sanity check: heuristic vs random baseline\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"Before training, confirm the environment + reward signal are working."
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "code",
|
| 125 |
+
"execution_count": null,
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"outputs": [],
|
| 128 |
+
"source": [
|
| 129 |
+
"!PYTHONPATH=. python -m scripts.run_agent --agent random --difficulty $DIFFICULTY --episodes 3 --quiet\n",
|
| 130 |
+
"!PYTHONPATH=. python -m scripts.run_agent --agent heuristic --difficulty $DIFFICULTY --episodes 3 --quiet\n",
|
| 131 |
+
"!PYTHONPATH=. python -m scripts.run_agent --agent oracle --difficulty $DIFFICULTY --episodes 3 --quiet"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "markdown",
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"source": [
|
| 138 |
+
"## 5. Pre-training evaluation (zero-shot LLM)"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"cell_type": "code",
|
| 143 |
+
"execution_count": null,
|
| 144 |
+
"metadata": {},
|
| 145 |
+
"outputs": [],
|
| 146 |
+
"source": [
|
| 147 |
+
"!PYTHONPATH=. python -m training.evaluate \\\n",
|
| 148 |
+
" --model_name $MODEL_NAME \\\n",
|
| 149 |
+
" --difficulty $DIFFICULTY \\\n",
|
| 150 |
+
" --episodes 16 \\\n",
|
| 151 |
+
" --max_steps $MAX_STEPS \\\n",
|
| 152 |
+
" --tag pre_train \\\n",
|
| 153 |
+
" --out training/runs/eval_pre_train.jsonl"
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "markdown",
|
| 158 |
+
"metadata": {},
|
| 159 |
+
"source": [
|
| 160 |
+
"## 6. Train with Unsloth + LoRA + GRPO"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": null,
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"outputs": [],
|
| 168 |
+
"source": [
|
| 169 |
+
"!PYTHONPATH=. python -m training.training_unsloth \\\n",
|
| 170 |
+
" --model_name $MODEL_NAME \\\n",
|
| 171 |
+
" --difficulty $DIFFICULTY \\\n",
|
| 172 |
+
" --total_episodes $TOTAL_EPISODES \\\n",
|
| 173 |
+
" --max_steps $MAX_STEPS \\\n",
|
| 174 |
+
" --num_generations 4 \\\n",
|
| 175 |
+
" --output_dir $OUTPUT_DIR"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"cell_type": "markdown",
|
| 180 |
+
"metadata": {},
|
| 181 |
+
"source": [
|
| 182 |
+
"## 7. Post-training evaluation"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "code",
|
| 187 |
+
"execution_count": null,
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [],
|
| 190 |
+
"source": [
|
| 191 |
+
"!PYTHONPATH=. python -m training.evaluate \\\n",
|
| 192 |
+
" --model_name $MODEL_NAME \\\n",
|
| 193 |
+
" --adapter_dir $OUTPUT_DIR \\\n",
|
| 194 |
+
" --difficulty $DIFFICULTY \\\n",
|
| 195 |
+
" --episodes 16 \\\n",
|
| 196 |
+
" --max_steps $MAX_STEPS \\\n",
|
| 197 |
+
" --tag post_train \\\n",
|
| 198 |
+
" --out training/runs/eval_post_train.jsonl"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "markdown",
|
| 203 |
+
"metadata": {},
|
| 204 |
+
"source": [
|
| 205 |
+
"## 8. Plot before / after"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "code",
|
| 210 |
+
"execution_count": null,
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"outputs": [],
|
| 213 |
+
"source": [
|
| 214 |
+
"!PYTHONPATH=. python -m training.plots \\\n",
|
| 215 |
+
" --pre training/runs/eval_pre_train.jsonl \\\n",
|
| 216 |
+
" --post training/runs/eval_post_train.jsonl \\\n",
|
| 217 |
+
" --out_dir training/plots"
|
| 218 |
+
]
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"cell_type": "markdown",
|
| 222 |
+
"metadata": {},
|
| 223 |
+
"source": [
|
| 224 |
+
"## 9. Push trained adapters to the Hugging Face Hub"
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"cell_type": "code",
|
| 229 |
+
"execution_count": null,
|
| 230 |
+
"metadata": {},
|
| 231 |
+
"outputs": [],
|
| 232 |
+
"source": [
|
| 233 |
+
"!PYTHONPATH=. python -m scripts.push_to_hub model \\\n",
|
| 234 |
+
" --adapter_dir $OUTPUT_DIR \\\n",
|
| 235 |
+
" --repo_id $PUSH_REPO \\\n",
|
| 236 |
+
" --base_model $MODEL_NAME"
|
| 237 |
+
]
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"cell_type": "markdown",
|
| 241 |
+
"metadata": {},
|
| 242 |
+
"source": [
|
| 243 |
+
"Done. Reward + success-rate plots live in `training/plots/`, model adapters at `OUTPUT_DIR`, and a copy is pushed to `PUSH_REPO`."
|
| 244 |
+
]
|
| 245 |
+
}
|
| 246 |
+
],
|
| 247 |
+
"metadata": {
|
| 248 |
+
"kernelspec": {
|
| 249 |
+
"display_name": "Python 3",
|
| 250 |
+
"language": "python",
|
| 251 |
+
"name": "python3"
|
| 252 |
+
},
|
| 253 |
+
"language_info": {
|
| 254 |
+
"name": "python",
|
| 255 |
+
"version": "3.11"
|
| 256 |
+
}
|
| 257 |
+
},
|
| 258 |
+
"nbformat": 4,
|
| 259 |
+
"nbformat_minor": 5
|
| 260 |
+
}
|
training/evaluate.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate an LLM (with optional LoRA adapters) on CERNenv.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
python -m training.evaluate --model_name unsloth/Qwen2.5-3B-Instruct \\
|
| 5 |
+
--difficulty easy --episodes 16 --tag pre_train \\
|
| 6 |
+
--out training/runs/eval_pre_train.jsonl
|
| 7 |
+
|
| 8 |
+
python -m training.evaluate --model_name unsloth/Qwen2.5-3B-Instruct \\
|
| 9 |
+
--adapter_dir training/runs/unsloth-grpo --difficulty easy \\
|
| 10 |
+
--episodes 16 --tag post_train --out training/runs/eval_post_train.jsonl
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import asdict
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any, Dict, List, Optional
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _build_generate_fn(
|
| 29 |
+
*,
|
| 30 |
+
model_name: str,
|
| 31 |
+
adapter_dir: Optional[str],
|
| 32 |
+
use_unsloth: bool,
|
| 33 |
+
max_seq_length: int,
|
| 34 |
+
):
|
| 35 |
+
if use_unsloth:
|
| 36 |
+
from unsloth import FastLanguageModel # type: ignore
|
| 37 |
+
|
| 38 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 39 |
+
model_name=model_name,
|
| 40 |
+
max_seq_length=max_seq_length,
|
| 41 |
+
load_in_4bit=True,
|
| 42 |
+
fast_inference=True,
|
| 43 |
+
)
|
| 44 |
+
if adapter_dir:
|
| 45 |
+
model.load_adapter(adapter_dir)
|
| 46 |
+
FastLanguageModel.for_inference(model)
|
| 47 |
+
else:
|
| 48 |
+
import torch
|
| 49 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 50 |
+
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 52 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 53 |
+
model_name,
|
| 54 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 55 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 56 |
+
)
|
| 57 |
+
if adapter_dir:
|
| 58 |
+
from peft import PeftModel # type: ignore
|
| 59 |
+
model = PeftModel.from_pretrained(model, adapter_dir)
|
| 60 |
+
|
| 61 |
+
if tokenizer.pad_token is None:
|
| 62 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 63 |
+
|
| 64 |
+
def prompt_fn(chat: List[Dict[str, str]]) -> str:
|
| 65 |
+
return tokenizer.apply_chat_template(
|
| 66 |
+
chat, add_generation_prompt=True, tokenize=False
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def generate_fn(prompt: str, config) -> str:
|
| 70 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 71 |
+
outputs = model.generate(
|
| 72 |
+
**inputs,
|
| 73 |
+
max_new_tokens=config.max_new_tokens,
|
| 74 |
+
do_sample=True,
|
| 75 |
+
temperature=config.temperature,
|
| 76 |
+
top_p=config.top_p,
|
| 77 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 78 |
+
)
|
| 79 |
+
gen = outputs[0][inputs["input_ids"].shape[1]:]
|
| 80 |
+
return tokenizer.decode(gen, skip_special_tokens=True)
|
| 81 |
+
|
| 82 |
+
return prompt_fn, generate_fn
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def main() -> None: # pragma: no cover
|
| 86 |
+
parser = argparse.ArgumentParser()
|
| 87 |
+
parser.add_argument("--model_name", required=True)
|
| 88 |
+
parser.add_argument("--adapter_dir", default=None)
|
| 89 |
+
parser.add_argument("--scenario", default=None)
|
| 90 |
+
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
|
| 91 |
+
parser.add_argument("--episodes", type=int, default=16)
|
| 92 |
+
parser.add_argument("--seed", type=int, default=1000)
|
| 93 |
+
parser.add_argument("--max_steps", type=int, default=18)
|
| 94 |
+
parser.add_argument("--max_seq_length", type=int, default=2048)
|
| 95 |
+
parser.add_argument("--no_unsloth", action="store_true")
|
| 96 |
+
parser.add_argument("--tag", default="eval")
|
| 97 |
+
parser.add_argument("--out", required=True)
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
|
| 100 |
+
from server.environment import CERNCollisionEnvironment
|
| 101 |
+
from training.llm_agent import LLMAgentConfig
|
| 102 |
+
from training.rollouts import collect_episode, save_episodes_jsonl
|
| 103 |
+
|
| 104 |
+
use_unsloth = not args.no_unsloth
|
| 105 |
+
try:
|
| 106 |
+
prompt_fn, generate_fn = _build_generate_fn(
|
| 107 |
+
model_name=args.model_name,
|
| 108 |
+
adapter_dir=args.adapter_dir,
|
| 109 |
+
use_unsloth=use_unsloth,
|
| 110 |
+
max_seq_length=args.max_seq_length,
|
| 111 |
+
)
|
| 112 |
+
except ImportError as exc:
|
| 113 |
+
logger.warning("Unsloth not available (%s); falling back to transformers.", exc)
|
| 114 |
+
prompt_fn, generate_fn = _build_generate_fn(
|
| 115 |
+
model_name=args.model_name,
|
| 116 |
+
adapter_dir=args.adapter_dir,
|
| 117 |
+
use_unsloth=False,
|
| 118 |
+
max_seq_length=args.max_seq_length,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
env = CERNCollisionEnvironment(max_steps=args.max_steps)
|
| 122 |
+
cfg = LLMAgentConfig()
|
| 123 |
+
|
| 124 |
+
episodes = []
|
| 125 |
+
for ep in range(args.episodes):
|
| 126 |
+
seed = args.seed + ep
|
| 127 |
+
rec = collect_episode(
|
| 128 |
+
env=env,
|
| 129 |
+
seed=seed,
|
| 130 |
+
scenario=args.scenario,
|
| 131 |
+
difficulty=args.difficulty,
|
| 132 |
+
prompt_fn=prompt_fn,
|
| 133 |
+
generate_fn=generate_fn,
|
| 134 |
+
config=cfg,
|
| 135 |
+
)
|
| 136 |
+
episodes.append(rec)
|
| 137 |
+
logger.info(
|
| 138 |
+
"[%s][%d/%d] reward=%+.3f discovered=%s mass=%s channel=%s",
|
| 139 |
+
args.tag, ep + 1, args.episodes,
|
| 140 |
+
rec.cumulative_reward, rec.discovered, rec.correct_mass, rec.correct_channel,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
| 144 |
+
save_episodes_jsonl(episodes, args.out)
|
| 145 |
+
|
| 146 |
+
rewards = [e.cumulative_reward for e in episodes]
|
| 147 |
+
success = sum(1 for e in episodes if e.discovered) / len(episodes)
|
| 148 |
+
logger.info("[%s] mean_reward=%.3f success_rate=%.2f", args.tag, sum(rewards) / len(rewards), success)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__": # pragma: no cover
|
| 152 |
+
main()
|
training/llm_agent.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM (Large Language Model) agent that picks the next CERNenv action.
|
| 2 |
+
|
| 3 |
+
The agent renders an observation as a short prompt, asks the LLM for a
|
| 4 |
+
JSON-formatted ``ExperimentAction``, validates the response, and falls back
|
| 5 |
+
to a safe default action if parsing fails. This is the unit shared by
|
| 6 |
+
evaluation and the GRPO (Group-Relative Policy Optimization) training loop.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import re
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Any, Dict, List, Optional
|
| 15 |
+
|
| 16 |
+
from models import (
|
| 17 |
+
ActionType,
|
| 18 |
+
CollisionObservation,
|
| 19 |
+
ExperimentAction,
|
| 20 |
+
build_agent_observation_context,
|
| 21 |
+
build_agent_system_prompt,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
_VALID_ACTIONS = {a.value for a in ActionType}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class LLMAgentConfig:
|
| 30 |
+
"""Knobs for prompt formatting and decoding."""
|
| 31 |
+
|
| 32 |
+
max_history_steps: int = 6
|
| 33 |
+
temperature: float = 0.7
|
| 34 |
+
max_new_tokens: int = 256
|
| 35 |
+
top_p: float = 0.95
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def render_history(obs: CollisionObservation, max_steps: int) -> str:
|
| 39 |
+
if not obs.pipeline_history:
|
| 40 |
+
return " (none yet — pick a starting action)"
|
| 41 |
+
lines: List[str] = []
|
| 42 |
+
history = obs.pipeline_history[-max_steps:]
|
| 43 |
+
for rec in history:
|
| 44 |
+
success = "OK" if rec.success else "FAIL"
|
| 45 |
+
lines.append(
|
| 46 |
+
f" step {rec.step_index:>2} {rec.action_type.value:<24} {success}: {rec.output_summary[:80]}"
|
| 47 |
+
)
|
| 48 |
+
return "\n".join(lines)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def render_resources(obs: CollisionObservation) -> str:
|
| 52 |
+
r = obs.resource_usage
|
| 53 |
+
return (
|
| 54 |
+
f"budget {r.budget_remaining_musd:.1f}/{r.budget_remaining_musd + r.budget_used_musd:.1f} M$ left, "
|
| 55 |
+
f"luminosity {r.luminosity_remaining_fb:.1f}/{r.luminosity_remaining_fb + r.luminosity_used_fb:.1f} fb^-1 left, "
|
| 56 |
+
f"time {r.time_remaining_days:.0f}/{r.time_remaining_days + r.time_used_days:.0f} days left"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def render_user_prompt(
|
| 61 |
+
obs: CollisionObservation,
|
| 62 |
+
config: LLMAgentConfig = LLMAgentConfig(),
|
| 63 |
+
) -> str:
|
| 64 |
+
parts: List[str] = []
|
| 65 |
+
parts.append("Task:")
|
| 66 |
+
parts.append(" " + obs.task.problem_statement.strip())
|
| 67 |
+
parts.append("")
|
| 68 |
+
parts.append("Public state:")
|
| 69 |
+
parts.append(" " + build_agent_observation_context(obs).replace("\n", "\n "))
|
| 70 |
+
parts.append("")
|
| 71 |
+
parts.append("Resources:")
|
| 72 |
+
parts.append(" " + render_resources(obs))
|
| 73 |
+
parts.append("")
|
| 74 |
+
parts.append("Recent steps:")
|
| 75 |
+
parts.append(render_history(obs, max_steps=config.max_history_steps))
|
| 76 |
+
if obs.rule_violations:
|
| 77 |
+
parts.append("")
|
| 78 |
+
parts.append("Last-step violations: " + ", ".join(obs.rule_violations))
|
| 79 |
+
parts.append("")
|
| 80 |
+
parts.append("Choose ONE next action and respond with a single JSON object.")
|
| 81 |
+
return "\n".join(parts)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_chat(
|
| 85 |
+
obs: CollisionObservation,
|
| 86 |
+
config: LLMAgentConfig = LLMAgentConfig(),
|
| 87 |
+
) -> List[Dict[str, str]]:
|
| 88 |
+
return [
|
| 89 |
+
{"role": "system", "content": build_agent_system_prompt()},
|
| 90 |
+
{"role": "user", "content": render_user_prompt(obs, config)},
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ── Robust JSON extraction ───────────────────────────────────────────────
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
_JSON_RE = re.compile(r"\{[\s\S]*\}")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def extract_first_json(text: str) -> Optional[Dict[str, Any]]:
|
| 101 |
+
"""Return the first parseable JSON object found inside ``text``."""
|
| 102 |
+
if not text:
|
| 103 |
+
return None
|
| 104 |
+
m = _JSON_RE.search(text)
|
| 105 |
+
if not m:
|
| 106 |
+
return None
|
| 107 |
+
candidate = m.group(0)
|
| 108 |
+
try:
|
| 109 |
+
return json.loads(candidate)
|
| 110 |
+
except json.JSONDecodeError:
|
| 111 |
+
# Try a relaxed pass: trim trailing commas
|
| 112 |
+
cleaned = re.sub(r",\s*([}\]])", r"\1", candidate)
|
| 113 |
+
try:
|
| 114 |
+
return json.loads(cleaned)
|
| 115 |
+
except json.JSONDecodeError:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def parse_action(text: str) -> Optional[ExperimentAction]:
|
| 120 |
+
payload = extract_first_json(text)
|
| 121 |
+
if payload is None:
|
| 122 |
+
return None
|
| 123 |
+
action_type = payload.get("action_type")
|
| 124 |
+
if action_type not in _VALID_ACTIONS:
|
| 125 |
+
return None
|
| 126 |
+
try:
|
| 127 |
+
return ExperimentAction(
|
| 128 |
+
action_type=ActionType(action_type),
|
| 129 |
+
method=payload.get("method") or None,
|
| 130 |
+
parameters=payload.get("parameters") or {},
|
| 131 |
+
justification=payload.get("justification"),
|
| 132 |
+
confidence=float(payload.get("confidence", 0.5) or 0.5),
|
| 133 |
+
)
|
| 134 |
+
except Exception:
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def safe_default_action(obs: CollisionObservation) -> ExperimentAction:
|
| 139 |
+
"""Picks the next sensible scripted step when the LLM output is invalid."""
|
| 140 |
+
prog = obs.pipeline_history
|
| 141 |
+
flags = {a.value: False for a in ActionType}
|
| 142 |
+
for rec in prog:
|
| 143 |
+
if rec.success:
|
| 144 |
+
flags[rec.action_type.value] = True
|
| 145 |
+
|
| 146 |
+
if not flags[ActionType.CONFIGURE_BEAM.value]:
|
| 147 |
+
return ExperimentAction(
|
| 148 |
+
action_type=ActionType.CONFIGURE_BEAM,
|
| 149 |
+
parameters={"beam_energy": "13TeV"},
|
| 150 |
+
justification="default fallback",
|
| 151 |
+
)
|
| 152 |
+
if not flags[ActionType.SELECT_CHANNEL.value]:
|
| 153 |
+
return ExperimentAction(
|
| 154 |
+
action_type=ActionType.SELECT_CHANNEL,
|
| 155 |
+
parameters={"channel": obs.task.available_channels[0] if obs.task.available_channels else "diphoton"},
|
| 156 |
+
justification="default fallback",
|
| 157 |
+
)
|
| 158 |
+
if not flags[ActionType.SET_TRIGGER.value]:
|
| 159 |
+
return ExperimentAction(
|
| 160 |
+
action_type=ActionType.SET_TRIGGER,
|
| 161 |
+
parameters={"trigger": "diphoton_hlt"},
|
| 162 |
+
justification="default fallback",
|
| 163 |
+
)
|
| 164 |
+
if not flags[ActionType.ALLOCATE_LUMINOSITY.value]:
|
| 165 |
+
return ExperimentAction(
|
| 166 |
+
action_type=ActionType.ALLOCATE_LUMINOSITY,
|
| 167 |
+
parameters={"luminosity_fb": 50.0},
|
| 168 |
+
justification="default fallback",
|
| 169 |
+
)
|
| 170 |
+
if not flags[ActionType.COLLECT_COLLISIONS.value]:
|
| 171 |
+
return ExperimentAction(
|
| 172 |
+
action_type=ActionType.COLLECT_COLLISIONS,
|
| 173 |
+
parameters={"luminosity_fb": 50.0},
|
| 174 |
+
justification="default fallback",
|
| 175 |
+
)
|
| 176 |
+
if not flags[ActionType.RECONSTRUCT_TRACKS.value]:
|
| 177 |
+
return ExperimentAction(
|
| 178 |
+
action_type=ActionType.RECONSTRUCT_TRACKS,
|
| 179 |
+
justification="default fallback",
|
| 180 |
+
)
|
| 181 |
+
if not flags[ActionType.BUILD_INVARIANT_MASS.value]:
|
| 182 |
+
return ExperimentAction(
|
| 183 |
+
action_type=ActionType.BUILD_INVARIANT_MASS,
|
| 184 |
+
parameters={"mass_window_gev": obs.task.mass_search_window_gev},
|
| 185 |
+
justification="default fallback",
|
| 186 |
+
)
|
| 187 |
+
if not flags[ActionType.FIT_RESONANCE.value]:
|
| 188 |
+
return ExperimentAction(
|
| 189 |
+
action_type=ActionType.FIT_RESONANCE,
|
| 190 |
+
method="ROOT_RooFit",
|
| 191 |
+
justification="default fallback",
|
| 192 |
+
)
|
| 193 |
+
if not flags[ActionType.ESTIMATE_SIGNIFICANCE.value]:
|
| 194 |
+
return ExperimentAction(
|
| 195 |
+
action_type=ActionType.ESTIMATE_SIGNIFICANCE,
|
| 196 |
+
method="Asimov_significance",
|
| 197 |
+
justification="default fallback",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
mass = obs.candidate_masses_gev[-1] if obs.candidate_masses_gev else 125.0
|
| 201 |
+
return ExperimentAction(
|
| 202 |
+
action_type=ActionType.SUBMIT_DISCOVERY_CLAIM,
|
| 203 |
+
parameters={
|
| 204 |
+
"claim": {
|
| 205 |
+
"mass_estimate_gev": mass,
|
| 206 |
+
"mass_uncertainty_gev": 1.0,
|
| 207 |
+
"significance_sigma": obs.cumulative_significance,
|
| 208 |
+
"decay_channel": obs.selected_channel or "diphoton",
|
| 209 |
+
"spin_hypothesis": 0,
|
| 210 |
+
"parity": "+",
|
| 211 |
+
"confidence": 0.7,
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
justification="default fallback claim",
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
__all__ = [
|
| 219 |
+
"LLMAgentConfig",
|
| 220 |
+
"build_chat",
|
| 221 |
+
"extract_first_json",
|
| 222 |
+
"parse_action",
|
| 223 |
+
"render_history",
|
| 224 |
+
"render_resources",
|
| 225 |
+
"render_user_prompt",
|
| 226 |
+
"safe_default_action",
|
| 227 |
+
]
|
training/plots.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plot before/after evaluation curves and reward breakdowns.
|
| 2 |
+
|
| 3 |
+
Reads two JSONL evaluation files (typically ``eval_pre_train.jsonl`` and
|
| 4 |
+
``eval_post_train.jsonl``) produced by ``training.evaluate`` and writes
|
| 5 |
+
publication-ready PNGs (Portable Network Graphics) under ``--out_dir``.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Dict, List
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _load(path: str) -> List[Dict[str, Any]]:
|
| 17 |
+
eps = []
|
| 18 |
+
with open(path) as f:
|
| 19 |
+
for line in f:
|
| 20 |
+
line = line.strip()
|
| 21 |
+
if line:
|
| 22 |
+
eps.append(json.loads(line))
|
| 23 |
+
return eps
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _summarise(eps: List[Dict[str, Any]]) -> Dict[str, float]:
|
| 27 |
+
if not eps:
|
| 28 |
+
return {"mean": 0.0, "success_rate": 0.0, "mass_acc": 0.0, "channel_acc": 0.0}
|
| 29 |
+
rewards = [float(e.get("cumulative_reward") or 0.0) for e in eps]
|
| 30 |
+
return {
|
| 31 |
+
"mean": sum(rewards) / len(rewards),
|
| 32 |
+
"success_rate": sum(1 for e in eps if e.get("discovered")) / len(eps),
|
| 33 |
+
"mass_acc": sum(1 for e in eps if e.get("correct_mass")) / len(eps),
|
| 34 |
+
"channel_acc": sum(1 for e in eps if e.get("correct_channel")) / len(eps),
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main() -> None: # pragma: no cover
|
| 39 |
+
parser = argparse.ArgumentParser()
|
| 40 |
+
parser.add_argument("--pre", required=True)
|
| 41 |
+
parser.add_argument("--post", required=True)
|
| 42 |
+
parser.add_argument("--out_dir", default="training/plots")
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
import matplotlib
|
| 46 |
+
matplotlib.use("Agg")
|
| 47 |
+
import matplotlib.pyplot as plt
|
| 48 |
+
|
| 49 |
+
pre = _load(args.pre)
|
| 50 |
+
post = _load(args.post)
|
| 51 |
+
pre_stats = _summarise(pre)
|
| 52 |
+
post_stats = _summarise(post)
|
| 53 |
+
|
| 54 |
+
out = Path(args.out_dir)
|
| 55 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
pre_rewards = [float(e.get("cumulative_reward") or 0.0) for e in pre]
|
| 58 |
+
post_rewards = [float(e.get("cumulative_reward") or 0.0) for e in post]
|
| 59 |
+
|
| 60 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 61 |
+
ax.hist(pre_rewards, bins=15, alpha=0.5, label=f"pre (μ={pre_stats['mean']:+.2f})")
|
| 62 |
+
ax.hist(post_rewards, bins=15, alpha=0.5, label=f"post (μ={post_stats['mean']:+.2f})")
|
| 63 |
+
ax.set_xlabel("episode cumulative reward")
|
| 64 |
+
ax.set_ylabel("episode count")
|
| 65 |
+
ax.set_title("CERNenv reward distribution: pre vs post training")
|
| 66 |
+
ax.legend()
|
| 67 |
+
fig.tight_layout()
|
| 68 |
+
fig.savefig(out / "reward_distribution.png", dpi=140)
|
| 69 |
+
plt.close(fig)
|
| 70 |
+
|
| 71 |
+
metrics = ["mean", "success_rate", "mass_acc", "channel_acc"]
|
| 72 |
+
pre_vals = [pre_stats[m] for m in metrics]
|
| 73 |
+
post_vals = [post_stats[m] for m in metrics]
|
| 74 |
+
x = list(range(len(metrics)))
|
| 75 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 76 |
+
ax.bar([i - 0.18 for i in x], pre_vals, width=0.36, label="pre")
|
| 77 |
+
ax.bar([i + 0.18 for i in x], post_vals, width=0.36, label="post")
|
| 78 |
+
ax.set_xticks(x)
|
| 79 |
+
ax.set_xticklabels(metrics)
|
| 80 |
+
ax.set_title("Mean reward & accuracy: pre vs post training")
|
| 81 |
+
ax.legend()
|
| 82 |
+
fig.tight_layout()
|
| 83 |
+
fig.savefig(out / "metrics_summary.png", dpi=140)
|
| 84 |
+
plt.close(fig)
|
| 85 |
+
|
| 86 |
+
with open(out / "metrics_summary.json", "w") as f:
|
| 87 |
+
json.dump({"pre": pre_stats, "post": post_stats}, f, indent=2)
|
| 88 |
+
|
| 89 |
+
print("wrote:", list(out.glob("*")))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__": # pragma: no cover
|
| 93 |
+
main()
|
training/rollouts.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Rollout collector for LLM-driven CERNenv episodes.
|
| 2 |
+
|
| 3 |
+
Runs an LLM agent in-process against ``CERNCollisionEnvironment`` and
|
| 4 |
+
records full per-step trajectories: prompt, completion, parsed action,
|
| 5 |
+
reward, observation snapshot, and final episode summary.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
from dataclasses import asdict, dataclass, field
|
| 13 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
from models import ActionType, CollisionObservation, ExperimentAction
|
| 16 |
+
from server.environment import CERNCollisionEnvironment
|
| 17 |
+
|
| 18 |
+
from .llm_agent import (
|
| 19 |
+
LLMAgentConfig,
|
| 20 |
+
build_chat,
|
| 21 |
+
parse_action,
|
| 22 |
+
safe_default_action,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
PromptFn = Callable[[List[Dict[str, str]]], str]
|
| 30 |
+
"""Callable: tokenizer-aware prompt formatter (e.g. apply_chat_template)."""
|
| 31 |
+
|
| 32 |
+
GenerateFn = Callable[[str, LLMAgentConfig], str]
|
| 33 |
+
"""Callable: actually run the LLM and return the raw completion string."""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class StepRecord:
|
| 38 |
+
step: int
|
| 39 |
+
prompt: str
|
| 40 |
+
completion: str
|
| 41 |
+
action: Dict[str, Any]
|
| 42 |
+
parsed_ok: bool
|
| 43 |
+
reward: float
|
| 44 |
+
done: bool
|
| 45 |
+
rule_violations: List[str]
|
| 46 |
+
observation_summary: Dict[str, Any] = field(default_factory=dict)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class EpisodeRecord:
|
| 51 |
+
seed: int
|
| 52 |
+
scenario: Optional[str]
|
| 53 |
+
difficulty: Optional[str]
|
| 54 |
+
truth: Optional[Dict[str, Any]]
|
| 55 |
+
total_reward: float
|
| 56 |
+
cumulative_reward: float
|
| 57 |
+
terminal_reward: Optional[float]
|
| 58 |
+
discovered: Optional[bool]
|
| 59 |
+
correct_mass: Optional[bool]
|
| 60 |
+
correct_channel: Optional[bool]
|
| 61 |
+
correct_spin: Optional[bool]
|
| 62 |
+
steps: List[StepRecord]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _summarise_obs(obs: CollisionObservation) -> Dict[str, Any]:
|
| 66 |
+
return {
|
| 67 |
+
"step_index": obs.step_index,
|
| 68 |
+
"selected_channel": obs.selected_channel,
|
| 69 |
+
"selected_beam_energy": obs.selected_beam_energy,
|
| 70 |
+
"n_candidates": len(obs.candidate_masses_gev),
|
| 71 |
+
"best_significance": obs.cumulative_significance,
|
| 72 |
+
"budget_remaining_musd": obs.resource_usage.budget_remaining_musd,
|
| 73 |
+
"luminosity_remaining_fb": obs.resource_usage.luminosity_remaining_fb,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def collect_episode(
|
| 78 |
+
*,
|
| 79 |
+
env: CERNCollisionEnvironment,
|
| 80 |
+
seed: int,
|
| 81 |
+
scenario: Optional[str],
|
| 82 |
+
difficulty: Optional[str],
|
| 83 |
+
prompt_fn: PromptFn,
|
| 84 |
+
generate_fn: GenerateFn,
|
| 85 |
+
config: LLMAgentConfig = LLMAgentConfig(),
|
| 86 |
+
max_steps: Optional[int] = None,
|
| 87 |
+
) -> EpisodeRecord:
|
| 88 |
+
obs = env.reset(seed=seed, scenario=scenario, difficulty=difficulty)
|
| 89 |
+
steps: List[StepRecord] = []
|
| 90 |
+
total_reward = 0.0
|
| 91 |
+
|
| 92 |
+
cap = max_steps or env.max_steps
|
| 93 |
+
while not obs.done and len(steps) < cap:
|
| 94 |
+
chat = build_chat(obs, config)
|
| 95 |
+
prompt = prompt_fn(chat)
|
| 96 |
+
completion = generate_fn(prompt, config)
|
| 97 |
+
|
| 98 |
+
action = parse_action(completion)
|
| 99 |
+
parsed_ok = action is not None
|
| 100 |
+
if action is None:
|
| 101 |
+
action = safe_default_action(obs)
|
| 102 |
+
|
| 103 |
+
next_obs = env.step(action)
|
| 104 |
+
reward = float(next_obs.reward or 0.0)
|
| 105 |
+
total_reward += reward
|
| 106 |
+
|
| 107 |
+
steps.append(
|
| 108 |
+
StepRecord(
|
| 109 |
+
step=obs.step_index,
|
| 110 |
+
prompt=prompt,
|
| 111 |
+
completion=completion,
|
| 112 |
+
action=action.model_dump(),
|
| 113 |
+
parsed_ok=parsed_ok,
|
| 114 |
+
reward=reward,
|
| 115 |
+
done=next_obs.done,
|
| 116 |
+
rule_violations=list(next_obs.rule_violations),
|
| 117 |
+
observation_summary=_summarise_obs(obs),
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
obs = next_obs
|
| 121 |
+
|
| 122 |
+
return EpisodeRecord(
|
| 123 |
+
seed=seed,
|
| 124 |
+
scenario=env.state.scenario_name,
|
| 125 |
+
difficulty=env.state.difficulty,
|
| 126 |
+
truth=env.hidden_truth(),
|
| 127 |
+
total_reward=total_reward,
|
| 128 |
+
cumulative_reward=float(env.state.cumulative_reward),
|
| 129 |
+
terminal_reward=env.state.terminal_reward,
|
| 130 |
+
discovered=env.state.discovered,
|
| 131 |
+
correct_mass=env.state.correct_mass,
|
| 132 |
+
correct_channel=env.state.correct_channel,
|
| 133 |
+
correct_spin=env.state.correct_spin,
|
| 134 |
+
steps=steps,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def save_episodes_jsonl(episodes: List[EpisodeRecord], path: str) -> None:
|
| 139 |
+
with open(path, "w") as f:
|
| 140 |
+
for ep in episodes:
|
| 141 |
+
f.write(json.dumps(asdict(ep), default=str) + "\n")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def load_episodes_jsonl(path: str) -> List[Dict[str, Any]]:
|
| 145 |
+
eps: List[Dict[str, Any]] = []
|
| 146 |
+
with open(path) as f:
|
| 147 |
+
for line in f:
|
| 148 |
+
line = line.strip()
|
| 149 |
+
if line:
|
| 150 |
+
eps.append(json.loads(line))
|
| 151 |
+
return eps
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
__all__ = [
|
| 155 |
+
"EpisodeRecord",
|
| 156 |
+
"StepRecord",
|
| 157 |
+
"collect_episode",
|
| 158 |
+
"save_episodes_jsonl",
|
| 159 |
+
"load_episodes_jsonl",
|
| 160 |
+
]
|
training/training_script.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO (Group-Relative Policy Optimization) training script for CERNenv.
|
| 2 |
+
|
| 3 |
+
Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to
|
| 4 |
+
fine-tune a small instruction-tuned model on full episodes of the CERN
|
| 5 |
+
environment. Each ``query`` is a prompt sampled from a freshly-reset env;
|
| 6 |
+
the reward function rolls the model's response through the environment and
|
| 7 |
+
returns the per-step + (optional) terminal reward.
|
| 8 |
+
|
| 9 |
+
This script is intentionally CPU-friendly and self-contained. For
|
| 10 |
+
GPU-accelerated training with LoRA, prefer ``training_unsloth.py``.
|
| 11 |
+
|
| 12 |
+
Run:
|
| 13 |
+
python -m training.training_script \
|
| 14 |
+
--model_name HuggingFaceTB/SmolLM2-360M-Instruct \
|
| 15 |
+
--total_episodes 200 --max_steps 18 --output_dir training/grpo-output
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import logging
|
| 22 |
+
import math
|
| 23 |
+
import os
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Any, Dict, List, Optional
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from datasets import Dataset
|
| 29 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 30 |
+
|
| 31 |
+
from models import ExperimentAction
|
| 32 |
+
from server.environment import CERNCollisionEnvironment
|
| 33 |
+
from training.llm_agent import (
|
| 34 |
+
LLMAgentConfig,
|
| 35 |
+
build_chat,
|
| 36 |
+
parse_action,
|
| 37 |
+
safe_default_action,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ── Episode reward harness ───────────────────────────────────────────────
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class EpisodeContext:
|
| 50 |
+
"""Per-prompt reusable env + observation snapshot used by the reward fn."""
|
| 51 |
+
|
| 52 |
+
env: CERNCollisionEnvironment
|
| 53 |
+
seed: int
|
| 54 |
+
scenario: Optional[str]
|
| 55 |
+
difficulty: Optional[str]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _stepwise_reward(
|
| 59 |
+
*,
|
| 60 |
+
completion_text: str,
|
| 61 |
+
ctx: EpisodeContext,
|
| 62 |
+
) -> float:
|
| 63 |
+
"""Roll the model's first response through one full episode and
|
| 64 |
+
return the cumulative reward (per-step + terminal).
|
| 65 |
+
|
| 66 |
+
The completion is interpreted as the first action only; subsequent
|
| 67 |
+
steps fall back to the safe default policy. This keeps the reward
|
| 68 |
+
bandwidth high for early-exploration training without requiring
|
| 69 |
+
multi-turn rollouts inside GRPO.
|
| 70 |
+
"""
|
| 71 |
+
env = ctx.env
|
| 72 |
+
obs = env.reset(seed=ctx.seed, scenario=ctx.scenario, difficulty=ctx.difficulty)
|
| 73 |
+
|
| 74 |
+
action = parse_action(completion_text) or safe_default_action(obs)
|
| 75 |
+
obs = env.step(action)
|
| 76 |
+
cumulative = float(obs.reward or 0.0)
|
| 77 |
+
|
| 78 |
+
while not obs.done:
|
| 79 |
+
fallback = safe_default_action(obs)
|
| 80 |
+
obs = env.step(fallback)
|
| 81 |
+
cumulative += float(obs.reward or 0.0)
|
| 82 |
+
|
| 83 |
+
return cumulative
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _format_validity_bonus(completion_text: str) -> float:
|
| 87 |
+
return 0.5 if parse_action(completion_text) is not None else -0.5
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def make_reward_fn(ctx: EpisodeContext):
|
| 91 |
+
"""Return a TRL-compatible reward function (closes over ``ctx``)."""
|
| 92 |
+
|
| 93 |
+
def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]:
|
| 94 |
+
rewards: List[float] = []
|
| 95 |
+
for completion in completions:
|
| 96 |
+
r = _stepwise_reward(completion_text=completion, ctx=ctx)
|
| 97 |
+
r += _format_validity_bonus(completion)
|
| 98 |
+
rewards.append(float(r))
|
| 99 |
+
return rewards
|
| 100 |
+
|
| 101 |
+
return reward_fn
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ── Prompt dataset ───────────────────────────────────────────────────────
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_dataset(
|
| 108 |
+
*,
|
| 109 |
+
tokenizer,
|
| 110 |
+
n_prompts: int,
|
| 111 |
+
seed: int,
|
| 112 |
+
scenario: Optional[str],
|
| 113 |
+
difficulty: Optional[str],
|
| 114 |
+
) -> Dataset:
|
| 115 |
+
env = CERNCollisionEnvironment()
|
| 116 |
+
prompts: List[str] = []
|
| 117 |
+
for i in range(n_prompts):
|
| 118 |
+
obs = env.reset(seed=seed + i, scenario=scenario, difficulty=difficulty)
|
| 119 |
+
chat = build_chat(obs)
|
| 120 |
+
prompt = tokenizer.apply_chat_template(
|
| 121 |
+
chat, add_generation_prompt=True, tokenize=False
|
| 122 |
+
)
|
| 123 |
+
prompts.append(prompt)
|
| 124 |
+
return Dataset.from_dict({"prompt": prompts})
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ── Main ─────────────────────────────────────────────────────────────────
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main() -> None: # pragma: no cover - training entrypoint
|
| 131 |
+
parser = argparse.ArgumentParser()
|
| 132 |
+
parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct")
|
| 133 |
+
parser.add_argument("--scenario", default=None)
|
| 134 |
+
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
|
| 135 |
+
parser.add_argument("--total_episodes", type=int, default=200)
|
| 136 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 137 |
+
parser.add_argument("--max_steps", type=int, default=18)
|
| 138 |
+
parser.add_argument("--num_generations", type=int, default=4)
|
| 139 |
+
parser.add_argument("--learning_rate", type=float, default=1e-5)
|
| 140 |
+
parser.add_argument("--max_prompt_length", type=int, default=1024)
|
| 141 |
+
parser.add_argument("--max_completion_length", type=int, default=256)
|
| 142 |
+
parser.add_argument("--output_dir", default="training/grpo-output")
|
| 143 |
+
args = parser.parse_args()
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 147 |
+
except ImportError as exc: # pragma: no cover
|
| 148 |
+
raise SystemExit(
|
| 149 |
+
"TRL (Transformer Reinforcement Learning) is required: "
|
| 150 |
+
"pip install -r requirements-train.txt"
|
| 151 |
+
) from exc
|
| 152 |
+
|
| 153 |
+
logger.info("Loading tokenizer + model: %s", args.model_name)
|
| 154 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 155 |
+
if tokenizer.pad_token is None:
|
| 156 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 157 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 158 |
+
args.model_name,
|
| 159 |
+
torch_dtype=torch.float32,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
logger.info("Building prompt dataset (%d prompts)", args.total_episodes)
|
| 163 |
+
dataset = build_dataset(
|
| 164 |
+
tokenizer=tokenizer,
|
| 165 |
+
n_prompts=args.total_episodes,
|
| 166 |
+
seed=args.seed,
|
| 167 |
+
scenario=args.scenario,
|
| 168 |
+
difficulty=args.difficulty,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
env = CERNCollisionEnvironment(max_steps=args.max_steps)
|
| 172 |
+
ctx = EpisodeContext(
|
| 173 |
+
env=env,
|
| 174 |
+
seed=args.seed,
|
| 175 |
+
scenario=args.scenario,
|
| 176 |
+
difficulty=args.difficulty,
|
| 177 |
+
)
|
| 178 |
+
reward_fn = make_reward_fn(ctx)
|
| 179 |
+
|
| 180 |
+
cfg = GRPOConfig(
|
| 181 |
+
output_dir=args.output_dir,
|
| 182 |
+
per_device_train_batch_size=2,
|
| 183 |
+
gradient_accumulation_steps=2,
|
| 184 |
+
num_generations=args.num_generations,
|
| 185 |
+
learning_rate=args.learning_rate,
|
| 186 |
+
max_prompt_length=args.max_prompt_length,
|
| 187 |
+
max_completion_length=args.max_completion_length,
|
| 188 |
+
logging_steps=5,
|
| 189 |
+
save_steps=50,
|
| 190 |
+
seed=args.seed,
|
| 191 |
+
bf16=False,
|
| 192 |
+
fp16=False,
|
| 193 |
+
report_to=[],
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
trainer = GRPOTrainer(
|
| 197 |
+
model=model,
|
| 198 |
+
processing_class=tokenizer,
|
| 199 |
+
train_dataset=dataset,
|
| 200 |
+
reward_funcs=[reward_fn],
|
| 201 |
+
args=cfg,
|
| 202 |
+
)
|
| 203 |
+
logger.info("Starting GRPO training")
|
| 204 |
+
trainer.train()
|
| 205 |
+
trainer.save_model(args.output_dir)
|
| 206 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 207 |
+
logger.info("Saved model to %s", args.output_dir)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__": # pragma: no cover
|
| 211 |
+
main()
|
training/training_unsloth.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unsloth + LoRA (Low-Rank Adaptation) GRPO training for CERNenv.
|
| 2 |
+
|
| 3 |
+
This is the recommended path for Colab / single-GPU runs because Unsloth's
|
| 4 |
+
fused kernels and 4-bit loading let us train 2B–8B models with limited VRAM.
|
| 5 |
+
|
| 6 |
+
Run on Colab:
|
| 7 |
+
!pip install -q unsloth unsloth_zoo trl peft datasets bitsandbytes
|
| 8 |
+
!python -m training.training_unsloth \
|
| 9 |
+
--model_name unsloth/Qwen2.5-3B-Instruct \
|
| 10 |
+
--total_episodes 400 --num_generations 4 --output_dir runs/unsloth-grpo
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Any, List, Optional
|
| 18 |
+
|
| 19 |
+
from datasets import Dataset
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> None: # pragma: no cover - heavy GPU path
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
parser.add_argument("--model_name", default="unsloth/Qwen2.5-3B-Instruct")
|
| 29 |
+
parser.add_argument("--scenario", default=None)
|
| 30 |
+
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
|
| 31 |
+
parser.add_argument("--total_episodes", type=int, default=400)
|
| 32 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 33 |
+
parser.add_argument("--max_steps", type=int, default=18)
|
| 34 |
+
parser.add_argument("--num_generations", type=int, default=4)
|
| 35 |
+
parser.add_argument("--max_prompt_length", type=int, default=2048)
|
| 36 |
+
parser.add_argument("--max_completion_length", type=int, default=384)
|
| 37 |
+
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
| 38 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=True)
|
| 39 |
+
parser.add_argument("--lora_rank", type=int, default=16)
|
| 40 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 41 |
+
parser.add_argument("--output_dir", default="training/runs/unsloth-grpo")
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
from unsloth import FastLanguageModel
|
| 45 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 46 |
+
|
| 47 |
+
from server.environment import CERNCollisionEnvironment
|
| 48 |
+
from training.llm_agent import (
|
| 49 |
+
LLMAgentConfig,
|
| 50 |
+
build_chat,
|
| 51 |
+
parse_action,
|
| 52 |
+
safe_default_action,
|
| 53 |
+
)
|
| 54 |
+
from training.training_script import EpisodeContext, _format_validity_bonus, _stepwise_reward
|
| 55 |
+
|
| 56 |
+
logger.info("Loading Unsloth model: %s", args.model_name)
|
| 57 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 58 |
+
model_name=args.model_name,
|
| 59 |
+
max_seq_length=args.max_prompt_length + args.max_completion_length,
|
| 60 |
+
load_in_4bit=args.load_in_4bit,
|
| 61 |
+
fast_inference=True,
|
| 62 |
+
)
|
| 63 |
+
model = FastLanguageModel.get_peft_model(
|
| 64 |
+
model,
|
| 65 |
+
r=args.lora_rank,
|
| 66 |
+
lora_alpha=args.lora_alpha,
|
| 67 |
+
target_modules=[
|
| 68 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 69 |
+
"gate_proj", "up_proj", "down_proj",
|
| 70 |
+
],
|
| 71 |
+
use_gradient_checkpointing="unsloth",
|
| 72 |
+
)
|
| 73 |
+
if tokenizer.pad_token is None:
|
| 74 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 75 |
+
|
| 76 |
+
# Build prompts
|
| 77 |
+
env = CERNCollisionEnvironment(max_steps=args.max_steps)
|
| 78 |
+
prompts: List[str] = []
|
| 79 |
+
for i in range(args.total_episodes):
|
| 80 |
+
obs = env.reset(seed=args.seed + i, scenario=args.scenario, difficulty=args.difficulty)
|
| 81 |
+
chat = build_chat(obs)
|
| 82 |
+
prompts.append(
|
| 83 |
+
tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
|
| 84 |
+
)
|
| 85 |
+
dataset = Dataset.from_dict({"prompt": prompts})
|
| 86 |
+
|
| 87 |
+
ctx = EpisodeContext(
|
| 88 |
+
env=env, seed=args.seed,
|
| 89 |
+
scenario=args.scenario, difficulty=args.difficulty,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]:
|
| 93 |
+
rewards: List[float] = []
|
| 94 |
+
for completion in completions:
|
| 95 |
+
r = _stepwise_reward(completion_text=completion, ctx=ctx)
|
| 96 |
+
r += _format_validity_bonus(completion)
|
| 97 |
+
rewards.append(float(r))
|
| 98 |
+
return rewards
|
| 99 |
+
|
| 100 |
+
cfg = GRPOConfig(
|
| 101 |
+
output_dir=args.output_dir,
|
| 102 |
+
per_device_train_batch_size=1,
|
| 103 |
+
gradient_accumulation_steps=4,
|
| 104 |
+
num_generations=args.num_generations,
|
| 105 |
+
learning_rate=args.learning_rate,
|
| 106 |
+
max_prompt_length=args.max_prompt_length,
|
| 107 |
+
max_completion_length=args.max_completion_length,
|
| 108 |
+
logging_steps=5,
|
| 109 |
+
save_steps=50,
|
| 110 |
+
seed=args.seed,
|
| 111 |
+
bf16=True,
|
| 112 |
+
report_to=[],
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
trainer = GRPOTrainer(
|
| 116 |
+
model=model,
|
| 117 |
+
processing_class=tokenizer,
|
| 118 |
+
train_dataset=dataset,
|
| 119 |
+
reward_funcs=[reward_fn],
|
| 120 |
+
args=cfg,
|
| 121 |
+
)
|
| 122 |
+
logger.info("Starting Unsloth + LoRA GRPO training")
|
| 123 |
+
trainer.train()
|
| 124 |
+
trainer.save_model(args.output_dir)
|
| 125 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 126 |
+
logger.info("Saved adapters to %s", args.output_dir)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__": # pragma: no cover
|
| 130 |
+
main()
|