Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- Dockerfile +43 -0
- LICENSE +21 -0
- README.md +213 -6
- __init__.py +34 -0
- client.py +127 -0
- grader.py +179 -0
- models.py +115 -0
- openenv.yaml +6 -0
- openenv_medusa.egg-info/PKG-INFO +16 -0
- openenv_medusa.egg-info/SOURCES.txt +25 -0
- openenv_medusa.egg-info/dependency_links.txt +1 -0
- openenv_medusa.egg-info/entry_points.txt +2 -0
- openenv_medusa.egg-info/requires.txt +10 -0
- openenv_medusa.egg-info/top_level.txt +2 -0
- operators.py +315 -0
- pyproject.toml +37 -0
- rewards.py +107 -0
- scenarios.py +215 -0
- scripts/inference.py +288 -0
- server/__init__.py +6 -0
- server/app.py +37 -0
- server/medusa_env.py +514 -0
- tasks.py +286 -0
- tests/test_medusa_environment.py +591 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HF Space root-level Dockerfile β targets port 7860 (HF default).
|
| 2 |
+
# This file lives at envs/medusa_env/Dockerfile and is the file
|
| 3 |
+
# HF Spaces uses when deploying a Docker Space from this directory.
|
| 4 |
+
|
| 5 |
+
FROM python:3.12-slim
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Install uv for fast dependency resolution
|
| 10 |
+
RUN pip install uv --no-cache-dir
|
| 11 |
+
|
| 12 |
+
# Copy environment code
|
| 13 |
+
COPY . /app/env
|
| 14 |
+
|
| 15 |
+
WORKDIR /app/env
|
| 16 |
+
|
| 17 |
+
# Install all dependencies including openenv-core + pandas + numpy
|
| 18 |
+
RUN uv pip install --system --no-cache \
|
| 19 |
+
"openenv-core[core]>=0.2.2" \
|
| 20 |
+
fastapi \
|
| 21 |
+
"uvicorn[standard]" \
|
| 22 |
+
pydantic \
|
| 23 |
+
pandas \
|
| 24 |
+
numpy \
|
| 25 |
+
websockets
|
| 26 |
+
|
| 27 |
+
# Install the medusa package itself (so medusa_env.* imports resolve)
|
| 28 |
+
RUN uv pip install --system --no-cache -e .
|
| 29 |
+
|
| 30 |
+
# HF Spaces requires port 7860
|
| 31 |
+
ENV PORT=7860
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
# PYTHONPATH so imports resolve correctly when running from /app/env
|
| 35 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 36 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 37 |
+
|
| 38 |
+
# Health check on HF port
|
| 39 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \
|
| 40 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
|
| 41 |
+
|
| 42 |
+
# Run on port 7860 β HF Space requirement
|
| 43 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Ram Janam Yadav
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,219 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: MEDUSA Environment
|
| 3 |
+
emoji: π¦
|
| 4 |
+
colorFrom: purple
|
| 5 |
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- reinforcement-learning
|
| 11 |
+
- data-engineering
|
| 12 |
+
app_port: 7860
|
| 13 |
+
base_path: /web
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# MEDUSA
|
| 17 |
+
|
| 18 |
+
**Medallion-Engineered Deterministic Unified Storage Agent**
|
| 19 |
+
|
| 20 |
+
An OpenEnv reinforcement learning environment that trains agents to act as *Relational Controllers* β orchestrating multi-source BronzeβSilver data integration pipelines inside a Medallion Architecture.
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## Problem
|
| 25 |
+
|
| 26 |
+
Modern data platforms fail not because they can't clean a single table, but because they can't reliably integrate **multiple shifting sources**. The BronzeβSilver transition is a minefield of:
|
| 27 |
+
|
| 28 |
+
- **Stale data** β processing yesterday's snapshot wastes compute and produces wrong results
|
| 29 |
+
- **Schema drift** β new columns appear in sources that Silver doesn't know about yet
|
| 30 |
+
- **Dirty join keys** β NULLs and whitespace cause 0-row joins and silent data loss
|
| 31 |
+
- **Cartesian explosions** β joining on non-unique Dimension keys multiplies rows catastrophically
|
| 32 |
+
- **Orphaned records** β unmatched Fact rows must be quarantined, not silently dropped
|
| 33 |
+
|
| 34 |
+
MEDUSA trains an agent to detect and handle all of these autonomously.
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## Environment Overview
|
| 39 |
+
|
| 40 |
+
```
|
| 41 |
+
Bronze A (Fact) βββ
|
| 42 |
+
ββββΊ [Agent] βββΊ Silver + /quarantine
|
| 43 |
+
Bronze B (Dim) βββ
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
The agent observes data quality signals and selects ETL actions step-by-step. At the end it issues `COMMIT`, triggering a deterministic grader audit.
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## The MDP
|
| 51 |
+
|
| 52 |
+
### Observation Space
|
| 53 |
+
|
| 54 |
+
A **16-element normalised float vector** `[0, 1]`:
|
| 55 |
+
|
| 56 |
+
| Index | Feature | Description |
|
| 57 |
+
|-------|---------|-------------|
|
| 58 |
+
| 0β1 | `time_delta_a/b_norm` | Source freshness (hours / 48h ceiling) |
|
| 59 |
+
| 2β3 | `is_stale_a/b` | Binary staleness flag |
|
| 60 |
+
| 4β5 | `null_ratio_key_a/b` | Fraction of null join keys |
|
| 61 |
+
| 6β7 | `uniqueness_a/b` | Key uniqueness ratio (1.0 = fully unique) |
|
| 62 |
+
| 8 | `match_rate` | % of Fact keys found in Dimension |
|
| 63 |
+
| 9β10 | `new_cols_a/b_norm` | Schema drift columns pending |
|
| 64 |
+
| 11 | `schema_compat` | Key type compatibility score |
|
| 65 |
+
| 12β14 | `did_prep_a/b`, `did_dedup_b` | Prerequisite action flags |
|
| 66 |
+
| 15 | `step_frac` | Episode progress (step / max_steps) |
|
| 67 |
+
|
| 68 |
+
### Action Space
|
| 69 |
+
|
| 70 |
+
11 discrete actions:
|
| 71 |
+
|
| 72 |
+
| Action | Description |
|
| 73 |
+
|--------|-------------|
|
| 74 |
+
| `SYNC_CHECK` | Verify freshness of both sources |
|
| 75 |
+
| `EVOLVE_SCHEMA` | Add new columns from A/B into Silver schema |
|
| 76 |
+
| `PREP_KEYS_A` | Cast, strip, null-fill join key in Source A |
|
| 77 |
+
| `PREP_KEYS_B` | Cast, strip, null-fill join key in Source B |
|
| 78 |
+
| `DEDUPLICATE_B` | Ensure Dimension (B) is unique on the join key |
|
| 79 |
+
| `EXECUTE_JOIN_INNER` | Inner join A β B |
|
| 80 |
+
| `EXECUTE_JOIN_LEFT` | Left join A β B (orphans β quarantine) |
|
| 81 |
+
| `EXECUTE_JOIN_ANTI` | Anti-join: extract rows in A with no match in B |
|
| 82 |
+
| `APPLY_SCD_1` | Overwrite Silver records (SCD Type 1) |
|
| 83 |
+
| `APPLY_SCD_2` | Close old records, insert new with timestamps (SCD Type 2) |
|
| 84 |
+
| `COMMIT` | Finalise pipeline; triggers grader audit |
|
| 85 |
+
|
| 86 |
+
### Reward Model
|
| 87 |
+
|
| 88 |
+
| Event | Reward | Trigger |
|
| 89 |
+
|-------|--------|---------|
|
| 90 |
+
| High-Match Join | **+25.0** | `match_rate > 90%` after join |
|
| 91 |
+
| Quarantine Precision | **+10.0** | Orphaned rows correctly isolated |
|
| 92 |
+
| Correct SCD-2 | **+5.0** | SCD-2 applied on a tracked column |
|
| 93 |
+
| Grader All-Pass Bonus | **+15.0** | All 4 post-commit checks pass |
|
| 94 |
+
| Row Explosion | **β100.0** | Join output > 105% of Fact row count |
|
| 95 |
+
| Join on Dirty Keys | **β30.0** | Join without PREP_KEYS β 0-row result |
|
| 96 |
+
| Stale Processing | **β15.0** | Action taken while source is stale, SYNC_CHECK never called |
|
| 97 |
+
| Step Penalty | **β0.2** | Applied every step (efficiency incentive) |
|
| 98 |
+
|
| 99 |
+
---
|
| 100 |
+
|
| 101 |
+
## Post-Commit Grader
|
| 102 |
+
|
| 103 |
+
After `COMMIT` the deterministic grader runs 4 checks:
|
| 104 |
+
|
| 105 |
+
| Check | Pass Condition |
|
| 106 |
+
|-------|---------------|
|
| 107 |
+
| **Volume** | `Silver rows β€ Source A rows` (for left joins) |
|
| 108 |
+
| **Integrity** | Quarantine holds only true orphans (not keys that could have joined if cleaned) |
|
| 109 |
+
| **Schema** | Silver contains the union of all required columns from A and B |
|
| 110 |
+
| **History** | SCD-2 `valid_from`/`valid_to` timestamps are non-overlapping |
|
| 111 |
+
|
| 112 |
+
All 4 pass β **+15.0** bonus. Each failure costs **β5.0**.
|
| 113 |
+
|
| 114 |
+
---
|
| 115 |
+
|
| 116 |
+
## Episode Scenarios
|
| 117 |
+
|
| 118 |
+
Four canonical scenarios (selectable by seed):
|
| 119 |
+
|
| 120 |
+
| Seed | Scenario | Challenge |
|
| 121 |
+
|------|----------|-----------|
|
| 122 |
+
| 0 | `clean` | Fresh, unique keys, ~100% match rate. Baseline. |
|
| 123 |
+
| 1 | `dirty_keys` | NULLs + whitespace in join keys. Must PREP first. |
|
| 124 |
+
| 2 | `stale` | Source A is 8β24h old. Must SYNC_CHECK first. |
|
| 125 |
+
| 3 | `schema_drift` | New columns in A and B not yet in Silver. Must EVOLVE first. |
|
| 126 |
+
|
| 127 |
+
Random seeds produce blended variants.
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## Setup
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
# Clone / navigate to repo
|
| 135 |
+
cd /path/to/OpenEnv
|
| 136 |
+
|
| 137 |
+
# Create venv and install all deps (including pandas, numpy)
|
| 138 |
+
uv sync
|
| 139 |
+
|
| 140 |
+
# Activate
|
| 141 |
+
source .venv/bin/activate
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
## Running
|
| 147 |
+
|
| 148 |
+
### Start the FastAPI server
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
uvicorn envs.medusa_env.server.app:app --reload --host 0.0.0.0 --port 8000
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
API docs available at `http://localhost:8000/docs`.
|
| 155 |
+
|
| 156 |
+
### Run tests
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
python -m pytest tests/envs/test_medusa_environment.py -v
|
| 160 |
+
# 39 passed in ~4s
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### Run a manual episode (Python)
|
| 164 |
+
|
| 165 |
+
```python
|
| 166 |
+
from envs.medusa_env import MedusaEnv, MedusaAction
|
| 167 |
+
from envs.medusa_env.models import MedusaActionType
|
| 168 |
+
|
| 169 |
+
env = MedusaEnv(n_fact_rows=200, n_dim_rows=150)
|
| 170 |
+
obs = env.reset(seed=0) # seed 0 = clean scenario
|
| 171 |
+
print(obs.message)
|
| 172 |
+
|
| 173 |
+
for action_type in [
|
| 174 |
+
MedusaActionType.SYNC_CHECK,
|
| 175 |
+
MedusaActionType.EVOLVE_SCHEMA,
|
| 176 |
+
MedusaActionType.PREP_KEYS_A,
|
| 177 |
+
MedusaActionType.PREP_KEYS_B,
|
| 178 |
+
MedusaActionType.DEDUPLICATE_B,
|
| 179 |
+
MedusaActionType.EXECUTE_JOIN_LEFT,
|
| 180 |
+
MedusaActionType.APPLY_SCD_2,
|
| 181 |
+
MedusaActionType.COMMIT,
|
| 182 |
+
]:
|
| 183 |
+
obs = env.step(MedusaAction(action=action_type))
|
| 184 |
+
print(f"{action_type.value:25s} reward={obs.reward:+.1f} done={obs.done}")
|
| 185 |
+
|
| 186 |
+
print(f"\nGrader: {env.state.grader_report}")
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
## Architecture
|
| 192 |
+
|
| 193 |
+
```
|
| 194 |
+
envs/medusa_env/
|
| 195 |
+
βββ __init__.py # Package exports
|
| 196 |
+
βββ medusa_env.py # MedusaEnv β reset / step / commit loop
|
| 197 |
+
βββ models.py # MedusaAction, MedusaObservation, MedusaState (Pydantic)
|
| 198 |
+
βββ scenarios.py # ScenarioGenerator β procedural Bronze A/B DataFrames
|
| 199 |
+
βββ operators.py # Stateless ETL functions (sync_check, prep_keys, execute_join, apply_scd β¦)
|
| 200 |
+
βββ rewards.py # RewardEngine β per-step reward computation
|
| 201 |
+
βββ grader.py # Grader β post-commit deterministic audit
|
| 202 |
+
βββ openenv.yaml # OpenEnv environment manifest
|
| 203 |
+
βββ server/
|
| 204 |
+
βββ app.py # FastAPI app via create_app()
|
| 205 |
+
|
| 206 |
+
tests/envs/
|
| 207 |
+
βββ test_medusa_environment.py # 39 tests across 6 test classes
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
**Stack:** Python 3.10+ Β· Pandas Β· Pydantic v2 Β· FastAPI Β· OpenEnv
|
| 211 |
+
|
| 212 |
+
---
|
| 213 |
+
|
| 214 |
+
## Technical Notes
|
| 215 |
+
|
| 216 |
+
- **No external data required.** All Bronze tables are generated procedurally per episode.
|
| 217 |
+
- **No Spark or Delta Lake required.** All logic uses Pandas β identical semantics, zero cluster setup.
|
| 218 |
+
- The grader is fully deterministic: same Silver + quarantine tables always produce the same audit result.
|
| 219 |
+
- The governance log (accessible at `env._tables.governance_log`) records every agent decision with its reward and operator metrics.
|
__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA (Medallion-Engineered Deterministic Unified Storage Agent) environment.
|
| 2 |
+
|
| 3 |
+
Full BronzeβSilver integration controller with:
|
| 4 |
+
- Multi-source join orchestration (inner / left / anti)
|
| 5 |
+
- Schema drift handling (EVOLVE_SCHEMA)
|
| 6 |
+
- Key preparation and deduplication
|
| 7 |
+
- SCD-1 and SCD-2 merge logic
|
| 8 |
+
- Per-step RL reward engine
|
| 9 |
+
- Deterministic post-commit grader
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .client import medusa_env
|
| 13 |
+
from .grader import Grader, GraderResult
|
| 14 |
+
from .models import MedusaAction, MedusaActionType, MedusaObservation, MedusaState
|
| 15 |
+
from .rewards import RewardEngine
|
| 16 |
+
from .scenarios import Scenario, ScenarioGenerator
|
| 17 |
+
from .tasks import TASKS, Task, TaskResult, score_episode
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"medusa_env",
|
| 21 |
+
"MedusaAction",
|
| 22 |
+
"MedusaActionType",
|
| 23 |
+
"MedusaObservation",
|
| 24 |
+
"MedusaState",
|
| 25 |
+
"Scenario",
|
| 26 |
+
"ScenarioGenerator",
|
| 27 |
+
"RewardEngine",
|
| 28 |
+
"Grader",
|
| 29 |
+
"GraderResult",
|
| 30 |
+
"TASKS",
|
| 31 |
+
"Task",
|
| 32 |
+
"TaskResult",
|
| 33 |
+
"score_episode",
|
| 34 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA Environment Client.
|
| 2 |
+
|
| 3 |
+
Connects to a running MEDUSA server via WebSocket for persistent sessions.
|
| 4 |
+
|
| 5 |
+
Example:
|
| 6 |
+
>>> # Connect to a running server
|
| 7 |
+
>>> with medusa_env(base_url="http://localhost:8000") as client:
|
| 8 |
+
... result = client.reset(seed=0)
|
| 9 |
+
... print(result.observation.message)
|
| 10 |
+
...
|
| 11 |
+
... from envs.medusa_env.models import MedusaActionType
|
| 12 |
+
... result = client.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
|
| 13 |
+
... print(f"Reward: {result.reward}")
|
| 14 |
+
|
| 15 |
+
Example with Docker:
|
| 16 |
+
>>> client = medusa_env.from_docker_image("medusa_env:latest")
|
| 17 |
+
>>> try:
|
| 18 |
+
... result = client.reset()
|
| 19 |
+
... result = client.step(MedusaAction(action=MedusaActionType.COMMIT))
|
| 20 |
+
... finally:
|
| 21 |
+
... client.close()
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from typing import Any, Dict
|
| 25 |
+
|
| 26 |
+
# Support both in-repo and standalone imports
|
| 27 |
+
try:
|
| 28 |
+
from openenv.core.client_types import StepResult
|
| 29 |
+
from openenv.core.env_client import EnvClient
|
| 30 |
+
|
| 31 |
+
from .models import MedusaAction, MedusaObservation, MedusaState
|
| 32 |
+
except ImportError:
|
| 33 |
+
from models import MedusaAction, MedusaObservation, MedusaState
|
| 34 |
+
|
| 35 |
+
from openenv.core.client_types import StepResult
|
| 36 |
+
from openenv.core.env_client import EnvClient
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class medusa_env(EnvClient[MedusaAction, MedusaObservation, MedusaState]):
|
| 40 |
+
"""Client for the MEDUSA BronzeβSilver integration environment.
|
| 41 |
+
|
| 42 |
+
Maintains a persistent WebSocket connection to the MEDUSA server.
|
| 43 |
+
Each client instance has its own dedicated environment session.
|
| 44 |
+
|
| 45 |
+
The agent observes a 16-float data quality feature vector and chooses
|
| 46 |
+
from 11 discrete ETL actions to build a correct Silver entity from
|
| 47 |
+
two Bronze sources (Fact + Dimension).
|
| 48 |
+
|
| 49 |
+
Example:
|
| 50 |
+
>>> with medusa_env(base_url="http://localhost:8000") as env:
|
| 51 |
+
... result = env.reset(seed=0) # clean scenario
|
| 52 |
+
... result = env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
|
| 53 |
+
... result = env.step(MedusaAction(action=MedusaActionType.PREP_KEYS_A))
|
| 54 |
+
... result = env.step(MedusaAction(action=MedusaActionType.PREP_KEYS_B))
|
| 55 |
+
... result = env.step(MedusaAction(action=MedusaActionType.DEDUPLICATE_B))
|
| 56 |
+
... result = env.step(MedusaAction(action=MedusaActionType.EXECUTE_JOIN_LEFT))
|
| 57 |
+
... result = env.step(MedusaAction(action=MedusaActionType.APPLY_SCD_2))
|
| 58 |
+
... result = env.step(MedusaAction(action=MedusaActionType.COMMIT))
|
| 59 |
+
... print(result.reward)
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def _step_payload(self, action: MedusaAction) -> Dict[str, Any]:
|
| 63 |
+
"""Convert MedusaAction to JSON payload for the step request."""
|
| 64 |
+
return {
|
| 65 |
+
"action": action.action.value,
|
| 66 |
+
"params": action.params,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[MedusaObservation]:
|
| 70 |
+
"""Parse server response into StepResult[MedusaObservation]."""
|
| 71 |
+
obs_data = payload.get("observation", {})
|
| 72 |
+
observation = MedusaObservation(
|
| 73 |
+
message=obs_data.get("message", ""),
|
| 74 |
+
features=obs_data.get("features", []),
|
| 75 |
+
metrics=obs_data.get("metrics", {}),
|
| 76 |
+
metadata=obs_data.get("metadata", {}),
|
| 77 |
+
reward=payload.get("reward"),
|
| 78 |
+
done=payload.get("done", False),
|
| 79 |
+
)
|
| 80 |
+
return StepResult(
|
| 81 |
+
observation=observation,
|
| 82 |
+
reward=payload.get("reward"),
|
| 83 |
+
done=payload.get("done", False),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def _parse_state(self, payload: Dict[str, Any]) -> MedusaState:
|
| 87 |
+
"""Parse server response into MedusaState."""
|
| 88 |
+
return MedusaState(
|
| 89 |
+
run_id=payload.get("run_id"),
|
| 90 |
+
seed=payload.get("seed"),
|
| 91 |
+
scenario_id=payload.get("scenario_id"),
|
| 92 |
+
step_idx=payload.get("step_idx", 0),
|
| 93 |
+
stage=payload.get("stage", "init"),
|
| 94 |
+
# Freshness
|
| 95 |
+
time_delta_a=payload.get("time_delta_a", 0.0),
|
| 96 |
+
time_delta_b=payload.get("time_delta_b", 0.0),
|
| 97 |
+
is_stale_a=payload.get("is_stale_a", False),
|
| 98 |
+
is_stale_b=payload.get("is_stale_b", False),
|
| 99 |
+
did_sync_check=payload.get("did_sync_check", False),
|
| 100 |
+
# Key health
|
| 101 |
+
null_ratio_key_a=payload.get("null_ratio_key_a", 0.0),
|
| 102 |
+
null_ratio_key_b=payload.get("null_ratio_key_b", 0.0),
|
| 103 |
+
uniqueness_a=payload.get("uniqueness_a", 1.0),
|
| 104 |
+
uniqueness_b=payload.get("uniqueness_b", 1.0),
|
| 105 |
+
did_prep_a=payload.get("did_prep_a", False),
|
| 106 |
+
did_prep_b=payload.get("did_prep_b", False),
|
| 107 |
+
did_dedup_b=payload.get("did_dedup_b", False),
|
| 108 |
+
# Join
|
| 109 |
+
match_rate=payload.get("match_rate", 0.0),
|
| 110 |
+
did_join=payload.get("did_join", False),
|
| 111 |
+
join_type=payload.get("join_type"),
|
| 112 |
+
join_row_count=payload.get("join_row_count", 0),
|
| 113 |
+
explosion_detected=payload.get("explosion_detected", False),
|
| 114 |
+
# SCD
|
| 115 |
+
did_scd=payload.get("did_scd", False),
|
| 116 |
+
scd_type=payload.get("scd_type"),
|
| 117 |
+
scd_inserts=payload.get("scd_inserts", 0),
|
| 118 |
+
scd_updates=payload.get("scd_updates", 0),
|
| 119 |
+
# Silver / Quarantine
|
| 120 |
+
silver_row_count=payload.get("silver_row_count", 0),
|
| 121 |
+
quarantine_row_count=payload.get("quarantine_row_count", 0),
|
| 122 |
+
source_a_row_count=payload.get("source_a_row_count", 0),
|
| 123 |
+
# Grader
|
| 124 |
+
grader_passed=payload.get("grader_passed", False),
|
| 125 |
+
grader_report=payload.get("grader_report", ""),
|
| 126 |
+
cumulative_reward=payload.get("cumulative_reward", 0.0),
|
| 127 |
+
)
|
grader.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA deterministic post-commit grader.
|
| 2 |
+
|
| 3 |
+
Runs a four-check audit after the agent issues COMMIT and returns a
|
| 4 |
+
``GraderResult`` that feeds a bonus/penalty into the terminal reward.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import TYPE_CHECKING, List
|
| 11 |
+
|
| 12 |
+
import pandas as pd
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from .scenarios import Scenario
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# GraderResult
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class GraderResult:
|
| 24 |
+
"""Outcome of the post-commit audit."""
|
| 25 |
+
|
| 26 |
+
passed: bool = False
|
| 27 |
+
volume_ok: bool = False # Silver rows β€ Source A rows (no duplicates from join)
|
| 28 |
+
integrity_ok: bool = False # Quarantine holds only true orphans
|
| 29 |
+
schema_ok: bool = False # Silver has union of required columns
|
| 30 |
+
history_ok: bool = False # SCD-2 timestamps non-overlapping
|
| 31 |
+
failures: List[str] = field(default_factory=list)
|
| 32 |
+
bonus_reward: float = 0.0
|
| 33 |
+
report: str = ""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Reward tuning
|
| 37 |
+
_BONUS_ALL_PASS = +15.0
|
| 38 |
+
_PENALTY_ALL_FAIL = -20.0
|
| 39 |
+
_BONUS_PER_CHECK = +3.0
|
| 40 |
+
_PENALTY_PER_FAIL = -5.0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# Grader
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
class Grader:
|
| 48 |
+
"""Post-commit deterministic audit following MEDUSA spec Β§4."""
|
| 49 |
+
|
| 50 |
+
def audit(
|
| 51 |
+
self,
|
| 52 |
+
silver: pd.DataFrame,
|
| 53 |
+
quarantine: pd.DataFrame,
|
| 54 |
+
bronze_a: pd.DataFrame,
|
| 55 |
+
bronze_b: pd.DataFrame,
|
| 56 |
+
join_key: str,
|
| 57 |
+
join_type: str,
|
| 58 |
+
scd_type: int,
|
| 59 |
+
scenario: "Scenario",
|
| 60 |
+
) -> GraderResult:
|
| 61 |
+
"""Run all four grader checks and compute bonus reward.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
silver: The final Silver DataFrame after SCD merge.
|
| 65 |
+
quarantine: Rows from A that did not match B.
|
| 66 |
+
bronze_a: Original fact source (pre-cleaning).
|
| 67 |
+
bronze_b: Original dimension source (pre-cleaning).
|
| 68 |
+
join_key: Column used for the join.
|
| 69 |
+
join_type: "inner" | "left" | "anti"
|
| 70 |
+
scd_type: 1 or 2
|
| 71 |
+
scenario: The current episode's scenario (has tracked_cols etc.)
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
GraderResult with individual check statuses and bonus_reward.
|
| 75 |
+
"""
|
| 76 |
+
result = GraderResult()
|
| 77 |
+
|
| 78 |
+
# ββ 1. Volume Check ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 79 |
+
# For left joins, Silver should not exceed Source A row count.
|
| 80 |
+
if join_type == "left":
|
| 81 |
+
source_a_rows = len(bronze_a.dropna(subset=[join_key]))
|
| 82 |
+
silver_rows = len(silver[silver.get("is_current", pd.Series(True, index=silver.index)) == True]) if "is_current" in silver.columns else len(silver) # noqa: E712
|
| 83 |
+
result.volume_ok = silver_rows <= source_a_rows * 1.05 # 5% tolerance
|
| 84 |
+
if not result.volume_ok:
|
| 85 |
+
result.failures.append(
|
| 86 |
+
f"VOLUME_FAIL: Silver {silver_rows} rows > Source A {source_a_rows} rows"
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
result.volume_ok = True # Not applicable for inner/anti joins
|
| 90 |
+
|
| 91 |
+
# ββ 2. Integrity Check βββββββββββββββββββββββββββββββββββββββββββ
|
| 92 |
+
# Quarantine rows should be true orphans (no match in B even after cleaning).
|
| 93 |
+
if not quarantine.empty and join_key in quarantine.columns:
|
| 94 |
+
dim_keys = set(bronze_b[join_key].dropna().astype(str).str.strip())
|
| 95 |
+
quarantine_keys = set(quarantine[join_key].dropna().astype(str).str.strip())
|
| 96 |
+
# Orphan = quarantine key truly not in dim
|
| 97 |
+
could_join = quarantine_keys & dim_keys
|
| 98 |
+
if could_join:
|
| 99 |
+
result.integrity_ok = False
|
| 100 |
+
result.failures.append(
|
| 101 |
+
f"INTEGRITY_FAIL: {len(could_join)} quarantine row(s) could have "
|
| 102 |
+
f"been joined if keys were cleaned."
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
result.integrity_ok = True
|
| 106 |
+
else:
|
| 107 |
+
result.integrity_ok = True # Empty quarantine is fine
|
| 108 |
+
|
| 109 |
+
# ββ 3. Schema Check ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
+
# Silver must contain all required columns from A and B.
|
| 111 |
+
required_from_a = [c for c in bronze_a.columns if c != join_key]
|
| 112 |
+
required_from_b = [c for c in bronze_b.columns if c != join_key]
|
| 113 |
+
required = set(required_from_a + required_from_b + scenario.new_cols_a + scenario.new_cols_b)
|
| 114 |
+
silver_cols = set(silver.columns)
|
| 115 |
+
missing = required - silver_cols
|
| 116 |
+
if missing:
|
| 117 |
+
result.schema_ok = False
|
| 118 |
+
result.failures.append(f"SCHEMA_FAIL: Missing columns in Silver: {sorted(missing)}")
|
| 119 |
+
else:
|
| 120 |
+
result.schema_ok = True
|
| 121 |
+
|
| 122 |
+
# ββ 4. History Check (SCD-2 only) ββββββββββββββββββββββββββββββββ
|
| 123 |
+
if scd_type == 2 and "valid_from" in silver.columns and "valid_to" in silver.columns:
|
| 124 |
+
overlap_found = False
|
| 125 |
+
for key_val, group in silver.groupby(join_key):
|
| 126 |
+
if len(group) < 2:
|
| 127 |
+
continue
|
| 128 |
+
closed = group[group["valid_to"].notna()].sort_values("valid_from")
|
| 129 |
+
for i in range(len(closed) - 1):
|
| 130 |
+
vt_i = closed.iloc[i]["valid_to"]
|
| 131 |
+
vf_next = closed.iloc[i + 1]["valid_from"]
|
| 132 |
+
if pd.notna(vt_i) and pd.notna(vf_next) and vt_i > vf_next:
|
| 133 |
+
overlap_found = True
|
| 134 |
+
break
|
| 135 |
+
if overlap_found:
|
| 136 |
+
break
|
| 137 |
+
if overlap_found:
|
| 138 |
+
result.history_ok = False
|
| 139 |
+
result.failures.append("HISTORY_FAIL: SCD-2 timestamps overlap for some keys.")
|
| 140 |
+
else:
|
| 141 |
+
result.history_ok = True
|
| 142 |
+
else:
|
| 143 |
+
result.history_ok = True # Not applicable for SCD-1
|
| 144 |
+
|
| 145 |
+
# ββ Compute bonus ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
+
checks = [result.volume_ok, result.integrity_ok, result.schema_ok, result.history_ok]
|
| 147 |
+
passed_count = sum(checks)
|
| 148 |
+
failed_count = len(checks) - passed_count
|
| 149 |
+
|
| 150 |
+
result.passed = all(checks)
|
| 151 |
+
|
| 152 |
+
if result.passed:
|
| 153 |
+
result.bonus_reward = _BONUS_ALL_PASS
|
| 154 |
+
elif failed_count == len(checks):
|
| 155 |
+
result.bonus_reward = _PENALTY_ALL_FAIL
|
| 156 |
+
else:
|
| 157 |
+
result.bonus_reward = passed_count * _BONUS_PER_CHECK - failed_count * _PENALTY_PER_FAIL
|
| 158 |
+
|
| 159 |
+
result.report = _build_report(result)
|
| 160 |
+
return result
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# Internal helpers
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
def _build_report(result: GraderResult) -> str:
|
| 168 |
+
lines = ["=== MEDUSA Grader Audit ==="]
|
| 169 |
+
lines.append(f" Volume OK: {'β' if result.volume_ok else 'β'}")
|
| 170 |
+
lines.append(f" Integrity OK: {'β' if result.integrity_ok else 'β'}")
|
| 171 |
+
lines.append(f" Schema OK: {'β' if result.schema_ok else 'β'}")
|
| 172 |
+
lines.append(f" History OK: {'β' if result.history_ok else 'β'}")
|
| 173 |
+
lines.append(f" Bonus Reward: {result.bonus_reward:+.1f}")
|
| 174 |
+
if result.failures:
|
| 175 |
+
lines.append(" Failures:")
|
| 176 |
+
for f in result.failures:
|
| 177 |
+
lines.append(f" - {f}")
|
| 178 |
+
lines.append(f" {'PASS β' if result.passed else 'FAIL β'}")
|
| 179 |
+
return "\n".join(lines)
|
models.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from pydantic import Field
|
| 7 |
+
|
| 8 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MedusaActionType(str, Enum):
|
| 12 |
+
"""Discrete action set for the MEDUSA controller."""
|
| 13 |
+
|
| 14 |
+
SYNC_CHECK = "SYNC_CHECK"
|
| 15 |
+
EVOLVE_SCHEMA = "EVOLVE_SCHEMA"
|
| 16 |
+
PREP_KEYS_A = "PREP_KEYS_A"
|
| 17 |
+
PREP_KEYS_B = "PREP_KEYS_B"
|
| 18 |
+
DEDUPLICATE_B = "DEDUPLICATE_B"
|
| 19 |
+
EXECUTE_JOIN_INNER = "EXECUTE_JOIN_INNER"
|
| 20 |
+
EXECUTE_JOIN_LEFT = "EXECUTE_JOIN_LEFT"
|
| 21 |
+
EXECUTE_JOIN_ANTI = "EXECUTE_JOIN_ANTI"
|
| 22 |
+
APPLY_SCD_1 = "APPLY_SCD_1"
|
| 23 |
+
APPLY_SCD_2 = "APPLY_SCD_2"
|
| 24 |
+
COMMIT = "COMMIT"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MedusaAction(Action):
|
| 28 |
+
"""One controller action (enum + optional params for future use)."""
|
| 29 |
+
|
| 30 |
+
action: MedusaActionType
|
| 31 |
+
params: Dict[str, Any] = Field(default_factory=dict)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MedusaState(State):
|
| 35 |
+
"""Full pipeline controller state.
|
| 36 |
+
|
| 37 |
+
Tracks every book-keeping flag needed by the reward engine and grader.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
run_id: Optional[str] = None
|
| 41 |
+
seed: Optional[int] = None
|
| 42 |
+
scenario_id: Optional[str] = None
|
| 43 |
+
max_steps: int = 20
|
| 44 |
+
|
| 45 |
+
step_idx: int = 0
|
| 46 |
+
stage: str = "init" # init | running | committed | failed
|
| 47 |
+
|
| 48 |
+
# --- Freshness ---
|
| 49 |
+
time_delta_a: float = 0.0 # Hours since Source A last updated
|
| 50 |
+
time_delta_b: float = 0.0
|
| 51 |
+
is_stale_a: bool = False
|
| 52 |
+
is_stale_b: bool = False
|
| 53 |
+
did_sync_check: bool = False
|
| 54 |
+
|
| 55 |
+
# --- Schema ---
|
| 56 |
+
did_evolve_schema: bool = False
|
| 57 |
+
new_cols_a: int = 0 # Number of new columns in A not yet in Silver
|
| 58 |
+
new_cols_b: int = 0
|
| 59 |
+
schema_compat: float = 1.0 # 0-1 key-type compatibility score
|
| 60 |
+
|
| 61 |
+
# --- Key Health ---
|
| 62 |
+
null_ratio_key_a: float = 0.0
|
| 63 |
+
null_ratio_key_b: float = 0.0
|
| 64 |
+
uniqueness_a: float = 1.0 # 1.0 = fully unique
|
| 65 |
+
uniqueness_b: float = 1.0
|
| 66 |
+
did_prep_a: bool = False
|
| 67 |
+
did_prep_b: bool = False
|
| 68 |
+
did_dedup_b: bool = False
|
| 69 |
+
|
| 70 |
+
# --- Referential Integrity ---
|
| 71 |
+
match_rate: float = 0.0 # % of Key_A values found in Key_B
|
| 72 |
+
|
| 73 |
+
# --- Join Result ---
|
| 74 |
+
did_join: bool = False
|
| 75 |
+
join_type: Optional[str] = None
|
| 76 |
+
join_row_count: int = 0
|
| 77 |
+
explosion_detected: bool = False
|
| 78 |
+
|
| 79 |
+
# --- SCD ---
|
| 80 |
+
did_scd: bool = False
|
| 81 |
+
scd_type: Optional[str] = None
|
| 82 |
+
scd_inserts: int = 0
|
| 83 |
+
scd_updates: int = 0
|
| 84 |
+
|
| 85 |
+
# --- Silver / Quarantine ---
|
| 86 |
+
silver_row_count: int = 0
|
| 87 |
+
quarantine_row_count: int = 0
|
| 88 |
+
source_a_row_count: int = 0
|
| 89 |
+
|
| 90 |
+
# --- Grader ---
|
| 91 |
+
grader_passed: bool = False
|
| 92 |
+
grader_report: str = ""
|
| 93 |
+
|
| 94 |
+
# --- Governance ---
|
| 95 |
+
cumulative_reward: float = 0.0
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MedusaObservation(Observation):
|
| 99 |
+
"""Observation returned to the agent after every step.
|
| 100 |
+
|
| 101 |
+
``features`` is a 16-element normalised float vector suitable as
|
| 102 |
+
direct RL input::
|
| 103 |
+
|
| 104 |
+
[time_delta_a_norm, time_delta_b_norm, is_stale_a, is_stale_b,
|
| 105 |
+
null_ratio_key_a, null_ratio_key_b, uniqueness_a, uniqueness_b,
|
| 106 |
+
match_rate, new_cols_a_norm, new_cols_b_norm, schema_compat,
|
| 107 |
+
did_prep_a, did_prep_b, did_dedup_b, step_frac]
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
message: str = ""
|
| 111 |
+
features: List[float] = Field(default_factory=list)
|
| 112 |
+
metrics: Dict[str, Any] = Field(default_factory=dict)
|
| 113 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 114 |
+
reward: Optional[float] = None
|
| 115 |
+
done: bool = False
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: medusa_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
openenv_medusa.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-medusa
|
| 3 |
+
Version: 0.2.0
|
| 4 |
+
Summary: MEDUSA: Medallion-Engineered Deterministic Unified Storage Agent β BronzeβSilver RL environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
License-File: LICENSE
|
| 7 |
+
Requires-Dist: openenv-core[core]>=0.2.2
|
| 8 |
+
Requires-Dist: fastapi>=0.115.0
|
| 9 |
+
Requires-Dist: pydantic>=2.0.0
|
| 10 |
+
Requires-Dist: uvicorn>=0.24.0
|
| 11 |
+
Requires-Dist: pandas>=2.0.0
|
| 12 |
+
Requires-Dist: numpy>=1.24.0
|
| 13 |
+
Provides-Extra: dev
|
| 14 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 15 |
+
Requires-Dist: pytest-asyncio>=0.23.0; extra == "dev"
|
| 16 |
+
Dynamic: license-file
|
openenv_medusa.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
./__init__.py
|
| 5 |
+
./client.py
|
| 6 |
+
./grader.py
|
| 7 |
+
./models.py
|
| 8 |
+
./openenv.yaml
|
| 9 |
+
./operators.py
|
| 10 |
+
./rewards.py
|
| 11 |
+
./scenarios.py
|
| 12 |
+
./tasks.py
|
| 13 |
+
./server/__init__.py
|
| 14 |
+
./server/app.py
|
| 15 |
+
./server/medusa_env.py
|
| 16 |
+
openenv_medusa.egg-info/PKG-INFO
|
| 17 |
+
openenv_medusa.egg-info/SOURCES.txt
|
| 18 |
+
openenv_medusa.egg-info/dependency_links.txt
|
| 19 |
+
openenv_medusa.egg-info/entry_points.txt
|
| 20 |
+
openenv_medusa.egg-info/requires.txt
|
| 21 |
+
openenv_medusa.egg-info/top_level.txt
|
| 22 |
+
server/__init__.py
|
| 23 |
+
server/app.py
|
| 24 |
+
server/medusa_env.py
|
| 25 |
+
tests/test_medusa_environment.py
|
openenv_medusa.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_medusa.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = server.app:main
|
openenv_medusa.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.2
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
uvicorn>=0.24.0
|
| 5 |
+
pandas>=2.0.0
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
|
| 8 |
+
[dev]
|
| 9 |
+
pytest>=8.0.0
|
| 10 |
+
pytest-asyncio>=0.23.0
|
openenv_medusa.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
medusa_env
|
| 2 |
+
server
|
operators.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA ETL operators.
|
| 2 |
+
|
| 3 |
+
Each operator is a stateless function that takes DataFrame(s) and returns a
|
| 4 |
+
(result_df_or_None, metrics_dict) tuple. The environment calls these from
|
| 5 |
+
``step()`` and passes the metrics to the reward engine.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import datetime
|
| 11 |
+
from typing import Any, Dict, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Type alias
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
Metrics = Dict[str, Any]
|
| 21 |
+
OpResult = Tuple[Optional[pd.DataFrame], Metrics]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Operator: sync_check
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def sync_check(
|
| 29 |
+
bronze_a: pd.DataFrame,
|
| 30 |
+
bronze_b: pd.DataFrame,
|
| 31 |
+
time_delta_a: float,
|
| 32 |
+
time_delta_b: float,
|
| 33 |
+
stale_threshold_hours: float = 6.0,
|
| 34 |
+
) -> OpResult:
|
| 35 |
+
"""Inspect freshness of both sources.
|
| 36 |
+
|
| 37 |
+
Returns metrics about staleness without modifying any data.
|
| 38 |
+
"""
|
| 39 |
+
is_stale_a = time_delta_a > stale_threshold_hours
|
| 40 |
+
is_stale_b = time_delta_b > stale_threshold_hours
|
| 41 |
+
metrics: Metrics = {
|
| 42 |
+
"time_delta_a": time_delta_a,
|
| 43 |
+
"time_delta_b": time_delta_b,
|
| 44 |
+
"is_stale_a": is_stale_a,
|
| 45 |
+
"is_stale_b": is_stale_b,
|
| 46 |
+
"rows_a": len(bronze_a),
|
| 47 |
+
"rows_b": len(bronze_b),
|
| 48 |
+
}
|
| 49 |
+
return None, metrics
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Operator: evolve_schema
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
def evolve_schema(
|
| 57 |
+
silver: pd.DataFrame,
|
| 58 |
+
bronze_a: pd.DataFrame,
|
| 59 |
+
bronze_b: pd.DataFrame,
|
| 60 |
+
new_cols_a: list[str],
|
| 61 |
+
new_cols_b: list[str],
|
| 62 |
+
) -> OpResult:
|
| 63 |
+
"""Add new columns (from schema drift) to the Silver DataFrame.
|
| 64 |
+
|
| 65 |
+
Fills missing historical rows with NaN.
|
| 66 |
+
"""
|
| 67 |
+
added: list[str] = []
|
| 68 |
+
result = silver.copy()
|
| 69 |
+
|
| 70 |
+
for col in new_cols_a + new_cols_b:
|
| 71 |
+
if col not in result.columns:
|
| 72 |
+
result[col] = pd.NA
|
| 73 |
+
added.append(col)
|
| 74 |
+
|
| 75 |
+
metrics: Metrics = {
|
| 76 |
+
"cols_added": added,
|
| 77 |
+
"new_cols_count": len(added),
|
| 78 |
+
"silver_col_count": len(result.columns),
|
| 79 |
+
}
|
| 80 |
+
return result, metrics
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# Operator: prep_keys
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
def prep_keys(df: pd.DataFrame, key_col: str) -> OpResult:
|
| 88 |
+
"""Cast, strip whitespace, and null-fill the join key column.
|
| 89 |
+
|
| 90 |
+
Returns a cleaned copy of ``df`` with metrics about how many rows were
|
| 91 |
+
affected.
|
| 92 |
+
"""
|
| 93 |
+
result = df.copy()
|
| 94 |
+
original_nulls = result[key_col].isna().sum()
|
| 95 |
+
original_len = len(result)
|
| 96 |
+
|
| 97 |
+
# Strip whitespace (treat blank strings as nulls)
|
| 98 |
+
result[key_col] = result[key_col].astype(str).str.strip()
|
| 99 |
+
result[key_col] = result[key_col].replace({"None": pd.NA, "nan": pd.NA, "": pd.NA})
|
| 100 |
+
|
| 101 |
+
# Cast to string (uniform type for join)
|
| 102 |
+
result[key_col] = result[key_col].astype("string")
|
| 103 |
+
|
| 104 |
+
after_nulls = result[key_col].isna().sum()
|
| 105 |
+
null_ratio_before = original_nulls / max(original_len, 1)
|
| 106 |
+
null_ratio_after = int(after_nulls) / max(original_len, 1)
|
| 107 |
+
|
| 108 |
+
metrics: Metrics = {
|
| 109 |
+
"null_ratio_before": null_ratio_before,
|
| 110 |
+
"null_ratio_after": null_ratio_after,
|
| 111 |
+
"rows_trimmed": original_len - int(after_nulls),
|
| 112 |
+
"null_rows_dropped": 0, # We do NOT drop nulls; grader catches orphans
|
| 113 |
+
}
|
| 114 |
+
return result, metrics
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
# Operator: deduplicate
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
|
| 121 |
+
def deduplicate(df: pd.DataFrame, key_col: str) -> OpResult:
|
| 122 |
+
"""Ensure Dimension (Source B) is unique on ``key_col``.
|
| 123 |
+
|
| 124 |
+
Keeps the last occurrence so the most-recent record wins.
|
| 125 |
+
"""
|
| 126 |
+
original_len = len(df)
|
| 127 |
+
result = df.drop_duplicates(subset=[key_col], keep="last").reset_index(drop=True)
|
| 128 |
+
dupes_removed = original_len - len(result)
|
| 129 |
+
|
| 130 |
+
non_null = result[key_col].notna().sum()
|
| 131 |
+
uniqueness = non_null / max(len(result), 1)
|
| 132 |
+
|
| 133 |
+
metrics: Metrics = {
|
| 134 |
+
"dupes_removed": dupes_removed,
|
| 135 |
+
"uniqueness": float(uniqueness),
|
| 136 |
+
"rows_after": len(result),
|
| 137 |
+
}
|
| 138 |
+
return result, metrics
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
# Operator: execute_join
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
|
| 145 |
+
_EXPLOSION_MULTIPLIER = 1.05 # > 5% extra rows triggers explosion alert
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def execute_join(
|
| 149 |
+
fact: pd.DataFrame,
|
| 150 |
+
dim: pd.DataFrame,
|
| 151 |
+
key_col: str,
|
| 152 |
+
join_type: str, # "inner" | "left" | "anti"
|
| 153 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, Metrics]:
|
| 154 |
+
"""Join Fact (A) with Dimension (B).
|
| 155 |
+
|
| 156 |
+
Returns (joined_df, quarantine_df, metrics).
|
| 157 |
+
``quarantine_df`` contains rows from A that did not match B (orphans).
|
| 158 |
+
"""
|
| 159 |
+
# Drop null-keyed rows from both before joining
|
| 160 |
+
fact_clean = fact.dropna(subset=[key_col])
|
| 161 |
+
dim_clean = dim.dropna(subset=[key_col])
|
| 162 |
+
|
| 163 |
+
# Compute match rate before join
|
| 164 |
+
fact_keys = set(fact_clean[key_col].astype(str))
|
| 165 |
+
dim_keys = set(dim_clean[key_col].astype(str))
|
| 166 |
+
overlap = fact_keys & dim_keys
|
| 167 |
+
match_rate = len(overlap) / max(len(fact_keys), 1)
|
| 168 |
+
|
| 169 |
+
if join_type == "anti":
|
| 170 |
+
# Anti-join: rows in A NOT in B β goes to quarantine
|
| 171 |
+
mask = ~fact_clean[key_col].astype(str).isin(dim_keys)
|
| 172 |
+
joined = pd.DataFrame(columns=list(fact_clean.columns) + [
|
| 173 |
+
c for c in dim_clean.columns if c != key_col
|
| 174 |
+
])
|
| 175 |
+
quarantine = fact_clean[mask].copy()
|
| 176 |
+
elif join_type == "inner":
|
| 177 |
+
merged = fact_clean.merge(dim_clean, on=key_col, how="inner",
|
| 178 |
+
suffixes=("_a", "_b"))
|
| 179 |
+
quarantine = fact_clean[~fact_clean[key_col].astype(str).isin(dim_keys)].copy()
|
| 180 |
+
joined = merged
|
| 181 |
+
else: # left
|
| 182 |
+
merged = fact_clean.merge(dim_clean, on=key_col, how="left",
|
| 183 |
+
suffixes=("_a", "_b"))
|
| 184 |
+
# Quarantine = rows where all dim columns are NaN (no match)
|
| 185 |
+
dim_cols = [c for c in dim_clean.columns if c != key_col]
|
| 186 |
+
if dim_cols:
|
| 187 |
+
no_match_mask = merged[dim_cols[0]].isna() if dim_cols else pd.Series(False, index=merged.index)
|
| 188 |
+
else:
|
| 189 |
+
no_match_mask = pd.Series(False, index=merged.index)
|
| 190 |
+
quarantine = merged[no_match_mask][[key_col]].copy()
|
| 191 |
+
joined = merged
|
| 192 |
+
|
| 193 |
+
# Explosion detection
|
| 194 |
+
explosion = len(joined) > len(fact_clean) * _EXPLOSION_MULTIPLIER
|
| 195 |
+
|
| 196 |
+
metrics: Metrics = {
|
| 197 |
+
"join_type": join_type,
|
| 198 |
+
"fact_rows": len(fact_clean),
|
| 199 |
+
"dim_rows": len(dim_clean),
|
| 200 |
+
"join_rows": len(joined),
|
| 201 |
+
"quarantine_rows": len(quarantine),
|
| 202 |
+
"match_rate": match_rate,
|
| 203 |
+
"explosion_detected": explosion,
|
| 204 |
+
}
|
| 205 |
+
return joined, quarantine, metrics
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# ---------------------------------------------------------------------------
|
| 209 |
+
# Operator: apply_scd
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
def apply_scd(
|
| 213 |
+
silver: pd.DataFrame,
|
| 214 |
+
joined: pd.DataFrame,
|
| 215 |
+
key_col: str,
|
| 216 |
+
tracked_col: str,
|
| 217 |
+
scd_type: int, # 1 or 2
|
| 218 |
+
) -> OpResult:
|
| 219 |
+
"""Merge ``joined`` result into Silver using SCD-1 or SCD-2.
|
| 220 |
+
|
| 221 |
+
SCD-1: overwrite existing records.
|
| 222 |
+
SCD-2: close old records (valid_to = now) and insert new ones with
|
| 223 |
+
a new valid_from / valid_to = None (open record).
|
| 224 |
+
"""
|
| 225 |
+
now = datetime.datetime.now(datetime.UTC)
|
| 226 |
+
inserts = 0
|
| 227 |
+
updates = 0
|
| 228 |
+
|
| 229 |
+
if joined.empty:
|
| 230 |
+
metrics: Metrics = {
|
| 231 |
+
"scd_type": scd_type,
|
| 232 |
+
"inserts": 0,
|
| 233 |
+
"updates": 0,
|
| 234 |
+
"silver_rows": len(silver),
|
| 235 |
+
}
|
| 236 |
+
return silver, metrics
|
| 237 |
+
|
| 238 |
+
if silver.empty:
|
| 239 |
+
# First load β treat everything as inserts
|
| 240 |
+
result = joined.copy()
|
| 241 |
+
if scd_type == 2:
|
| 242 |
+
result["valid_from"] = now
|
| 243 |
+
result["valid_to"] = pd.NaT
|
| 244 |
+
result["is_current"] = True
|
| 245 |
+
inserts = len(result)
|
| 246 |
+
metrics = {
|
| 247 |
+
"scd_type": scd_type,
|
| 248 |
+
"inserts": inserts,
|
| 249 |
+
"updates": 0,
|
| 250 |
+
"silver_rows": len(result),
|
| 251 |
+
}
|
| 252 |
+
return result, metrics
|
| 253 |
+
|
| 254 |
+
if scd_type == 1:
|
| 255 |
+
# Upsert: overwrite matching records
|
| 256 |
+
exists_mask = silver[key_col].isin(joined[key_col])
|
| 257 |
+
new_keys_mask = ~joined[key_col].isin(silver[key_col])
|
| 258 |
+
|
| 259 |
+
result = silver[~exists_mask].copy()
|
| 260 |
+
result = pd.concat([result, joined], ignore_index=True)
|
| 261 |
+
|
| 262 |
+
updates = int(exists_mask.sum())
|
| 263 |
+
inserts = int(new_keys_mask.sum())
|
| 264 |
+
|
| 265 |
+
else: # SCD-2
|
| 266 |
+
# Ensure Silver has timestamp columns
|
| 267 |
+
if "valid_from" not in silver.columns:
|
| 268 |
+
silver = silver.copy()
|
| 269 |
+
silver["valid_from"] = now - datetime.timedelta(days=30)
|
| 270 |
+
silver["valid_to"] = pd.NaT
|
| 271 |
+
silver["is_current"] = True
|
| 272 |
+
|
| 273 |
+
silver_result = silver.copy()
|
| 274 |
+
new_rows: list[pd.DataFrame] = []
|
| 275 |
+
|
| 276 |
+
for _, new_row in joined.iterrows():
|
| 277 |
+
key_val = new_row[key_col]
|
| 278 |
+
current_mask = (silver_result[key_col] == key_val) & (silver_result["is_current"] == True) # noqa: E712
|
| 279 |
+
current_rows = silver_result[current_mask]
|
| 280 |
+
|
| 281 |
+
if current_rows.empty:
|
| 282 |
+
# New record
|
| 283 |
+
row_df = pd.DataFrame([new_row])
|
| 284 |
+
row_df["valid_from"] = now
|
| 285 |
+
row_df["valid_to"] = pd.NaT
|
| 286 |
+
row_df["is_current"] = True
|
| 287 |
+
new_rows.append(row_df)
|
| 288 |
+
inserts += 1
|
| 289 |
+
else:
|
| 290 |
+
# Check if tracked column changed
|
| 291 |
+
old_val = current_rows.iloc[0].get(tracked_col)
|
| 292 |
+
new_val = new_row.get(tracked_col)
|
| 293 |
+
if old_val != new_val:
|
| 294 |
+
# Close old record
|
| 295 |
+
silver_result.loc[current_mask, "valid_to"] = now
|
| 296 |
+
silver_result.loc[current_mask, "is_current"] = False
|
| 297 |
+
# Insert new record
|
| 298 |
+
row_df = pd.DataFrame([new_row])
|
| 299 |
+
row_df["valid_from"] = now
|
| 300 |
+
row_df["valid_to"] = pd.NaT
|
| 301 |
+
row_df["is_current"] = True
|
| 302 |
+
new_rows.append(row_df)
|
| 303 |
+
updates += 1
|
| 304 |
+
|
| 305 |
+
if new_rows:
|
| 306 |
+
silver_result = pd.concat([silver_result] + new_rows, ignore_index=True)
|
| 307 |
+
result = silver_result
|
| 308 |
+
|
| 309 |
+
metrics = {
|
| 310 |
+
"scd_type": scd_type,
|
| 311 |
+
"inserts": inserts,
|
| 312 |
+
"updates": updates,
|
| 313 |
+
"silver_rows": len(result),
|
| 314 |
+
}
|
| 315 |
+
return result, metrics
|
pyproject.toml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-medusa"
|
| 7 |
+
version = "0.2.0"
|
| 8 |
+
description = "MEDUSA: Medallion-Engineered Deterministic Unified Storage Agent β BronzeβSilver RL environment for OpenEnv"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
# Core OpenEnv dependencies
|
| 12 |
+
"openenv-core[core]>=0.2.2",
|
| 13 |
+
"fastapi>=0.115.0",
|
| 14 |
+
"pydantic>=2.0.0",
|
| 15 |
+
"uvicorn>=0.24.0",
|
| 16 |
+
# Data pipeline dependencies
|
| 17 |
+
"pandas>=2.0.0",
|
| 18 |
+
"numpy>=1.24.0",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[project.optional-dependencies]
|
| 22 |
+
dev = [
|
| 23 |
+
"pytest>=8.0.0",
|
| 24 |
+
"pytest-asyncio>=0.23.0",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
[project.scripts]
|
| 28 |
+
# Enables: uv run server (from the medusa_env directory)
|
| 29 |
+
server = "server.app:main"
|
| 30 |
+
|
| 31 |
+
[tool.setuptools]
|
| 32 |
+
include-package-data = true
|
| 33 |
+
packages = ["medusa_env", "medusa_env.server", "server"]
|
| 34 |
+
package-dir = { "medusa_env" = ".", "server" = "server" }
|
| 35 |
+
|
| 36 |
+
[tool.setuptools.package-data]
|
| 37 |
+
medusa_env = ["**/*.yaml", "**/*.yml"]
|
rewards.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA reward engine.
|
| 2 |
+
|
| 3 |
+
Reward model as defined in the MEDUSA blueprint. All reward logic is in a
|
| 4 |
+
single ``RewardEngine`` class so it can be unit-tested in isolation from the
|
| 5 |
+
environment.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Any, Dict
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Reward table (blueprint Β§3)
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
REWARD_TABLE: Dict[str, float] = {
|
| 18 |
+
"high_match_join": +25.0, # match_rate > 0.90
|
| 19 |
+
"correct_scd2": +5.0, # SCD-2 used on a tracked column
|
| 20 |
+
"quarantine_precision": +10.0, # Orphaned rows correctly moved to quarantine
|
| 21 |
+
"row_explosion": -100.0, # Cartesian product detected
|
| 22 |
+
"dirty_join": -30.0, # Join attempted without PREP_KEYS β 0-row result
|
| 23 |
+
"stale_processing": -15.0, # Action taken while source is stale (not synced first)
|
| 24 |
+
"step_penalty": -0.2, # Per-step efficiency penalty
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
HIGH_MATCH_THRESHOLD = 0.90
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# RewardEngine
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
class RewardEngine:
|
| 35 |
+
"""Compute per-step reward from action context and operator metrics."""
|
| 36 |
+
|
| 37 |
+
def evaluate(
|
| 38 |
+
self,
|
| 39 |
+
action_type: str,
|
| 40 |
+
metrics: Dict[str, Any],
|
| 41 |
+
state_before: Any, # MedusaState snapshot before step
|
| 42 |
+
) -> float:
|
| 43 |
+
"""Return the scalar reward for a single step.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
action_type: The ``MedusaActionType`` value string (e.g. "SYNC_CHECK").
|
| 47 |
+
metrics: Dictionary returned by the corresponding operator.
|
| 48 |
+
state_before: State object *before* this step was applied.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Scalar float reward.
|
| 52 |
+
"""
|
| 53 |
+
reward = REWARD_TABLE["step_penalty"] # always applied
|
| 54 |
+
|
| 55 |
+
if action_type == "SYNC_CHECK":
|
| 56 |
+
# No positive/negative signal from sync_check itself
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
elif action_type in ("PREP_KEYS_A", "PREP_KEYS_B"):
|
| 60 |
+
# Neutral β prep is just a prerequisite
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
elif action_type == "DEDUPLICATE_B":
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
elif action_type == "EVOLVE_SCHEMA":
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
elif action_type in ("EXECUTE_JOIN_INNER", "EXECUTE_JOIN_LEFT", "EXECUTE_JOIN_ANTI"):
|
| 70 |
+
explosion = metrics.get("explosion_detected", False)
|
| 71 |
+
if explosion:
|
| 72 |
+
reward += REWARD_TABLE["row_explosion"]
|
| 73 |
+
else:
|
| 74 |
+
join_rows = metrics.get("join_rows", 0)
|
| 75 |
+
fact_rows = metrics.get("fact_rows", 1)
|
| 76 |
+
# "Dirty join" = join executed without PREP_KEYS and produced 0 rows
|
| 77 |
+
# even though the source was non-empty
|
| 78 |
+
if join_rows == 0 and fact_rows > 0:
|
| 79 |
+
if not state_before.did_prep_a or not state_before.did_prep_b:
|
| 80 |
+
reward += REWARD_TABLE["dirty_join"]
|
| 81 |
+
else:
|
| 82 |
+
match_rate = metrics.get("match_rate", 0.0)
|
| 83 |
+
if match_rate >= HIGH_MATCH_THRESHOLD:
|
| 84 |
+
reward += REWARD_TABLE["high_match_join"]
|
| 85 |
+
|
| 86 |
+
# Quarantine precision: reward if orphans were quarantined
|
| 87 |
+
quarantine_rows = metrics.get("quarantine_rows", 0)
|
| 88 |
+
if quarantine_rows > 0 and action_type == "EXECUTE_JOIN_LEFT":
|
| 89 |
+
reward += REWARD_TABLE["quarantine_precision"]
|
| 90 |
+
|
| 91 |
+
# Stale processing: ran join while a source was stale (never synced)
|
| 92 |
+
if (state_before.is_stale_a or state_before.is_stale_b) and not state_before.did_sync_check:
|
| 93 |
+
reward += REWARD_TABLE["stale_processing"]
|
| 94 |
+
|
| 95 |
+
elif action_type in ("APPLY_SCD_1", "APPLY_SCD_2"):
|
| 96 |
+
if action_type == "APPLY_SCD_2":
|
| 97 |
+
# Reward if SCD-2 was the right choice (tracked col involved)
|
| 98 |
+
reward += REWARD_TABLE["correct_scd2"]
|
| 99 |
+
|
| 100 |
+
if (state_before.is_stale_a or state_before.is_stale_b) and not state_before.did_sync_check:
|
| 101 |
+
reward += REWARD_TABLE["stale_processing"]
|
| 102 |
+
|
| 103 |
+
elif action_type == "COMMIT":
|
| 104 |
+
# Base commit β grader adds bonus/penalty separately
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
return reward
|
scenarios.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA scenario generator.
|
| 2 |
+
|
| 3 |
+
Produces randomised Bronze A (Fact) and Bronze B (Dimension) DataFrames to
|
| 4 |
+
drive each training episode. Four canonical scenarios cover the canonical
|
| 5 |
+
failure modes described in the MEDUSA blueprint.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# Scenario dataclass
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class Scenario:
|
| 24 |
+
"""One episode's worth of Bronze source data + configuration."""
|
| 25 |
+
|
| 26 |
+
id: str
|
| 27 |
+
bronze_a: pd.DataFrame # Fact table (source of truth for volume)
|
| 28 |
+
bronze_b: pd.DataFrame # Dimension table (must be unique on key)
|
| 29 |
+
join_key: str # Column name used to join A and B
|
| 30 |
+
tracked_cols: List[str] # Columns in B that require SCD-2 history
|
| 31 |
+
is_stale_a: bool # Whether Source A is past the freshness threshold
|
| 32 |
+
is_stale_b: bool
|
| 33 |
+
time_delta_a: float # Hours since Source A was last refreshed
|
| 34 |
+
time_delta_b: float
|
| 35 |
+
new_cols_a: List[str] # Extra columns in A not in Silver yet
|
| 36 |
+
new_cols_b: List[str] # Extra columns in B not in Silver yet
|
| 37 |
+
description: str = ""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Internal helpers
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
_STALE_THRESHOLD_HOURS = 6.0
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _make_fact(
|
| 48 |
+
rng: random.Random,
|
| 49 |
+
n_rows: int,
|
| 50 |
+
key_col: str,
|
| 51 |
+
null_ratio: float = 0.0,
|
| 52 |
+
extra_cols: Optional[List[str]] = None,
|
| 53 |
+
) -> pd.DataFrame:
|
| 54 |
+
"""Create a synthetic Fact (Bronze A) DataFrame."""
|
| 55 |
+
keys = [f"K{i:04d}" for i in rng.sample(range(1, n_rows * 2), n_rows)]
|
| 56 |
+
|
| 57 |
+
# Inject nulls into the key
|
| 58 |
+
null_mask = rng.sample(range(n_rows), int(n_rows * null_ratio))
|
| 59 |
+
for idx in null_mask:
|
| 60 |
+
keys[idx] = None # type: ignore[call-overload]
|
| 61 |
+
|
| 62 |
+
data = {
|
| 63 |
+
key_col: keys,
|
| 64 |
+
"fact_value": [rng.uniform(0, 1000) for _ in range(n_rows)],
|
| 65 |
+
"fact_category": [rng.choice(["A", "B", "C"]) for _ in range(n_rows)],
|
| 66 |
+
"created_at": pd.date_range("2024-01-01", periods=n_rows, freq="h"),
|
| 67 |
+
}
|
| 68 |
+
for col in (extra_cols or []):
|
| 69 |
+
data[col] = [rng.uniform(0, 100) for _ in range(n_rows)]
|
| 70 |
+
|
| 71 |
+
return pd.DataFrame(data)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _make_dim(
|
| 75 |
+
rng: random.Random,
|
| 76 |
+
n_rows: int,
|
| 77 |
+
key_col: str,
|
| 78 |
+
null_ratio: float = 0.0,
|
| 79 |
+
uniqueness: float = 1.0, # < 1.0 means some keys are duplicated
|
| 80 |
+
match_keys: Optional[List[str]] = None, # If given, use these as the key pool
|
| 81 |
+
extra_cols: Optional[List[str]] = None,
|
| 82 |
+
tracked_cols: Optional[List[str]] = None,
|
| 83 |
+
) -> pd.DataFrame:
|
| 84 |
+
"""Create a synthetic Dimension (Bronze B) DataFrame."""
|
| 85 |
+
if match_keys:
|
| 86 |
+
# Choose from overlap pool to control referential integrity
|
| 87 |
+
available = list(match_keys)
|
| 88 |
+
keys = [rng.choice(available) for _ in range(n_rows)]
|
| 89 |
+
else:
|
| 90 |
+
keys = [f"K{i:04d}" for i in rng.sample(range(1, n_rows * 3), n_rows)]
|
| 91 |
+
|
| 92 |
+
# Inject duplicates (lower uniqueness)
|
| 93 |
+
if uniqueness < 1.0:
|
| 94 |
+
n_dupes = int(n_rows * (1 - uniqueness))
|
| 95 |
+
for i in rng.sample(range(n_rows), n_dupes):
|
| 96 |
+
keys[i] = keys[rng.randint(0, i - 1)] if i > 0 else keys[0]
|
| 97 |
+
|
| 98 |
+
# Inject nulls
|
| 99 |
+
null_mask = rng.sample(range(n_rows), int(n_rows * null_ratio))
|
| 100 |
+
for idx in null_mask:
|
| 101 |
+
keys[idx] = None # type: ignore[call-overload]
|
| 102 |
+
|
| 103 |
+
data: dict = {key_col: keys, "dim_name": [f"Name_{k}" for k in keys]}
|
| 104 |
+
for col in (tracked_cols or []):
|
| 105 |
+
data[col] = [rng.choice(["x", "y", "z"]) for _ in range(n_rows)]
|
| 106 |
+
for col in (extra_cols or []):
|
| 107 |
+
data[col] = [rng.uniform(0, 100) for _ in range(n_rows)]
|
| 108 |
+
|
| 109 |
+
return pd.DataFrame(data)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
# Scenario Generator
|
| 114 |
+
# ---------------------------------------------------------------------------
|
| 115 |
+
|
| 116 |
+
class ScenarioGenerator:
|
| 117 |
+
"""Generates Bronze A/B DataFrames for MEDUSA episodes."""
|
| 118 |
+
|
| 119 |
+
STALE_THRESHOLD = _STALE_THRESHOLD_HOURS
|
| 120 |
+
JOIN_KEY = "entity_id"
|
| 121 |
+
TRACKED_COLS = ["dim_status"]
|
| 122 |
+
|
| 123 |
+
# Four canonical scenario types
|
| 124 |
+
CANONICAL: List[str] = ["clean", "dirty_keys", "stale", "schema_drift"]
|
| 125 |
+
|
| 126 |
+
def __init__(self, n_fact_rows: int = 200, n_dim_rows: int = 150):
|
| 127 |
+
self.n_fact_rows = n_fact_rows
|
| 128 |
+
self.n_dim_rows = n_dim_rows
|
| 129 |
+
|
| 130 |
+
def generate(self, seed: Optional[int] = None) -> Scenario:
|
| 131 |
+
"""Generate a random scenario. Canonical scenarios cycle through seeds 0-3."""
|
| 132 |
+
rng = random.Random(seed)
|
| 133 |
+
if seed is not None and 0 <= seed < len(self.CANONICAL):
|
| 134 |
+
return self._canonical(self.CANONICAL[seed], seed)
|
| 135 |
+
variant = rng.choice(self.CANONICAL)
|
| 136 |
+
return self._canonical(variant, seed)
|
| 137 |
+
|
| 138 |
+
def _canonical(self, variant: str, seed: Optional[int]) -> Scenario:
|
| 139 |
+
rng = random.Random(seed)
|
| 140 |
+
np_rng = np.random.default_rng(seed)
|
| 141 |
+
key = self.JOIN_KEY
|
| 142 |
+
n_a = self.n_fact_rows
|
| 143 |
+
n_b = self.n_dim_rows
|
| 144 |
+
|
| 145 |
+
if variant == "clean":
|
| 146 |
+
# Fresh, unique keys, ~100% match rate
|
| 147 |
+
fact = _make_fact(rng, n_a, key, null_ratio=0.0)
|
| 148 |
+
valid_keys = fact[key].dropna().tolist()
|
| 149 |
+
dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0,
|
| 150 |
+
match_keys=valid_keys, tracked_cols=self.TRACKED_COLS)
|
| 151 |
+
return Scenario(
|
| 152 |
+
id=f"clean_{seed}",
|
| 153 |
+
bronze_a=fact, bronze_b=dim,
|
| 154 |
+
join_key=key, tracked_cols=self.TRACKED_COLS,
|
| 155 |
+
is_stale_a=False, is_stale_b=False,
|
| 156 |
+
time_delta_a=1.0, time_delta_b=2.0,
|
| 157 |
+
new_cols_a=[], new_cols_b=[],
|
| 158 |
+
description="Clean scenario: fresh, unique keys, high match rate.",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
elif variant == "dirty_keys":
|
| 162 |
+
# High null ratio in keys, no trimming / type-casting yet
|
| 163 |
+
fact = _make_fact(rng, n_a, key, null_ratio=0.25)
|
| 164 |
+
fact[key] = fact[key].apply(
|
| 165 |
+
lambda k: f" {k} " if k and rng.random() < 0.3 else k # whitespace noise
|
| 166 |
+
)
|
| 167 |
+
dim = _make_dim(rng, n_b, key, null_ratio=0.15, uniqueness=0.85,
|
| 168 |
+
tracked_cols=self.TRACKED_COLS)
|
| 169 |
+
return Scenario(
|
| 170 |
+
id=f"dirty_keys_{seed}",
|
| 171 |
+
bronze_a=fact, bronze_b=dim,
|
| 172 |
+
join_key=key, tracked_cols=self.TRACKED_COLS,
|
| 173 |
+
is_stale_a=False, is_stale_b=False,
|
| 174 |
+
time_delta_a=2.0, time_delta_b=3.0,
|
| 175 |
+
new_cols_a=[], new_cols_b=[],
|
| 176 |
+
description="Dirty keys: nulls + whitespace in join keys.",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
elif variant == "stale":
|
| 180 |
+
# One or both sources have not refreshed recently
|
| 181 |
+
fact = _make_fact(rng, n_a, key, null_ratio=0.0)
|
| 182 |
+
valid_keys = fact[key].dropna().tolist()
|
| 183 |
+
dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0,
|
| 184 |
+
match_keys=valid_keys, tracked_cols=self.TRACKED_COLS)
|
| 185 |
+
td_a = rng.uniform(8.0, 24.0) # definitely stale
|
| 186 |
+
td_b = rng.uniform(0.5, 4.0)
|
| 187 |
+
return Scenario(
|
| 188 |
+
id=f"stale_{seed}",
|
| 189 |
+
bronze_a=fact, bronze_b=dim,
|
| 190 |
+
join_key=key, tracked_cols=self.TRACKED_COLS,
|
| 191 |
+
is_stale_a=td_a > self.STALE_THRESHOLD,
|
| 192 |
+
is_stale_b=td_b > self.STALE_THRESHOLD,
|
| 193 |
+
time_delta_a=td_a, time_delta_b=td_b,
|
| 194 |
+
new_cols_a=[], new_cols_b=[],
|
| 195 |
+
description=f"Stale scenario: Source A is {td_a:.1f}h old.",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
else: # schema_drift
|
| 199 |
+
# New columns in A and/or B not yet registered in Silver
|
| 200 |
+
extra_a = ["new_metric_a"]
|
| 201 |
+
extra_b = ["new_attr_b"]
|
| 202 |
+
fact = _make_fact(rng, n_a, key, null_ratio=0.0, extra_cols=extra_a)
|
| 203 |
+
valid_keys = fact[key].dropna().tolist()
|
| 204 |
+
dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0,
|
| 205 |
+
match_keys=valid_keys,
|
| 206 |
+
tracked_cols=self.TRACKED_COLS, extra_cols=extra_b)
|
| 207 |
+
return Scenario(
|
| 208 |
+
id=f"schema_drift_{seed}",
|
| 209 |
+
bronze_a=fact, bronze_b=dim,
|
| 210 |
+
join_key=key, tracked_cols=self.TRACKED_COLS,
|
| 211 |
+
is_stale_a=False, is_stale_b=False,
|
| 212 |
+
time_delta_a=1.0, time_delta_b=1.5,
|
| 213 |
+
new_cols_a=extra_a, new_cols_b=extra_b,
|
| 214 |
+
description="Schema drift: new columns in A and B.",
|
| 215 |
+
)
|
scripts/inference.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA inference script β OpenEnv Hackathon submission.
|
| 2 |
+
|
| 3 |
+
Runs an LLM agent (via OpenAI-compatible API) against all three MEDUSA tasks
|
| 4 |
+
and reports per-task scores (0.0β1.0).
|
| 5 |
+
|
| 6 |
+
Required environment variables:
|
| 7 |
+
API_BASE_URL The API endpoint for the LLM (OpenAI-compatible).
|
| 8 |
+
MODEL_NAME The model identifier to use for inference.
|
| 9 |
+
HF_TOKEN Your Hugging Face / API key (used as the API key).
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
export API_BASE_URL="https://api.openai.com/v1"
|
| 13 |
+
export MODEL_NAME="gpt-4o-mini"
|
| 14 |
+
export HF_TOKEN="hf-..."
|
| 15 |
+
python inference.py
|
| 16 |
+
|
| 17 |
+
Output:
|
| 18 |
+
Prints per-task results and a final summary table to stdout.
|
| 19 |
+
Exits with code 0 if all tasks score >= 0.35, else 1.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
import sys
|
| 27 |
+
import textwrap
|
| 28 |
+
import time
|
| 29 |
+
from typing import List, Optional
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Validate required environment variables before anything else
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "").rstrip("/")
|
| 36 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "")
|
| 37 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 38 |
+
|
| 39 |
+
_missing = [k for k, v in {
|
| 40 |
+
"API_BASE_URL": API_BASE_URL,
|
| 41 |
+
"MODEL_NAME": MODEL_NAME,
|
| 42 |
+
"HF_TOKEN": HF_TOKEN,
|
| 43 |
+
}.items() if not v]
|
| 44 |
+
|
| 45 |
+
if _missing:
|
| 46 |
+
print(f"ERROR: Missing required environment variables: {', '.join(_missing)}", file=sys.stderr)
|
| 47 |
+
print("Set them before running:", file=sys.stderr)
|
| 48 |
+
for k in _missing:
|
| 49 |
+
print(f" export {k}=<value>", file=sys.stderr)
|
| 50 |
+
sys.exit(1)
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# OpenAI client (uses API_BASE_URL + HF_TOKEN as the key)
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
from openai import OpenAI # noqa: E402
|
| 57 |
+
|
| 58 |
+
client = OpenAI(
|
| 59 |
+
base_url=API_BASE_URL,
|
| 60 |
+
api_key=HF_TOKEN,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# MEDUSA environment imports
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
from pathlib import Path
|
| 68 |
+
|
| 69 |
+
# Dynamically add the OpenEnv repo root to sys.path so absolute imports work
|
| 70 |
+
# no matter where this script is executed from.
|
| 71 |
+
repo_root = str(Path(__file__).resolve().parent.parent.parent)
|
| 72 |
+
if repo_root not in sys.path:
|
| 73 |
+
sys.path.insert(0, repo_root)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# In-repo
|
| 77 |
+
from envs.medusa_env import MedusaEnv
|
| 78 |
+
from envs.medusa_env.models import MedusaAction, MedusaActionType
|
| 79 |
+
from envs.medusa_env.tasks import TASKS, TaskResult, score_episode
|
| 80 |
+
except ImportError:
|
| 81 |
+
# Standalone (running from inside envs/medusa_env/ installation)
|
| 82 |
+
from medusa_env import MedusaEnv # type: ignore
|
| 83 |
+
from models import MedusaAction, MedusaActionType # type: ignore
|
| 84 |
+
from tasks import TASKS, TaskResult, score_episode # type: ignore
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# System prompt
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
SYSTEM_PROMPT = textwrap.dedent("""
|
| 91 |
+
You are a data integration agent controlling a BronzeβSilver ETL pipeline.
|
| 92 |
+
|
| 93 |
+
You observe a 16-float feature vector describing data quality signals, and
|
| 94 |
+
you must choose one action per step from the list below.
|
| 95 |
+
|
| 96 |
+
ACTIONS (respond with ONLY the action name β nothing else):
|
| 97 |
+
SYNC_CHECK β Verify source freshness before processing
|
| 98 |
+
EVOLVE_SCHEMA β Add new columns from sources into Silver schema
|
| 99 |
+
PREP_KEYS_A β Clean and normalise join keys in Source A (Fact)
|
| 100 |
+
PREP_KEYS_B β Clean and normalise join keys in Source B (Dimension)
|
| 101 |
+
DEDUPLICATE_B β Remove duplicate keys from Source B
|
| 102 |
+
EXECUTE_JOIN_INNER β Inner join A β B
|
| 103 |
+
EXECUTE_JOIN_LEFT β Left join A β B (keeps all Fact rows; orphans β quarantine)
|
| 104 |
+
EXECUTE_JOIN_ANTI β Anti-join: extract Fact rows with no Dimension match
|
| 105 |
+
APPLY_SCD_1 β Overwrite Silver records (SCD Type 1)
|
| 106 |
+
APPLY_SCD_2 β Close old records and insert new with timestamps (SCD Type 2)
|
| 107 |
+
COMMIT β Finalise pipeline and trigger audit
|
| 108 |
+
|
| 109 |
+
STRATEGY:
|
| 110 |
+
1. Always call SYNC_CHECK first to verify freshness.
|
| 111 |
+
2. If schema drift signals are non-zero (features[9] or [10] > 0), call EVOLVE_SCHEMA.
|
| 112 |
+
3. If null key ratios (features[4] or [5] > 0), call PREP_KEYS_A and/or PREP_KEYS_B.
|
| 113 |
+
4. If Dimension uniqueness (features[7]) < 1.0, call DEDUPLICATE_B.
|
| 114 |
+
5. Prefer EXECUTE_JOIN_LEFT to preserve all Fact rows.
|
| 115 |
+
6. Prefer APPLY_SCD_2 for tracked history.
|
| 116 |
+
7. Call COMMIT when pipeline is complete.
|
| 117 |
+
|
| 118 |
+
The feature vector indices:
|
| 119 |
+
[0] time_delta_a_norm [1] time_delta_b_norm
|
| 120 |
+
[2] is_stale_a [3] is_stale_b
|
| 121 |
+
[4] null_ratio_key_a [5] null_ratio_key_b
|
| 122 |
+
[6] uniqueness_a [7] uniqueness_b
|
| 123 |
+
[8] match_rate [9] new_cols_a_norm
|
| 124 |
+
[10] new_cols_b_norm [11] schema_compat
|
| 125 |
+
[12] did_prep_a [13] did_prep_b
|
| 126 |
+
[14] did_dedup_b [15] step_frac
|
| 127 |
+
""").strip()
|
| 128 |
+
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
# LLM action chooser
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
VALID_ACTIONS = {a.value for a in MedusaActionType}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def choose_action(
|
| 137 |
+
features: List[float],
|
| 138 |
+
history: List[dict],
|
| 139 |
+
step: int,
|
| 140 |
+
) -> str:
|
| 141 |
+
"""Ask the LLM to choose the next action given the current observation."""
|
| 142 |
+
feature_str = ", ".join(f"{v:.3f}" for v in features)
|
| 143 |
+
user_msg = (
|
| 144 |
+
f"Step {step}. Feature vector: [{feature_str}]\n"
|
| 145 |
+
"What is the single best next action? Respond with ONLY the action name."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 149 |
+
# Include the last 4 steps of history for context (keep prompt short)
|
| 150 |
+
for h in history[-4:]:
|
| 151 |
+
messages.append({"role": "user", "content": h["user"]})
|
| 152 |
+
messages.append({"role": "assistant", "content": h["assistant"]})
|
| 153 |
+
messages.append({"role": "user", "content": user_msg})
|
| 154 |
+
|
| 155 |
+
response = client.chat.completions.create(
|
| 156 |
+
model=MODEL_NAME,
|
| 157 |
+
messages=messages,
|
| 158 |
+
max_tokens=20,
|
| 159 |
+
temperature=0.0,
|
| 160 |
+
)
|
| 161 |
+
raw = response.choices[0].message.content.strip().upper().replace(" ", "_")
|
| 162 |
+
|
| 163 |
+
# Fuzzy match: accept if the response contains a valid action name
|
| 164 |
+
for action in VALID_ACTIONS:
|
| 165 |
+
if action in raw:
|
| 166 |
+
return action
|
| 167 |
+
|
| 168 |
+
# Fallback: extract the longest matching token
|
| 169 |
+
for action in sorted(VALID_ACTIONS, key=len, reverse=True):
|
| 170 |
+
if action.replace("_", "") in raw.replace("_", ""):
|
| 171 |
+
return action
|
| 172 |
+
|
| 173 |
+
# Hard fallback: commit to end gracefully
|
| 174 |
+
return MedusaActionType.COMMIT.value
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ---------------------------------------------------------------------------
|
| 178 |
+
# Run one task
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
def run_task(task_id: str, max_steps: int = 15) -> TaskResult:
|
| 182 |
+
"""Run the LLM agent for one MEDUSA task. Returns the TaskResult."""
|
| 183 |
+
task = TASKS[task_id]
|
| 184 |
+
print(f"\n{'='*60}")
|
| 185 |
+
print(f"TASK: {task.name} [{task.difficulty.upper()}] (seed={task.seed})")
|
| 186 |
+
print(f" {task.description}")
|
| 187 |
+
print(f"{'='*60}")
|
| 188 |
+
|
| 189 |
+
env = MedusaEnv(n_fact_rows=200, n_dim_rows=150, max_steps=max_steps)
|
| 190 |
+
obs = env.reset(seed=task.seed)
|
| 191 |
+
|
| 192 |
+
history: List[dict] = []
|
| 193 |
+
step = 0
|
| 194 |
+
t0 = time.time()
|
| 195 |
+
|
| 196 |
+
while not obs.done and step < max_steps:
|
| 197 |
+
step += 1
|
| 198 |
+
action_str = choose_action(obs.features, history, step)
|
| 199 |
+
action_type = MedusaActionType(action_str)
|
| 200 |
+
action = MedusaAction(action=action_type)
|
| 201 |
+
|
| 202 |
+
obs = env.step(action)
|
| 203 |
+
reward = obs.reward or 0.0
|
| 204 |
+
|
| 205 |
+
print(f" Step {step:2d}: {action_str:25s} reward={reward:+7.2f} "
|
| 206 |
+
f"cumulative={env.state.cumulative_reward:+8.2f}")
|
| 207 |
+
|
| 208 |
+
history.append({
|
| 209 |
+
"user": (f"Step {step}. Features: [{', '.join(f'{v:.3f}' for v in obs.features)}]"
|
| 210 |
+
" What action?"),
|
| 211 |
+
"assistant": action_str,
|
| 212 |
+
})
|
| 213 |
+
|
| 214 |
+
elapsed = time.time() - t0
|
| 215 |
+
result = score_episode(task_id, env.state, env._tables)
|
| 216 |
+
|
| 217 |
+
print(f"\n β Score: {result.score:.4f} Grade: {result.grade} "
|
| 218 |
+
f"Passed: {result.passed} ({elapsed:.1f}s)")
|
| 219 |
+
if result.notes:
|
| 220 |
+
for note in result.notes:
|
| 221 |
+
print(f" β {note}")
|
| 222 |
+
print(f" β Breakdown: " +
|
| 223 |
+
", ".join(f"{k}={v:.2f}" for k, v in result.breakdown.items()))
|
| 224 |
+
return result
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# ---------------------------------------------------------------------------
|
| 228 |
+
# Main
|
| 229 |
+
# ---------------------------------------------------------------------------
|
| 230 |
+
|
| 231 |
+
def main() -> None:
|
| 232 |
+
print("MEDUSA β Baseline Inference")
|
| 233 |
+
print(f"Model: {MODEL_NAME}")
|
| 234 |
+
print(f"API: {API_BASE_URL}")
|
| 235 |
+
print()
|
| 236 |
+
|
| 237 |
+
task_ids = ["clean_pipeline", "dirty_integration", "full_medallion"]
|
| 238 |
+
results: dict[str, TaskResult] = {}
|
| 239 |
+
total_start = time.time()
|
| 240 |
+
|
| 241 |
+
for task_id in task_ids:
|
| 242 |
+
result = run_task(task_id)
|
| 243 |
+
results[task_id] = result
|
| 244 |
+
|
| 245 |
+
total_elapsed = time.time() - total_start
|
| 246 |
+
|
| 247 |
+
# Summary
|
| 248 |
+
print(f"\n{'='*60}")
|
| 249 |
+
print("SUMMARY")
|
| 250 |
+
print(f"{'='*60}")
|
| 251 |
+
print(f"{'Task':<25} {'Difficulty':<8} {'Score':>6} {'Grade':>5} {'Pass?':>5}")
|
| 252 |
+
print("-" * 60)
|
| 253 |
+
all_passed = True
|
| 254 |
+
for task_id, result in results.items():
|
| 255 |
+
task = TASKS[task_id]
|
| 256 |
+
print(f"{task.name:<25} {task.difficulty:<8} "
|
| 257 |
+
f"{result.score:>6.4f} {result.grade:>5} {'YES' if result.passed else 'NO':>5}")
|
| 258 |
+
if not result.passed:
|
| 259 |
+
all_passed = False
|
| 260 |
+
|
| 261 |
+
print("-" * 60)
|
| 262 |
+
avg = sum(r.score for r in results.values()) / len(results)
|
| 263 |
+
print(f"{'Average':<25} {'':8} {avg:>6.4f}")
|
| 264 |
+
print(f"\nTotal time: {total_elapsed:.1f}s")
|
| 265 |
+
|
| 266 |
+
# Machine-readable output for the evaluator
|
| 267 |
+
output = {
|
| 268 |
+
"model": MODEL_NAME,
|
| 269 |
+
"tasks": {
|
| 270 |
+
tid: {
|
| 271 |
+
"score": r.score,
|
| 272 |
+
"grade": r.grade,
|
| 273 |
+
"passed": r.passed,
|
| 274 |
+
"breakdown": r.breakdown,
|
| 275 |
+
}
|
| 276 |
+
for tid, r in results.items()
|
| 277 |
+
},
|
| 278 |
+
"average_score": avg,
|
| 279 |
+
"all_passed": all_passed,
|
| 280 |
+
}
|
| 281 |
+
print("\n--- JSON RESULTS ---")
|
| 282 |
+
print(json.dumps(output, indent=2))
|
| 283 |
+
|
| 284 |
+
sys.exit(0 if all_passed else 1)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
main()
|
server/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server package for medusa_env."""
|
| 2 |
+
from .medusa_env import MedusaEnv
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"MedusaEnv"
|
| 6 |
+
]
|
server/app.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server for the MEDUSA environment.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
# Development:
|
| 5 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 6 |
+
|
| 7 |
+
# Via openenv CLI:
|
| 8 |
+
openenv serve medusa_env
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
# Support three import contexts:
|
| 14 |
+
# 1. In-repo (from OpenEnv root): relative imports via `..`
|
| 15 |
+
# 2. Standalone installed (uv run server): medusa_env.* package
|
| 16 |
+
# 3. Direct execution inside env dir: bare module names
|
| 17 |
+
from openenv.core.env_server.http_server import create_app
|
| 18 |
+
from medusa_env.server import MedusaEnv
|
| 19 |
+
from medusa_env.models import MedusaAction, MedusaObservation
|
| 20 |
+
|
| 21 |
+
app = create_app(
|
| 22 |
+
MedusaEnv,
|
| 23 |
+
MedusaAction,
|
| 24 |
+
MedusaObservation,
|
| 25 |
+
env_name="medusa_env",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main() -> None:
|
| 30 |
+
"""Entry point for direct execution."""
|
| 31 |
+
import uvicorn
|
| 32 |
+
|
| 33 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
main()
|
server/medusa_env.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA β full environment implementation.
|
| 2 |
+
|
| 3 |
+
Replaces the Phase-1 skeleton with a complete reset/step pipeline that:
|
| 4 |
+
β’ Generates Bronze A/B data from ``ScenarioGenerator``
|
| 5 |
+
β’ Dispatches each action to the appropriate operator
|
| 6 |
+
β’ Computes per-step rewards via ``RewardEngine``
|
| 7 |
+
β’ Runs the deterministic grader on COMMIT
|
| 8 |
+
β’ Builds a 16-float normalized feature vector for the RL agent
|
| 9 |
+
β’ Maintains a governance log of every decision
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
import time
|
| 16 |
+
import uuid
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Any, Dict, List, Optional
|
| 19 |
+
|
| 20 |
+
import pandas as pd
|
| 21 |
+
|
| 22 |
+
from openenv.core.env_server.interfaces import Environment
|
| 23 |
+
from openenv.core.env_server.types import EnvironmentMetadata
|
| 24 |
+
|
| 25 |
+
from medusa_env.grader import Grader
|
| 26 |
+
from medusa_env.models import MedusaAction, MedusaActionType, MedusaObservation, MedusaState
|
| 27 |
+
from medusa_env.operators import (
|
| 28 |
+
apply_scd,
|
| 29 |
+
deduplicate,
|
| 30 |
+
evolve_schema,
|
| 31 |
+
execute_join,
|
| 32 |
+
prep_keys,
|
| 33 |
+
sync_check,
|
| 34 |
+
)
|
| 35 |
+
from medusa_env.rewards import RewardEngine
|
| 36 |
+
from medusa_env.scenarios import Scenario, ScenarioGenerator
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Internal episode tables
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class _EpisodeTables:
|
| 45 |
+
"""In-memory tables for one episode."""
|
| 46 |
+
|
| 47 |
+
bronze_a: pd.DataFrame = field(default_factory=pd.DataFrame)
|
| 48 |
+
bronze_a_prepped: pd.DataFrame = field(default_factory=pd.DataFrame)
|
| 49 |
+
bronze_b: pd.DataFrame = field(default_factory=pd.DataFrame)
|
| 50 |
+
bronze_b_prepped: pd.DataFrame = field(default_factory=pd.DataFrame)
|
| 51 |
+
joined: pd.DataFrame = field(default_factory=pd.DataFrame)
|
| 52 |
+
silver: pd.DataFrame = field(default_factory=pd.DataFrame)
|
| 53 |
+
quarantine: pd.DataFrame = field(default_factory=pd.DataFrame)
|
| 54 |
+
governance_log: List[Dict[str, Any]] = field(default_factory=list)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
# Feature vector builder
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
_MAX_TIME_DELTA = 48.0 # Normalisation ceiling (hours)
|
| 62 |
+
_MAX_COLS = 10.0 # Normalisation ceiling (new columns)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _build_features(state: MedusaState) -> List[float]:
|
| 66 |
+
"""Build the 16-float normalised observation vector."""
|
| 67 |
+
return [
|
| 68 |
+
min(state.time_delta_a / _MAX_TIME_DELTA, 1.0),
|
| 69 |
+
min(state.time_delta_b / _MAX_TIME_DELTA, 1.0),
|
| 70 |
+
float(state.is_stale_a),
|
| 71 |
+
float(state.is_stale_b),
|
| 72 |
+
state.null_ratio_key_a,
|
| 73 |
+
state.null_ratio_key_b,
|
| 74 |
+
state.uniqueness_a,
|
| 75 |
+
state.uniqueness_b,
|
| 76 |
+
state.match_rate,
|
| 77 |
+
min(state.new_cols_a / _MAX_COLS, 1.0),
|
| 78 |
+
min(state.new_cols_b / _MAX_COLS, 1.0),
|
| 79 |
+
state.schema_compat,
|
| 80 |
+
float(state.did_prep_a),
|
| 81 |
+
float(state.did_prep_b),
|
| 82 |
+
float(state.did_dedup_b),
|
| 83 |
+
min(state.step_idx / max(state.max_steps, 1), 1.0),
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# Main environment
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
class MedusaEnv(Environment[MedusaAction, MedusaObservation, MedusaState]):
|
| 92 |
+
"""MEDUSA: Medallion-Engineered Deterministic Unified Storage Agent.
|
| 93 |
+
|
| 94 |
+
Simulates a BronzeβSilver data integration pipeline. The agent observes
|
| 95 |
+
data quality signals and chooses ETL actions to produce a correct,
|
| 96 |
+
historically consistent Silver entity.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
scenario_seed: Fixed seed for deterministic episodes. ``None`` = random.
|
| 100 |
+
max_steps: Maximum steps per episode before forced termination.
|
| 101 |
+
stale_threshold_hours: Age (hours) at which a source is deemed stale.
|
| 102 |
+
n_fact_rows: Size of the Fact / Source A table.
|
| 103 |
+
n_dim_rows: Size of the Dimension / Source B table.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
scenario_seed: Optional[int] = None,
|
| 111 |
+
max_steps: int = 20,
|
| 112 |
+
stale_threshold_hours: float = 6.0,
|
| 113 |
+
n_fact_rows: int = 200,
|
| 114 |
+
n_dim_rows: int = 150,
|
| 115 |
+
**kwargs: Any,
|
| 116 |
+
):
|
| 117 |
+
super().__init__(**kwargs)
|
| 118 |
+
self._scenario_seed = scenario_seed
|
| 119 |
+
self._max_steps = max_steps
|
| 120 |
+
self._stale_threshold = stale_threshold_hours
|
| 121 |
+
|
| 122 |
+
self._generator = ScenarioGenerator(
|
| 123 |
+
n_fact_rows=n_fact_rows, n_dim_rows=n_dim_rows
|
| 124 |
+
)
|
| 125 |
+
self._reward_engine = RewardEngine()
|
| 126 |
+
self._grader = Grader()
|
| 127 |
+
|
| 128 |
+
self._state = MedusaState()
|
| 129 |
+
self._tables = _EpisodeTables()
|
| 130 |
+
self._scenario: Optional[Scenario] = None
|
| 131 |
+
|
| 132 |
+
# ------------------------------------------------------------------
|
| 133 |
+
# Metadata
|
| 134 |
+
# ------------------------------------------------------------------
|
| 135 |
+
|
| 136 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 137 |
+
return EnvironmentMetadata(
|
| 138 |
+
name="medusa_env",
|
| 139 |
+
description=(
|
| 140 |
+
"MEDUSA: simulated BronzeβSilver integration controller for "
|
| 141 |
+
"multi-source joins, schema drift, and SCD merges."
|
| 142 |
+
),
|
| 143 |
+
version="0.2.0",
|
| 144 |
+
documentation="envs/medusa_env/README.md",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# ------------------------------------------------------------------
|
| 148 |
+
# State
|
| 149 |
+
# ------------------------------------------------------------------
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def state(self) -> MedusaState:
|
| 153 |
+
return self._state
|
| 154 |
+
|
| 155 |
+
# ------------------------------------------------------------------
|
| 156 |
+
# Reset
|
| 157 |
+
# ------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
def reset(
|
| 160 |
+
self,
|
| 161 |
+
seed: Optional[int] = None,
|
| 162 |
+
episode_id: Optional[str] = None,
|
| 163 |
+
**kwargs: Any,
|
| 164 |
+
) -> MedusaObservation:
|
| 165 |
+
self._reset_rubric()
|
| 166 |
+
|
| 167 |
+
effective_seed = seed if seed is not None else self._scenario_seed
|
| 168 |
+
run_id = episode_id or str(uuid.uuid4())
|
| 169 |
+
|
| 170 |
+
# Generate scenario
|
| 171 |
+
self._scenario = self._generator.generate(seed=effective_seed)
|
| 172 |
+
scen = self._scenario
|
| 173 |
+
|
| 174 |
+
# Initialise tables
|
| 175 |
+
self._tables = _EpisodeTables(
|
| 176 |
+
bronze_a=scen.bronze_a.copy(),
|
| 177 |
+
bronze_a_prepped=scen.bronze_a.copy(),
|
| 178 |
+
bronze_b=scen.bronze_b.copy(),
|
| 179 |
+
bronze_b_prepped=scen.bronze_b.copy(),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Compute initial key health metrics from raw Bronze
|
| 183 |
+
na_a = scen.bronze_a[scen.join_key].isna().sum()
|
| 184 |
+
na_b = scen.bronze_b[scen.join_key].isna().sum()
|
| 185 |
+
null_ratio_a = na_a / max(len(scen.bronze_a), 1)
|
| 186 |
+
null_ratio_b = na_b / max(len(scen.bronze_b), 1)
|
| 187 |
+
|
| 188 |
+
# Uniqueness of raw keys
|
| 189 |
+
nna_a = scen.bronze_a[scen.join_key].dropna()
|
| 190 |
+
nna_b = scen.bronze_b[scen.join_key].dropna()
|
| 191 |
+
uniq_a = nna_a.nunique() / max(len(nna_a), 1)
|
| 192 |
+
uniq_b = nna_b.nunique() / max(len(nna_b), 1)
|
| 193 |
+
|
| 194 |
+
# Match rate on raw keys
|
| 195 |
+
keys_a = set(nna_a.astype(str))
|
| 196 |
+
keys_b = set(nna_b.astype(str))
|
| 197 |
+
match_rate = len(keys_a & keys_b) / max(len(keys_a), 1)
|
| 198 |
+
|
| 199 |
+
self._state = MedusaState(
|
| 200 |
+
run_id=run_id,
|
| 201 |
+
seed=effective_seed,
|
| 202 |
+
scenario_id=scen.id,
|
| 203 |
+
max_steps=self._max_steps,
|
| 204 |
+
step_idx=0,
|
| 205 |
+
stage="running",
|
| 206 |
+
time_delta_a=scen.time_delta_a,
|
| 207 |
+
time_delta_b=scen.time_delta_b,
|
| 208 |
+
is_stale_a=scen.is_stale_a,
|
| 209 |
+
is_stale_b=scen.is_stale_b,
|
| 210 |
+
null_ratio_key_a=float(null_ratio_a),
|
| 211 |
+
null_ratio_key_b=float(null_ratio_b),
|
| 212 |
+
uniqueness_a=float(uniq_a),
|
| 213 |
+
uniqueness_b=float(uniq_b),
|
| 214 |
+
match_rate=float(match_rate),
|
| 215 |
+
new_cols_a=len(scen.new_cols_a),
|
| 216 |
+
new_cols_b=len(scen.new_cols_b),
|
| 217 |
+
source_a_row_count=len(scen.bronze_a),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
features = _build_features(self._state)
|
| 221 |
+
obs = MedusaObservation(
|
| 222 |
+
message=(
|
| 223 |
+
f"MEDUSA episode started. Scenario: {scen.id}. "
|
| 224 |
+
f"{scen.description} "
|
| 225 |
+
f"Source A: {len(scen.bronze_a)} rows | "
|
| 226 |
+
f"Source B: {len(scen.bronze_b)} rows."
|
| 227 |
+
),
|
| 228 |
+
features=features,
|
| 229 |
+
metrics={
|
| 230 |
+
"scenario_id": scen.id,
|
| 231 |
+
"null_ratio_key_a": null_ratio_a,
|
| 232 |
+
"null_ratio_key_b": null_ratio_b,
|
| 233 |
+
"match_rate": match_rate,
|
| 234 |
+
"is_stale_a": scen.is_stale_a,
|
| 235 |
+
"is_stale_b": scen.is_stale_b,
|
| 236 |
+
"new_cols_a": scen.new_cols_a,
|
| 237 |
+
"new_cols_b": scen.new_cols_b,
|
| 238 |
+
},
|
| 239 |
+
metadata={"run_id": run_id, "seed": effective_seed},
|
| 240 |
+
reward=None,
|
| 241 |
+
done=False,
|
| 242 |
+
)
|
| 243 |
+
return self._apply_transform(obs)
|
| 244 |
+
|
| 245 |
+
# ------------------------------------------------------------------
|
| 246 |
+
# Step
|
| 247 |
+
# ------------------------------------------------------------------
|
| 248 |
+
|
| 249 |
+
def step(
|
| 250 |
+
self,
|
| 251 |
+
action: MedusaAction,
|
| 252 |
+
timeout_s: Optional[float] = None,
|
| 253 |
+
**kwargs: Any,
|
| 254 |
+
) -> MedusaObservation:
|
| 255 |
+
if self._state.stage != "running":
|
| 256 |
+
return self._apply_transform(MedusaObservation(
|
| 257 |
+
message=f"Episode not running (stage={self._state.stage}). Call reset().",
|
| 258 |
+
done=True,
|
| 259 |
+
reward=0.0,
|
| 260 |
+
features=_build_features(self._state),
|
| 261 |
+
metadata={"run_id": self._state.run_id},
|
| 262 |
+
))
|
| 263 |
+
|
| 264 |
+
# Snapshot state *before* applying action (for reward evaluation)
|
| 265 |
+
state_before = copy.copy(self._state)
|
| 266 |
+
self._state.step_idx += 1
|
| 267 |
+
|
| 268 |
+
action_type = action.action
|
| 269 |
+
metrics: dict = {}
|
| 270 |
+
step_message = ""
|
| 271 |
+
|
| 272 |
+
scen = self._scenario
|
| 273 |
+
assert scen is not None, "reset() must be called before step()"
|
| 274 |
+
|
| 275 |
+
# ββ Dispatch ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 276 |
+
try:
|
| 277 |
+
if action_type == MedusaActionType.SYNC_CHECK:
|
| 278 |
+
_, metrics = sync_check(
|
| 279 |
+
self._tables.bronze_a,
|
| 280 |
+
self._tables.bronze_b,
|
| 281 |
+
scen.time_delta_a,
|
| 282 |
+
scen.time_delta_b,
|
| 283 |
+
self._stale_threshold,
|
| 284 |
+
)
|
| 285 |
+
self._state.did_sync_check = True
|
| 286 |
+
step_message = (
|
| 287 |
+
f"SYNC_CHECK: A={scen.time_delta_a:.1f}h "
|
| 288 |
+
f"{'[STALE]' if scen.is_stale_a else '[FRESH]'} | "
|
| 289 |
+
f"B={scen.time_delta_b:.1f}h "
|
| 290 |
+
f"{'[STALE]' if scen.is_stale_b else '[FRESH]'}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
elif action_type == MedusaActionType.EVOLVE_SCHEMA:
|
| 294 |
+
result_df, metrics = evolve_schema(
|
| 295 |
+
self._tables.silver,
|
| 296 |
+
self._tables.bronze_a,
|
| 297 |
+
self._tables.bronze_b,
|
| 298 |
+
scen.new_cols_a,
|
| 299 |
+
scen.new_cols_b,
|
| 300 |
+
)
|
| 301 |
+
if result_df is not None:
|
| 302 |
+
self._tables.silver = result_df
|
| 303 |
+
self._state.did_evolve_schema = True
|
| 304 |
+
step_message = f"EVOLVE_SCHEMA: added {metrics.get('new_cols_count', 0)} column(s)."
|
| 305 |
+
|
| 306 |
+
elif action_type == MedusaActionType.PREP_KEYS_A:
|
| 307 |
+
result_df, metrics = prep_keys(
|
| 308 |
+
self._tables.bronze_a_prepped, scen.join_key
|
| 309 |
+
)
|
| 310 |
+
if result_df is not None:
|
| 311 |
+
self._tables.bronze_a_prepped = result_df
|
| 312 |
+
self._state.did_prep_a = True
|
| 313 |
+
self._state.null_ratio_key_a = float(metrics.get("null_ratio_after", 0.0))
|
| 314 |
+
step_message = (
|
| 315 |
+
f"PREP_KEYS_A: null ratio {metrics.get('null_ratio_before', 0):.2%}"
|
| 316 |
+
f"β{metrics.get('null_ratio_after', 0):.2%}."
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
elif action_type == MedusaActionType.PREP_KEYS_B:
|
| 320 |
+
result_df, metrics = prep_keys(
|
| 321 |
+
self._tables.bronze_b_prepped, scen.join_key
|
| 322 |
+
)
|
| 323 |
+
if result_df is not None:
|
| 324 |
+
self._tables.bronze_b_prepped = result_df
|
| 325 |
+
self._state.did_prep_b = True
|
| 326 |
+
self._state.null_ratio_key_b = float(metrics.get("null_ratio_after", 0.0))
|
| 327 |
+
step_message = (
|
| 328 |
+
f"PREP_KEYS_B: null ratio {metrics.get('null_ratio_before', 0):.2%}"
|
| 329 |
+
f"β{metrics.get('null_ratio_after', 0):.2%}."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
elif action_type == MedusaActionType.DEDUPLICATE_B:
|
| 333 |
+
result_df, metrics = deduplicate(
|
| 334 |
+
self._tables.bronze_b_prepped, scen.join_key
|
| 335 |
+
)
|
| 336 |
+
if result_df is not None:
|
| 337 |
+
self._tables.bronze_b_prepped = result_df
|
| 338 |
+
self._state.did_dedup_b = True
|
| 339 |
+
self._state.uniqueness_b = float(metrics.get("uniqueness", 1.0))
|
| 340 |
+
step_message = f"DEDUPLICATE_B: removed {metrics.get('dupes_removed', 0)} duplicate(s)."
|
| 341 |
+
|
| 342 |
+
elif action_type in {
|
| 343 |
+
MedusaActionType.EXECUTE_JOIN_INNER,
|
| 344 |
+
MedusaActionType.EXECUTE_JOIN_LEFT,
|
| 345 |
+
MedusaActionType.EXECUTE_JOIN_ANTI,
|
| 346 |
+
}:
|
| 347 |
+
join_map = {
|
| 348 |
+
MedusaActionType.EXECUTE_JOIN_INNER: "inner",
|
| 349 |
+
MedusaActionType.EXECUTE_JOIN_LEFT: "left",
|
| 350 |
+
MedusaActionType.EXECUTE_JOIN_ANTI: "anti",
|
| 351 |
+
}
|
| 352 |
+
join_type_str = join_map[action_type]
|
| 353 |
+
joined, quarantine, metrics = execute_join(
|
| 354 |
+
self._tables.bronze_a_prepped,
|
| 355 |
+
self._tables.bronze_b_prepped,
|
| 356 |
+
scen.join_key,
|
| 357 |
+
join_type_str,
|
| 358 |
+
)
|
| 359 |
+
self._tables.joined = joined
|
| 360 |
+
self._tables.quarantine = quarantine
|
| 361 |
+
self._state.did_join = True
|
| 362 |
+
self._state.join_type = join_type_str
|
| 363 |
+
self._state.join_row_count = int(metrics.get("join_rows", 0))
|
| 364 |
+
self._state.explosion_detected = bool(metrics.get("explosion_detected", False))
|
| 365 |
+
self._state.match_rate = float(metrics.get("match_rate", 0.0))
|
| 366 |
+
self._state.quarantine_row_count = len(quarantine)
|
| 367 |
+
step_message = (
|
| 368 |
+
f"EXECUTE_JOIN ({join_type_str.upper()}): "
|
| 369 |
+
f"{self._state.join_row_count} rows | "
|
| 370 |
+
f"match_rate={self._state.match_rate:.1%} | "
|
| 371 |
+
f"quarantine={self._state.quarantine_row_count} | "
|
| 372 |
+
f"{'β EXPLOSION' if self._state.explosion_detected else 'OK'}"
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
elif action_type in {MedusaActionType.APPLY_SCD_1, MedusaActionType.APPLY_SCD_2}:
|
| 376 |
+
scd_type_int = 1 if action_type == MedusaActionType.APPLY_SCD_1 else 2
|
| 377 |
+
tracked_col = scen.tracked_cols[0] if scen.tracked_cols else scen.join_key
|
| 378 |
+
result_df, metrics = apply_scd(
|
| 379 |
+
self._tables.silver,
|
| 380 |
+
self._tables.joined,
|
| 381 |
+
scen.join_key,
|
| 382 |
+
tracked_col,
|
| 383 |
+
scd_type_int,
|
| 384 |
+
)
|
| 385 |
+
if result_df is not None:
|
| 386 |
+
self._tables.silver = result_df
|
| 387 |
+
self._state.did_scd = True
|
| 388 |
+
self._state.scd_type = f"SCD-{scd_type_int}"
|
| 389 |
+
self._state.scd_inserts = int(metrics.get("inserts", 0))
|
| 390 |
+
self._state.scd_updates = int(metrics.get("updates", 0))
|
| 391 |
+
self._state.silver_row_count = int(metrics.get("silver_rows", 0))
|
| 392 |
+
step_message = (
|
| 393 |
+
f"APPLY_SCD-{scd_type_int}: "
|
| 394 |
+
f"{self._state.scd_inserts} inserts, "
|
| 395 |
+
f"{self._state.scd_updates} updates β "
|
| 396 |
+
f"Silver {self._state.silver_row_count} rows."
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
elif action_type == MedusaActionType.COMMIT:
|
| 400 |
+
return self._do_commit(state_before)
|
| 401 |
+
|
| 402 |
+
except Exception as exc: # noqa: BLE001
|
| 403 |
+
step_message = f"ERROR in {action_type}: {exc}"
|
| 404 |
+
metrics = {"error": str(exc)}
|
| 405 |
+
|
| 406 |
+
# ββ Reward ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 407 |
+
reward = self._reward_engine.evaluate(
|
| 408 |
+
action_type=action_type.value,
|
| 409 |
+
metrics=metrics,
|
| 410 |
+
state_before=state_before,
|
| 411 |
+
)
|
| 412 |
+
self._state.cumulative_reward += reward
|
| 413 |
+
|
| 414 |
+
# ββ Governance log ββββββββββββββββββββββββββββββββββββββββββββ
|
| 415 |
+
self._tables.governance_log.append({
|
| 416 |
+
"step": self._state.step_idx,
|
| 417 |
+
"action": action_type.value,
|
| 418 |
+
"reward": reward,
|
| 419 |
+
"cumulative_reward": self._state.cumulative_reward,
|
| 420 |
+
"metrics": metrics,
|
| 421 |
+
"timestamp": time.time(),
|
| 422 |
+
})
|
| 423 |
+
|
| 424 |
+
# Check step limit
|
| 425 |
+
done = self._state.step_idx >= self._state.max_steps
|
| 426 |
+
if done:
|
| 427 |
+
self._state.stage = "failed"
|
| 428 |
+
step_message += " [MAX STEPS REACHED]"
|
| 429 |
+
|
| 430 |
+
features = _build_features(self._state)
|
| 431 |
+
obs = MedusaObservation(
|
| 432 |
+
message=step_message,
|
| 433 |
+
features=features,
|
| 434 |
+
metrics=metrics,
|
| 435 |
+
metadata={
|
| 436 |
+
"run_id": self._state.run_id,
|
| 437 |
+
"step": self._state.step_idx,
|
| 438 |
+
"cumulative_reward": self._state.cumulative_reward,
|
| 439 |
+
},
|
| 440 |
+
reward=reward,
|
| 441 |
+
done=done,
|
| 442 |
+
)
|
| 443 |
+
return self._apply_transform(obs)
|
| 444 |
+
|
| 445 |
+
# ------------------------------------------------------------------
|
| 446 |
+
# Commit (terminal step)
|
| 447 |
+
# ------------------------------------------------------------------
|
| 448 |
+
|
| 449 |
+
def _do_commit(self, state_before: MedusaState) -> MedusaObservation:
|
| 450 |
+
"""Run grader then finalise the episode."""
|
| 451 |
+
scen = self._scenario
|
| 452 |
+
assert scen is not None
|
| 453 |
+
|
| 454 |
+
# Base step reward
|
| 455 |
+
reward = self._reward_engine.evaluate(
|
| 456 |
+
action_type=MedusaActionType.COMMIT.value,
|
| 457 |
+
metrics={},
|
| 458 |
+
state_before=state_before,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Grader audit
|
| 462 |
+
grader_result = self._grader.audit(
|
| 463 |
+
silver=self._tables.silver,
|
| 464 |
+
quarantine=self._tables.quarantine,
|
| 465 |
+
bronze_a=scen.bronze_a,
|
| 466 |
+
bronze_b=scen.bronze_b,
|
| 467 |
+
join_key=scen.join_key,
|
| 468 |
+
join_type=self._state.join_type or "left",
|
| 469 |
+
scd_type=int(self._state.scd_type[-1]) if self._state.scd_type else 1,
|
| 470 |
+
scenario=scen,
|
| 471 |
+
)
|
| 472 |
+
reward += grader_result.bonus_reward
|
| 473 |
+
self._state.grader_passed = grader_result.passed
|
| 474 |
+
self._state.grader_report = grader_result.report
|
| 475 |
+
self._state.cumulative_reward += reward
|
| 476 |
+
self._state.silver_row_count = len(self._tables.silver)
|
| 477 |
+
self._state.quarantine_row_count = len(self._tables.quarantine)
|
| 478 |
+
self._state.stage = "committed"
|
| 479 |
+
|
| 480 |
+
self._tables.governance_log.append({
|
| 481 |
+
"step": self._state.step_idx,
|
| 482 |
+
"action": "COMMIT",
|
| 483 |
+
"reward": reward,
|
| 484 |
+
"cumulative_reward": self._state.cumulative_reward,
|
| 485 |
+
"grader_passed": grader_result.passed,
|
| 486 |
+
"grader_report": grader_result.report,
|
| 487 |
+
"timestamp": time.time(),
|
| 488 |
+
})
|
| 489 |
+
|
| 490 |
+
features = _build_features(self._state)
|
| 491 |
+
obs = MedusaObservation(
|
| 492 |
+
message=(
|
| 493 |
+
f"COMMIT: episode finalized. "
|
| 494 |
+
f"{'Grader: PASS β' if grader_result.passed else 'Grader: FAIL β'} "
|
| 495 |
+
f"Bonus: {grader_result.bonus_reward:+.1f} | "
|
| 496 |
+
f"Total reward: {self._state.cumulative_reward:.1f}"
|
| 497 |
+
),
|
| 498 |
+
features=features,
|
| 499 |
+
metrics={
|
| 500 |
+
"grader_passed": grader_result.passed,
|
| 501 |
+
"grader_report": grader_result.report,
|
| 502 |
+
"silver_rows": self._state.silver_row_count,
|
| 503 |
+
"quarantine_rows": self._state.quarantine_row_count,
|
| 504 |
+
"governance_log_entries": len(self._tables.governance_log),
|
| 505 |
+
},
|
| 506 |
+
metadata={
|
| 507 |
+
"run_id": self._state.run_id,
|
| 508 |
+
"steps": self._state.step_idx,
|
| 509 |
+
"cumulative_reward": self._state.cumulative_reward,
|
| 510 |
+
},
|
| 511 |
+
reward=reward,
|
| 512 |
+
done=True,
|
| 513 |
+
)
|
| 514 |
+
return self._apply_transform(obs)
|
tasks.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDUSA Task Definitions.
|
| 2 |
+
|
| 3 |
+
Three formally graded tasks covering the easy β medium β hard spectrum.
|
| 4 |
+
Each task returns a deterministic score in [0.0, 1.0] after COMMIT.
|
| 5 |
+
|
| 6 |
+
Usage::
|
| 7 |
+
|
| 8 |
+
from envs.medusa_env.tasks import TASKS, score_episode
|
| 9 |
+
|
| 10 |
+
task = TASKS["clean_pipeline"] # easy
|
| 11 |
+
env = MedusaEnv(n_fact_rows=200, n_dim_rows=150)
|
| 12 |
+
obs = env.reset(seed=task.seed)
|
| 13 |
+
|
| 14 |
+
# ... agent takes actions ...
|
| 15 |
+
obs = env.step(MedusaAction(action=MedusaActionType.COMMIT))
|
| 16 |
+
|
| 17 |
+
result = score_episode(task.id, env.state, env._tables)
|
| 18 |
+
print(f"Score: {result.score:.2f} ({result.grade})")
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass, field
|
| 24 |
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from .medusa_env import _EpisodeTables
|
| 28 |
+
from .models import MedusaState
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Task definition
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Task:
|
| 37 |
+
"""A MEDUSA task definition."""
|
| 38 |
+
|
| 39 |
+
id: str
|
| 40 |
+
name: str
|
| 41 |
+
difficulty: str # "easy" | "medium" | "hard"
|
| 42 |
+
seed: int # Controls ScenarioGenerator variant
|
| 43 |
+
description: str
|
| 44 |
+
success_criteria: List[str]
|
| 45 |
+
scoring_rubric: Dict[str, float]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# Scoring result
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class TaskResult:
|
| 54 |
+
"""Outcome of scoring a completed episode against a task."""
|
| 55 |
+
|
| 56 |
+
task_id: str
|
| 57 |
+
score: float # 0.0 β 1.0
|
| 58 |
+
grade: str # "S" | "A" | "B" | "C" | "F"
|
| 59 |
+
breakdown: Dict[str, float] # per-criterion scores
|
| 60 |
+
passed: bool
|
| 61 |
+
notes: List[str] = field(default_factory=list)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _grade(score: float) -> str:
|
| 65 |
+
if score >= 0.90:
|
| 66 |
+
return "S"
|
| 67 |
+
if score >= 0.75:
|
| 68 |
+
return "A"
|
| 69 |
+
if score >= 0.55:
|
| 70 |
+
return "B"
|
| 71 |
+
if score >= 0.35:
|
| 72 |
+
return "C"
|
| 73 |
+
return "F"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
# Task catalogue
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
TASKS: Dict[str, Task] = {
|
| 81 |
+
|
| 82 |
+
# ββ EASY: Clean Pipeline ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
"clean_pipeline": Task(
|
| 84 |
+
id="clean_pipeline",
|
| 85 |
+
name="Clean Pipeline",
|
| 86 |
+
difficulty="easy",
|
| 87 |
+
seed=0,
|
| 88 |
+
description=(
|
| 89 |
+
"Both sources are fresh. Join keys are clean and unique. "
|
| 90 |
+
"The agent must verify freshness, prepare keys, join, apply SCD, "
|
| 91 |
+
"and commit without triggering a row explosion."
|
| 92 |
+
),
|
| 93 |
+
success_criteria=[
|
| 94 |
+
"COMMIT issued (episode finalized)",
|
| 95 |
+
"No Cartesian explosion detected",
|
| 96 |
+
"Silver row count β€ Source A row count",
|
| 97 |
+
"match_rate > 0.80 after join",
|
| 98 |
+
],
|
| 99 |
+
scoring_rubric={
|
| 100 |
+
"committed": 0.20, # Agent issued COMMIT
|
| 101 |
+
"no_explosion": 0.25, # No row explosion
|
| 102 |
+
"volume_ok": 0.20, # Silver β€ Source A rows
|
| 103 |
+
"high_match": 0.20, # match_rate > 0.80
|
| 104 |
+
"grader_pass": 0.15, # All 4 grader checks pass
|
| 105 |
+
},
|
| 106 |
+
),
|
| 107 |
+
|
| 108 |
+
# ββ MEDIUM: Dirty Integration βββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
+
"dirty_integration": Task(
|
| 110 |
+
id="dirty_integration",
|
| 111 |
+
name="Dirty Key Integration",
|
| 112 |
+
difficulty="medium",
|
| 113 |
+
seed=1,
|
| 114 |
+
description=(
|
| 115 |
+
"Source A has NULLs and whitespace in join keys. "
|
| 116 |
+
"Source B has duplicate keys that can cause row explosion. "
|
| 117 |
+
"The agent must PREP_KEYS and DEDUPLICATE before joining, "
|
| 118 |
+
"and correctly quarantine unresolvable orphans."
|
| 119 |
+
),
|
| 120 |
+
success_criteria=[
|
| 121 |
+
"PREP_KEYS_A issued before EXECUTE_JOIN",
|
| 122 |
+
"PREP_KEYS_B issued before EXECUTE_JOIN",
|
| 123 |
+
"DEDUPLICATE_B issued before EXECUTE_JOIN",
|
| 124 |
+
"No row explosion",
|
| 125 |
+
"Quarantine integrity check passes",
|
| 126 |
+
],
|
| 127 |
+
scoring_rubric={
|
| 128 |
+
"committed": 0.10,
|
| 129 |
+
"prepped_before_join": 0.20, # Both PREP_KEYS before join
|
| 130 |
+
"deduped_before_join": 0.20, # DEDUP before join
|
| 131 |
+
"no_explosion": 0.25,
|
| 132 |
+
"integrity_ok": 0.15, # Quarantine holds true orphans only
|
| 133 |
+
"grader_pass": 0.10,
|
| 134 |
+
},
|
| 135 |
+
),
|
| 136 |
+
|
| 137 |
+
# ββ HARD: Full Medallion Integration ββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
"full_medallion": Task(
|
| 139 |
+
id="full_medallion",
|
| 140 |
+
name="Full Medallion Integration",
|
| 141 |
+
difficulty="hard",
|
| 142 |
+
seed=2,
|
| 143 |
+
description=(
|
| 144 |
+
"Source A is stale (>6h old). Source B has new schema columns "
|
| 145 |
+
"not registered in Silver. The agent must: check freshness, "
|
| 146 |
+
"evolve the schema, clean keys, deduplicate, execute a left join, "
|
| 147 |
+
"apply SCD-2 for tracked columns, and pass all grader checks."
|
| 148 |
+
),
|
| 149 |
+
success_criteria=[
|
| 150 |
+
"SYNC_CHECK issued before any join",
|
| 151 |
+
"EVOLVE_SCHEMA issued before COMMIT",
|
| 152 |
+
"SCD-2 applied (not SCD-1) for tracked column",
|
| 153 |
+
"Silver schema contains new columns from drift",
|
| 154 |
+
"All 4 grader checks pass",
|
| 155 |
+
],
|
| 156 |
+
scoring_rubric={
|
| 157 |
+
"committed": 0.05,
|
| 158 |
+
"sync_checked": 0.15, # SYNC_CHECK before join
|
| 159 |
+
"schema_evolved": 0.15, # EVOLVE_SCHEMA called
|
| 160 |
+
"used_scd2": 0.20, # Chose SCD-2 over SCD-1
|
| 161 |
+
"schema_ok": 0.20, # Silver has all required columns
|
| 162 |
+
"grader_pass": 0.25, # All 4 grader checks pass
|
| 163 |
+
},
|
| 164 |
+
),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ---------------------------------------------------------------------------
|
| 169 |
+
# Scoring engine
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
|
| 172 |
+
def score_episode(
|
| 173 |
+
task_id: str,
|
| 174 |
+
state: "MedusaState",
|
| 175 |
+
tables: "Optional[_EpisodeTables]" = None,
|
| 176 |
+
) -> TaskResult:
|
| 177 |
+
"""Score a completed MEDUSA episode against the named task.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
task_id: One of "clean_pipeline", "dirty_integration", "full_medallion".
|
| 181 |
+
state: Final ``MedusaState`` after the episode ended.
|
| 182 |
+
tables: Episode tables (used for schema checks). Optional.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
TaskResult with score in [0.0, 1.0].
|
| 186 |
+
"""
|
| 187 |
+
task = TASKS.get(task_id)
|
| 188 |
+
if task is None:
|
| 189 |
+
raise ValueError(f"Unknown task_id={task_id!r}. Valid: {list(TASKS)}")
|
| 190 |
+
|
| 191 |
+
if state.stage not in ("committed", "failed"):
|
| 192 |
+
return TaskResult(
|
| 193 |
+
task_id=task_id, score=0.0, grade="F",
|
| 194 |
+
breakdown={}, passed=False,
|
| 195 |
+
notes=["Episode not finished β COMMIT was never issued."],
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
breakdown: Dict[str, float] = {}
|
| 199 |
+
notes: List[str] = []
|
| 200 |
+
rubric = task.scoring_rubric
|
| 201 |
+
committed = state.stage == "committed"
|
| 202 |
+
|
| 203 |
+
# ββ Shared criteria ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 204 |
+
if "committed" in rubric:
|
| 205 |
+
breakdown["committed"] = rubric["committed"] if committed else 0.0
|
| 206 |
+
|
| 207 |
+
if "no_explosion" in rubric:
|
| 208 |
+
ok = not state.explosion_detected
|
| 209 |
+
breakdown["no_explosion"] = rubric["no_explosion"] if ok else 0.0
|
| 210 |
+
if not ok:
|
| 211 |
+
notes.append("Row explosion was detected β heavy penalty applied.")
|
| 212 |
+
|
| 213 |
+
if "grader_pass" in rubric:
|
| 214 |
+
breakdown["grader_pass"] = rubric["grader_pass"] if state.grader_passed else 0.0
|
| 215 |
+
|
| 216 |
+
# ββ Task-specific criteria ββββββββββββββββββββββββββββββββββββββββββββ
|
| 217 |
+
|
| 218 |
+
if task_id == "clean_pipeline":
|
| 219 |
+
volume_ok = (
|
| 220 |
+
state.silver_row_count <= state.source_a_row_count * 1.05
|
| 221 |
+
and state.silver_row_count > 0
|
| 222 |
+
)
|
| 223 |
+
breakdown["volume_ok"] = rubric["volume_ok"] if volume_ok else 0.0
|
| 224 |
+
breakdown["high_match"] = rubric["high_match"] if state.match_rate >= 0.80 else 0.0
|
| 225 |
+
if state.match_rate < 0.80:
|
| 226 |
+
notes.append(f"match_rate={state.match_rate:.1%} β target >80%.")
|
| 227 |
+
|
| 228 |
+
elif task_id == "dirty_integration":
|
| 229 |
+
# Both PREP_KEYS before join
|
| 230 |
+
prepped = state.did_prep_a and state.did_prep_b and state.did_join
|
| 231 |
+
breakdown["prepped_before_join"] = rubric["prepped_before_join"] if prepped else 0.0
|
| 232 |
+
# DEDUP before join
|
| 233 |
+
deduped = state.did_dedup_b and state.did_join
|
| 234 |
+
breakdown["deduped_before_join"] = rubric["deduped_before_join"] if deduped else 0.0
|
| 235 |
+
# Integrity check comes from grader
|
| 236 |
+
integrity_ok = state.grader_passed or (
|
| 237 |
+
state.quarantine_row_count >= 0 # grader_passed already covers this
|
| 238 |
+
)
|
| 239 |
+
# Use grader_passed as proxy for integrity
|
| 240 |
+
breakdown["integrity_ok"] = rubric["integrity_ok"] if state.grader_passed else 0.0
|
| 241 |
+
if not prepped:
|
| 242 |
+
notes.append("Agent joined without prepping keys first.")
|
| 243 |
+
if not deduped:
|
| 244 |
+
notes.append("Agent joined without deduplicating Dimension.")
|
| 245 |
+
|
| 246 |
+
elif task_id == "full_medallion":
|
| 247 |
+
breakdown["sync_checked"] = rubric["sync_checked"] if state.did_sync_check else 0.0
|
| 248 |
+
breakdown["schema_evolved"] = rubric["schema_evolved"] if state.did_evolve_schema else 0.0
|
| 249 |
+
used_scd2 = state.scd_type == "SCD-2"
|
| 250 |
+
breakdown["used_scd2"] = rubric["used_scd2"] if used_scd2 else 0.0
|
| 251 |
+
breakdown["schema_ok"] = rubric["schema_ok"] if state.grader_passed else 0.0
|
| 252 |
+
if not state.did_sync_check:
|
| 253 |
+
notes.append("SYNC_CHECK was never called β stale source not verified.")
|
| 254 |
+
if not state.did_evolve_schema:
|
| 255 |
+
notes.append("EVOLVE_SCHEMA never called β new columns may be missing from Silver.")
|
| 256 |
+
if not used_scd2:
|
| 257 |
+
notes.append(f"Used SCD-1 instead of SCD-2 (scd_type={state.scd_type!r}).")
|
| 258 |
+
|
| 259 |
+
# ββ Final score βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 260 |
+
total = sum(breakdown.values())
|
| 261 |
+
# Clip to [0, 1] (row explosion can make total negative from reward engine)
|
| 262 |
+
score = max(0.0, min(1.0, total))
|
| 263 |
+
passed = score >= 0.55
|
| 264 |
+
|
| 265 |
+
return TaskResult(
|
| 266 |
+
task_id=task_id,
|
| 267 |
+
score=round(score, 4),
|
| 268 |
+
grade=_grade(score),
|
| 269 |
+
breakdown=breakdown,
|
| 270 |
+
passed=passed,
|
| 271 |
+
notes=notes,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# ---------------------------------------------------------------------------
|
| 276 |
+
# Convenience: score all tasks
|
| 277 |
+
# ---------------------------------------------------------------------------
|
| 278 |
+
|
| 279 |
+
def score_all_tasks(
|
| 280 |
+
results: Dict[str, tuple], # task_id β (state, tables)
|
| 281 |
+
) -> Dict[str, TaskResult]:
|
| 282 |
+
"""Score multiple completed episodes, one per task."""
|
| 283 |
+
return {
|
| 284 |
+
task_id: score_episode(task_id, state, tables)
|
| 285 |
+
for task_id, (state, tables) in results.items()
|
| 286 |
+
}
|
tests/test_medusa_environment.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the MEDUSA environment.
|
| 2 |
+
|
| 3 |
+
Covers: models, scenario generator, operators, reward engine, grader,
|
| 4 |
+
and full end-to-end environment episodes.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# Models
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
from medusa_env.models import (
|
| 16 |
+
MedusaAction,
|
| 17 |
+
MedusaActionType,
|
| 18 |
+
MedusaObservation,
|
| 19 |
+
MedusaState,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TestMedusaModels:
|
| 24 |
+
def test_action_creation(self):
|
| 25 |
+
a = MedusaAction(action=MedusaActionType.SYNC_CHECK)
|
| 26 |
+
assert a.action == MedusaActionType.SYNC_CHECK
|
| 27 |
+
assert a.params == {}
|
| 28 |
+
|
| 29 |
+
def test_state_defaults(self):
|
| 30 |
+
s = MedusaState()
|
| 31 |
+
assert s.stage == "init"
|
| 32 |
+
assert s.step_idx == 0
|
| 33 |
+
assert s.did_sync_check is False
|
| 34 |
+
assert s.explosion_detected is False
|
| 35 |
+
assert s.grader_passed is False
|
| 36 |
+
|
| 37 |
+
def test_observation_defaults(self):
|
| 38 |
+
obs = MedusaObservation()
|
| 39 |
+
assert obs.done is False
|
| 40 |
+
assert obs.reward is None
|
| 41 |
+
assert obs.features == []
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# Scenario Generator
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
import pandas as pd
|
| 49 |
+
|
| 50 |
+
from medusa_env.scenarios import Scenario, ScenarioGenerator
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TestMedusaScenarios:
|
| 54 |
+
@pytest.fixture
|
| 55 |
+
def gen(self):
|
| 56 |
+
return ScenarioGenerator(n_fact_rows=50, n_dim_rows=40)
|
| 57 |
+
|
| 58 |
+
def test_canonical_clean(self, gen):
|
| 59 |
+
scen = gen.generate(seed=0)
|
| 60 |
+
assert scen.id.startswith("clean")
|
| 61 |
+
assert isinstance(scen.bronze_a, pd.DataFrame)
|
| 62 |
+
assert len(scen.bronze_a) == 50
|
| 63 |
+
assert not scen.is_stale_a
|
| 64 |
+
assert not scen.is_stale_b
|
| 65 |
+
assert scen.new_cols_a == []
|
| 66 |
+
|
| 67 |
+
def test_canonical_dirty_keys(self, gen):
|
| 68 |
+
scen = gen.generate(seed=1)
|
| 69 |
+
assert "dirty_keys" in scen.id
|
| 70 |
+
# Dirty scenario should have actual null or whitespace keys
|
| 71 |
+
has_issues = (
|
| 72 |
+
scen.bronze_a[scen.join_key].isna().any()
|
| 73 |
+
or scen.bronze_a[scen.join_key].astype(str).str.contains(r"^\s|\s$").any()
|
| 74 |
+
)
|
| 75 |
+
assert has_issues
|
| 76 |
+
|
| 77 |
+
def test_canonical_stale(self, gen):
|
| 78 |
+
scen = gen.generate(seed=2)
|
| 79 |
+
assert "stale" in scen.id
|
| 80 |
+
assert scen.is_stale_a # Source A should be stale
|
| 81 |
+
|
| 82 |
+
def test_canonical_schema_drift(self, gen):
|
| 83 |
+
scen = gen.generate(seed=3)
|
| 84 |
+
assert "schema_drift" in scen.id
|
| 85 |
+
assert len(scen.new_cols_a) > 0
|
| 86 |
+
assert len(scen.new_cols_b) > 0
|
| 87 |
+
|
| 88 |
+
def test_random_seed_produces_scenario(self, gen):
|
| 89 |
+
scen = gen.generate(seed=999)
|
| 90 |
+
assert isinstance(scen, Scenario)
|
| 91 |
+
assert scen.join_key in scen.bronze_a.columns
|
| 92 |
+
assert scen.join_key in scen.bronze_b.columns
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Operators
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
from medusa_env.operators import (
|
| 100 |
+
apply_scd,
|
| 101 |
+
deduplicate,
|
| 102 |
+
evolve_schema,
|
| 103 |
+
execute_join,
|
| 104 |
+
prep_keys,
|
| 105 |
+
sync_check,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class TestMedusaOperators:
|
| 110 |
+
def test_sync_check_fresh(self):
|
| 111 |
+
a = pd.DataFrame({"id": [1, 2]})
|
| 112 |
+
b = pd.DataFrame({"id": [1, 2]})
|
| 113 |
+
_, m = sync_check(a, b, time_delta_a=1.0, time_delta_b=2.0)
|
| 114 |
+
assert m["is_stale_a"] is False
|
| 115 |
+
assert m["is_stale_b"] is False
|
| 116 |
+
|
| 117 |
+
def test_sync_check_stale(self):
|
| 118 |
+
a = pd.DataFrame({"id": [1]})
|
| 119 |
+
b = pd.DataFrame({"id": [1]})
|
| 120 |
+
_, m = sync_check(a, b, time_delta_a=10.0, time_delta_b=1.0)
|
| 121 |
+
assert m["is_stale_a"] is True
|
| 122 |
+
assert m["is_stale_b"] is False
|
| 123 |
+
|
| 124 |
+
def test_prep_keys_strips_whitespace(self):
|
| 125 |
+
df = pd.DataFrame({"key": [" K001 ", "K002", None]})
|
| 126 |
+
result, m = prep_keys(df, "key")
|
| 127 |
+
# Stripped key should have no leading/trailing spaces
|
| 128 |
+
non_null = result["key"].dropna().tolist()
|
| 129 |
+
assert all(v.strip() == v for v in non_null)
|
| 130 |
+
assert m["null_ratio_before"] > 0
|
| 131 |
+
|
| 132 |
+
def test_deduplicate_removes_dupes(self):
|
| 133 |
+
df = pd.DataFrame({"key": ["A", "A", "B"], "val": [1, 2, 3]})
|
| 134 |
+
result, m = deduplicate(df, "key")
|
| 135 |
+
assert m["dupes_removed"] == 1
|
| 136 |
+
assert len(result) == 2
|
| 137 |
+
|
| 138 |
+
def test_execute_join_left_basic(self):
|
| 139 |
+
fact = pd.DataFrame({"key": ["K001", "K002", "K003"], "val": [1, 2, 3]})
|
| 140 |
+
dim = pd.DataFrame({"key": ["K001", "K002"], "dim_name": ["A", "B"]})
|
| 141 |
+
joined, quarantine, m = execute_join(fact, dim, "key", "left")
|
| 142 |
+
assert m["join_rows"] == 3 # left join keeps all fact rows
|
| 143 |
+
assert m["match_rate"] == pytest.approx(2 / 3, abs=0.01)
|
| 144 |
+
assert len(quarantine) >= 1 # K003 should be quarantined
|
| 145 |
+
|
| 146 |
+
def test_execute_join_detects_explosion(self):
|
| 147 |
+
# Non-unique dim key β Cartesian explosion
|
| 148 |
+
fact = pd.DataFrame({"key": ["K001"] * 10, "val": list(range(10))})
|
| 149 |
+
dim = pd.DataFrame({"key": ["K001"] * 20, "dim_name": ["X"] * 20})
|
| 150 |
+
joined, quarantine, m = execute_join(fact, dim, "key", "inner")
|
| 151 |
+
assert m["explosion_detected"] is True
|
| 152 |
+
|
| 153 |
+
def test_execute_join_anti(self):
|
| 154 |
+
fact = pd.DataFrame({"key": ["K001", "K002", "K999"], "val": [1, 2, 3]})
|
| 155 |
+
dim = pd.DataFrame({"key": ["K001", "K002"], "name": ["A", "B"]})
|
| 156 |
+
joined, quarantine, m = execute_join(fact, dim, "key", "anti")
|
| 157 |
+
assert len(joined) == 0 # Anti-join: no rows in joined
|
| 158 |
+
assert len(quarantine) == 1 # K999 goes to quarantine
|
| 159 |
+
|
| 160 |
+
def test_apply_scd1_upsert(self):
|
| 161 |
+
silver = pd.DataFrame({"key": ["K001"], "val": [10], "status": ["old"]})
|
| 162 |
+
joined = pd.DataFrame({"key": ["K001", "K002"], "val": [99, 20], "status": ["new", "new"]})
|
| 163 |
+
result, m = apply_scd(silver, joined, "key", "status", scd_type=1)
|
| 164 |
+
assert m["scd_type"] == 1
|
| 165 |
+
assert m["inserts"] + m["updates"] > 0
|
| 166 |
+
# K001 should be updated to val=99
|
| 167 |
+
k1_row = result[result["key"] == "K001"]
|
| 168 |
+
assert not k1_row.empty
|
| 169 |
+
|
| 170 |
+
def test_apply_scd2_adds_history(self):
|
| 171 |
+
silver = pd.DataFrame()
|
| 172 |
+
joined = pd.DataFrame({"key": ["K001"], "status": ["active"]})
|
| 173 |
+
result, m = apply_scd(silver, joined, "key", "status", scd_type=2)
|
| 174 |
+
assert "valid_from" in result.columns
|
| 175 |
+
assert m["inserts"] == 1
|
| 176 |
+
|
| 177 |
+
def test_evolve_schema_adds_columns(self):
|
| 178 |
+
silver = pd.DataFrame({"key": ["K001"], "val": [1]})
|
| 179 |
+
a = pd.DataFrame({"key": ["K001"], "new_metric": [42]})
|
| 180 |
+
b = pd.DataFrame({"key": ["K001"]})
|
| 181 |
+
result, m = evolve_schema(silver, a, b, ["new_metric"], [])
|
| 182 |
+
assert "new_metric" in result.columns
|
| 183 |
+
assert m["new_cols_count"] == 1
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
# Reward Engine
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
|
| 190 |
+
from medusa_env.rewards import RewardEngine
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class TestMedusaRewards:
|
| 194 |
+
@pytest.fixture
|
| 195 |
+
def engine(self):
|
| 196 |
+
return RewardEngine()
|
| 197 |
+
|
| 198 |
+
def _clean_state(self):
|
| 199 |
+
s = MedusaState()
|
| 200 |
+
s.did_prep_a = True
|
| 201 |
+
s.did_prep_b = True
|
| 202 |
+
s.did_sync_check = True
|
| 203 |
+
return s
|
| 204 |
+
|
| 205 |
+
def test_step_penalty_always_applied(self, engine):
|
| 206 |
+
r = engine.evaluate("SYNC_CHECK", {}, MedusaState())
|
| 207 |
+
assert r == pytest.approx(-0.2, abs=0.01)
|
| 208 |
+
|
| 209 |
+
def test_high_match_join_reward(self, engine):
|
| 210 |
+
r = engine.evaluate(
|
| 211 |
+
"EXECUTE_JOIN_LEFT",
|
| 212 |
+
{"match_rate": 0.95, "join_rows": 100, "fact_rows": 100,
|
| 213 |
+
"explosion_detected": False, "quarantine_rows": 5},
|
| 214 |
+
self._clean_state(),
|
| 215 |
+
)
|
| 216 |
+
assert r > 0.0 # +25 - 0.2 + 10 (quarantine) = +34.8
|
| 217 |
+
|
| 218 |
+
def test_row_explosion_heavy_penalty(self, engine):
|
| 219 |
+
r = engine.evaluate(
|
| 220 |
+
"EXECUTE_JOIN_INNER",
|
| 221 |
+
{"explosion_detected": True, "join_rows": 1000, "fact_rows": 100,
|
| 222 |
+
"match_rate": 1.0, "quarantine_rows": 0},
|
| 223 |
+
self._clean_state(),
|
| 224 |
+
)
|
| 225 |
+
assert r < -50.0
|
| 226 |
+
|
| 227 |
+
def test_dirty_join_penalty(self, engine):
|
| 228 |
+
# No PREP_KEYS β dirty join penalty
|
| 229 |
+
state = MedusaState()
|
| 230 |
+
state.did_prep_a = False
|
| 231 |
+
state.did_prep_b = False
|
| 232 |
+
r = engine.evaluate(
|
| 233 |
+
"EXECUTE_JOIN_LEFT",
|
| 234 |
+
{"explosion_detected": False, "join_rows": 0, "fact_rows": 50,
|
| 235 |
+
"match_rate": 0.0, "quarantine_rows": 0},
|
| 236 |
+
state,
|
| 237 |
+
)
|
| 238 |
+
assert r < -20.0
|
| 239 |
+
|
| 240 |
+
def test_scd2_extra_reward(self, engine):
|
| 241 |
+
r = engine.evaluate("APPLY_SCD_2", {}, self._clean_state())
|
| 242 |
+
# +5 for SCD-2 - 0.2 step penalty
|
| 243 |
+
assert r == pytest.approx(4.8, abs=0.01)
|
| 244 |
+
|
| 245 |
+
def test_stale_processing_penalty(self, engine):
|
| 246 |
+
state = MedusaState()
|
| 247 |
+
state.is_stale_a = True
|
| 248 |
+
state.did_sync_check = False # Never checked freshness
|
| 249 |
+
state.did_prep_a = True
|
| 250 |
+
state.did_prep_b = True
|
| 251 |
+
r = engine.evaluate(
|
| 252 |
+
"EXECUTE_JOIN_LEFT",
|
| 253 |
+
{"explosion_detected": False, "join_rows": 100, "fact_rows": 100,
|
| 254 |
+
"match_rate": 0.95, "quarantine_rows": 0},
|
| 255 |
+
state,
|
| 256 |
+
)
|
| 257 |
+
# Should include stale penalty on top of positive join reward
|
| 258 |
+
assert r < 25.0 # Stale penalty reduces it
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# ---------------------------------------------------------------------------
|
| 262 |
+
# Grader
|
| 263 |
+
# ---------------------------------------------------------------------------
|
| 264 |
+
|
| 265 |
+
from medusa_env.grader import Grader
|
| 266 |
+
from medusa_env.scenarios import Scenario
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class TestMedusaGrader:
|
| 270 |
+
@pytest.fixture
|
| 271 |
+
def grader(self):
|
| 272 |
+
return Grader()
|
| 273 |
+
|
| 274 |
+
def _make_scenario(self):
|
| 275 |
+
a = pd.DataFrame({"entity_id": ["K1", "K2", "K3"], "val": [1, 2, 3],
|
| 276 |
+
"fact_category": ["A", "B", "C"],
|
| 277 |
+
"fact_value": [1.0, 2.0, 3.0],
|
| 278 |
+
"created_at": pd.date_range("2024-01-01", periods=3, freq="h")})
|
| 279 |
+
b = pd.DataFrame({"entity_id": ["K1", "K2"], "dim_name": ["N1", "N2"], "dim_status": ["x", "y"]})
|
| 280 |
+
return a, b
|
| 281 |
+
|
| 282 |
+
def test_volume_check_pass(self, grader):
|
| 283 |
+
a, b = self._make_scenario()
|
| 284 |
+
silver = pd.DataFrame({"entity_id": ["K1", "K2"], "val": [1, 2]})
|
| 285 |
+
scen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2).generate(seed=0)
|
| 286 |
+
r = grader.audit(silver, pd.DataFrame(), a, b, "entity_id", "left", 1, scen)
|
| 287 |
+
assert r.volume_ok is True
|
| 288 |
+
|
| 289 |
+
def test_volume_check_fail(self, grader):
|
| 290 |
+
a, b = self._make_scenario()
|
| 291 |
+
# Silver has way more rows than source A β violation
|
| 292 |
+
silver = pd.DataFrame({"entity_id": ["K1"] * 100})
|
| 293 |
+
scen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2).generate(seed=0)
|
| 294 |
+
r = grader.audit(silver, pd.DataFrame(), a, b, "entity_id", "left", 1, scen)
|
| 295 |
+
assert r.volume_ok is False
|
| 296 |
+
|
| 297 |
+
def test_integrity_check_quarantine_true_orphans(self, grader):
|
| 298 |
+
a, b = self._make_scenario()
|
| 299 |
+
# K3 is not in B β true orphan
|
| 300 |
+
quarantine = pd.DataFrame({"entity_id": ["K3"]})
|
| 301 |
+
scen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2).generate(seed=0)
|
| 302 |
+
silver = pd.DataFrame({"entity_id": ["K1", "K2"]})
|
| 303 |
+
r = grader.audit(silver, quarantine, a, b, "entity_id", "left", 1, scen)
|
| 304 |
+
assert r.integrity_ok is True
|
| 305 |
+
|
| 306 |
+
def test_integrity_check_fail_dirty_quarantine(self, grader):
|
| 307 |
+
a, b = self._make_scenario()
|
| 308 |
+
# K1 IS in B but ends up in quarantine (agent failed to clean it)
|
| 309 |
+
quarantine = pd.DataFrame({"entity_id": ["K1"]})
|
| 310 |
+
scen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2).generate(seed=0)
|
| 311 |
+
silver = pd.DataFrame({"entity_id": ["K2"]})
|
| 312 |
+
r = grader.audit(silver, quarantine, a, b, "entity_id", "left", 1, scen)
|
| 313 |
+
assert r.integrity_ok is False
|
| 314 |
+
|
| 315 |
+
def test_all_pass_gives_bonus(self, grader):
|
| 316 |
+
gen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2)
|
| 317 |
+
scen = gen.generate(seed=0)
|
| 318 |
+
a, b = scen.bronze_a, scen.bronze_b
|
| 319 |
+
# Simulate a perfect run
|
| 320 |
+
silver = a.merge(b, on="entity_id", how="left")
|
| 321 |
+
r = grader.audit(silver, pd.DataFrame(), a, b, "entity_id", "left", 1, scen)
|
| 322 |
+
assert r.bonus_reward > 0
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# ---------------------------------------------------------------------------
|
| 326 |
+
# Full environment integration
|
| 327 |
+
# ---------------------------------------------------------------------------
|
| 328 |
+
|
| 329 |
+
from medusa_env.server import MedusaEnv
|
| 330 |
+
from medusa_env.models import MedusaActionType
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class TestMedusaEnvironment:
|
| 334 |
+
@pytest.fixture
|
| 335 |
+
def env(self):
|
| 336 |
+
return MedusaEnv(n_fact_rows=50, n_dim_rows=40)
|
| 337 |
+
|
| 338 |
+
def test_reset_returns_observation(self, env):
|
| 339 |
+
obs = env.reset(seed=0)
|
| 340 |
+
assert isinstance(obs, MedusaObservation)
|
| 341 |
+
assert obs.done is False
|
| 342 |
+
assert len(obs.features) == 16
|
| 343 |
+
assert obs.reward is None
|
| 344 |
+
|
| 345 |
+
def test_state_after_reset(self, env):
|
| 346 |
+
env.reset(seed=0)
|
| 347 |
+
state = env.state
|
| 348 |
+
assert state.stage == "running"
|
| 349 |
+
assert state.step_idx == 0
|
| 350 |
+
assert state.source_a_row_count == 50
|
| 351 |
+
|
| 352 |
+
def test_happy_path_episode(self, env):
|
| 353 |
+
"""Full pipeline: sync β evolve β prep both β dedup β join β scd β commit."""
|
| 354 |
+
env.reset(seed=0) # clean scenario
|
| 355 |
+
|
| 356 |
+
actions = [
|
| 357 |
+
MedusaActionType.SYNC_CHECK,
|
| 358 |
+
MedusaActionType.EVOLVE_SCHEMA,
|
| 359 |
+
MedusaActionType.PREP_KEYS_A,
|
| 360 |
+
MedusaActionType.PREP_KEYS_B,
|
| 361 |
+
MedusaActionType.DEDUPLICATE_B,
|
| 362 |
+
MedusaActionType.EXECUTE_JOIN_LEFT,
|
| 363 |
+
MedusaActionType.APPLY_SCD_2,
|
| 364 |
+
MedusaActionType.COMMIT,
|
| 365 |
+
]
|
| 366 |
+
obs = None
|
| 367 |
+
for act_type in actions:
|
| 368 |
+
obs = env.step(MedusaAction(action=act_type))
|
| 369 |
+
|
| 370 |
+
assert obs is not None
|
| 371 |
+
assert obs.done is True
|
| 372 |
+
assert env.state.stage == "committed"
|
| 373 |
+
assert env.state.grader_passed # Clean scenario should pass grader
|
| 374 |
+
|
| 375 |
+
def test_row_explosion_gives_heavy_penalty(self, env):
|
| 376 |
+
"""Joining on non-unique B keys should trigger explosion penalty."""
|
| 377 |
+
env.reset(seed=1) # dirty_keys scenario β B has duplicate keys
|
| 378 |
+
|
| 379 |
+
# Skip prep & dedup β go straight to join
|
| 380 |
+
env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
|
| 381 |
+
|
| 382 |
+
# Force the dimension to have many duplicates so explosion fires
|
| 383 |
+
import pandas as _pd
|
| 384 |
+
|
| 385 |
+
env._tables.bronze_b_prepped = _pd.DataFrame({
|
| 386 |
+
"entity_id": ["K001"] * 30,
|
| 387 |
+
"dim_name": ["X"] * 30,
|
| 388 |
+
"dim_status": ["x"] * 30,
|
| 389 |
+
})
|
| 390 |
+
env._tables.bronze_a_prepped = _pd.DataFrame({
|
| 391 |
+
"entity_id": ["K001"] * 10,
|
| 392 |
+
"fact_value": list(range(10)),
|
| 393 |
+
"fact_category": ["A"] * 10,
|
| 394 |
+
"created_at": _pd.date_range("2024-01-01", periods=10, freq="h"),
|
| 395 |
+
})
|
| 396 |
+
|
| 397 |
+
obs = env.step(MedusaAction(action=MedusaActionType.EXECUTE_JOIN_INNER))
|
| 398 |
+
assert obs.reward is not None
|
| 399 |
+
assert obs.reward < -50.0
|
| 400 |
+
assert env.state.explosion_detected is True
|
| 401 |
+
|
| 402 |
+
def test_dirty_join_penalty(self, env):
|
| 403 |
+
"""Skipping PREP_KEYS and joining on null-heavy keys β dirty join."""
|
| 404 |
+
env.reset(seed=1) # dirty_keys scenario
|
| 405 |
+
|
| 406 |
+
# Skip PREP β join directly
|
| 407 |
+
obs = env.step(MedusaAction(action=MedusaActionType.EXECUTE_JOIN_LEFT))
|
| 408 |
+
# If all fact keys are null/non-matching β 0-row join β dirty join penalty
|
| 409 |
+
# (reward < base -0.2 if dirty join fired)
|
| 410 |
+
assert obs.reward is not None
|
| 411 |
+
|
| 412 |
+
def test_step_idx_increments(self, env):
|
| 413 |
+
env.reset(seed=0)
|
| 414 |
+
for _ in range(3):
|
| 415 |
+
env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
|
| 416 |
+
assert env.state.step_idx == 3
|
| 417 |
+
|
| 418 |
+
def test_max_steps_terminates_episode(self):
|
| 419 |
+
env = MedusaEnv(n_fact_rows=10, n_dim_rows=10, max_steps=3)
|
| 420 |
+
env.reset(seed=0)
|
| 421 |
+
obs = None
|
| 422 |
+
for _ in range(4): # more than max_steps
|
| 423 |
+
obs = env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
|
| 424 |
+
assert obs is not None
|
| 425 |
+
assert obs.done is True
|
| 426 |
+
|
| 427 |
+
def test_commit_without_join_grader_fails(self, env):
|
| 428 |
+
"""Committing without joining should make the grader fail."""
|
| 429 |
+
env.reset(seed=0)
|
| 430 |
+
env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
|
| 431 |
+
obs = env.step(MedusaAction(action=MedusaActionType.COMMIT))
|
| 432 |
+
assert obs.done is True
|
| 433 |
+
# Silver will be empty β schema check should fail or volume check fail
|
| 434 |
+
assert env.state.grader_report != ""
|
| 435 |
+
|
| 436 |
+
def test_features_vector_length(self, env):
|
| 437 |
+
env.reset(seed=0)
|
| 438 |
+
obs = env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
|
| 439 |
+
assert len(obs.features) == 16
|
| 440 |
+
assert all(0.0 <= f <= 1.0 for f in obs.features)
|
| 441 |
+
|
| 442 |
+
def test_governance_log_populated(self, env):
|
| 443 |
+
env.reset(seed=0)
|
| 444 |
+
env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
|
| 445 |
+
env.step(MedusaAction(action=MedusaActionType.PREP_KEYS_A))
|
| 446 |
+
log = env._tables.governance_log
|
| 447 |
+
assert len(log) == 2
|
| 448 |
+
assert log[0]["action"] == "SYNC_CHECK"
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# ---------------------------------------------------------------------------
|
| 452 |
+
# Task Scorer
|
| 453 |
+
# ---------------------------------------------------------------------------
|
| 454 |
+
|
| 455 |
+
from medusa_env.tasks import TASKS, score_episode
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class TestMedusaTasks:
|
| 459 |
+
"""Tests for the 3 formal task definitions and 0.0β1.0 scorer."""
|
| 460 |
+
|
| 461 |
+
def test_three_tasks_defined(self):
|
| 462 |
+
assert "clean_pipeline" in TASKS
|
| 463 |
+
assert "dirty_integration" in TASKS
|
| 464 |
+
assert "full_medallion" in TASKS
|
| 465 |
+
|
| 466 |
+
def test_task_difficulties(self):
|
| 467 |
+
assert TASKS["clean_pipeline"].difficulty == "easy"
|
| 468 |
+
assert TASKS["dirty_integration"].difficulty == "medium"
|
| 469 |
+
assert TASKS["full_medallion"].difficulty == "hard"
|
| 470 |
+
|
| 471 |
+
def test_task_seeds_match_scenarios(self):
|
| 472 |
+
assert TASKS["clean_pipeline"].seed == 0
|
| 473 |
+
assert TASKS["dirty_integration"].seed == 1
|
| 474 |
+
assert TASKS["full_medallion"].seed == 2
|
| 475 |
+
|
| 476 |
+
def _run_happy_path(self, seed: int) -> MedusaState:
|
| 477 |
+
"""Run the optimal action sequence for the given seed and return final state."""
|
| 478 |
+
env = MedusaEnv(n_fact_rows=50, n_dim_rows=40)
|
| 479 |
+
env.reset(seed=seed)
|
| 480 |
+
for act in [
|
| 481 |
+
MedusaActionType.SYNC_CHECK,
|
| 482 |
+
MedusaActionType.EVOLVE_SCHEMA,
|
| 483 |
+
MedusaActionType.PREP_KEYS_A,
|
| 484 |
+
MedusaActionType.PREP_KEYS_B,
|
| 485 |
+
MedusaActionType.DEDUPLICATE_B,
|
| 486 |
+
MedusaActionType.EXECUTE_JOIN_LEFT,
|
| 487 |
+
MedusaActionType.APPLY_SCD_2,
|
| 488 |
+
MedusaActionType.COMMIT,
|
| 489 |
+
]:
|
| 490 |
+
env.step(MedusaAction(action=act))
|
| 491 |
+
return env.state
|
| 492 |
+
|
| 493 |
+
# ββ clean_pipeline (easy) βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 494 |
+
|
| 495 |
+
def test_clean_pipeline_score_is_in_range(self):
|
| 496 |
+
state = self._run_happy_path(seed=0)
|
| 497 |
+
result = score_episode("clean_pipeline", state)
|
| 498 |
+
assert 0.0 <= result.score <= 1.0
|
| 499 |
+
|
| 500 |
+
def test_clean_pipeline_happy_path_passes(self):
|
| 501 |
+
state = self._run_happy_path(seed=0)
|
| 502 |
+
result = score_episode("clean_pipeline", state)
|
| 503 |
+
assert result.passed is True
|
| 504 |
+
assert result.grade in ("S", "A", "B")
|
| 505 |
+
|
| 506 |
+
def test_clean_pipeline_uncommitted_scores_zero(self):
|
| 507 |
+
state = MedusaState(stage="running")
|
| 508 |
+
result = score_episode("clean_pipeline", state)
|
| 509 |
+
assert result.score == 0.0
|
| 510 |
+
assert result.grade == "F"
|
| 511 |
+
|
| 512 |
+
def test_clean_pipeline_explosion_detected_lowers_score(self):
|
| 513 |
+
state = MedusaState(
|
| 514 |
+
stage="committed",
|
| 515 |
+
explosion_detected=True,
|
| 516 |
+
silver_row_count=0,
|
| 517 |
+
source_a_row_count=50,
|
| 518 |
+
match_rate=0.0,
|
| 519 |
+
grader_passed=False,
|
| 520 |
+
)
|
| 521 |
+
result = score_episode("clean_pipeline", state)
|
| 522 |
+
assert result.breakdown["no_explosion"] == 0.0
|
| 523 |
+
|
| 524 |
+
# ββ dirty_integration (medium) βββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββ
|
| 525 |
+
|
| 526 |
+
def test_dirty_integration_score_is_in_range(self):
|
| 527 |
+
state = self._run_happy_path(seed=1)
|
| 528 |
+
result = score_episode("dirty_integration", state)
|
| 529 |
+
assert 0.0 <= result.score <= 1.0
|
| 530 |
+
|
| 531 |
+
def test_dirty_integration_without_prep_penalized(self):
|
| 532 |
+
state = MedusaState(
|
| 533 |
+
stage="committed",
|
| 534 |
+
did_prep_a=False,
|
| 535 |
+
did_prep_b=False,
|
| 536 |
+
did_dedup_b=False,
|
| 537 |
+
did_join=True,
|
| 538 |
+
explosion_detected=False,
|
| 539 |
+
grader_passed=False,
|
| 540 |
+
)
|
| 541 |
+
result = score_episode("dirty_integration", state)
|
| 542 |
+
assert result.breakdown["prepped_before_join"] == 0.0
|
| 543 |
+
assert result.breakdown["deduped_before_join"] == 0.0
|
| 544 |
+
|
| 545 |
+
def test_dirty_integration_with_all_prereqs_scores_higher(self):
|
| 546 |
+
state_no_prep = MedusaState(
|
| 547 |
+
stage="committed", did_prep_a=False, did_prep_b=False,
|
| 548 |
+
did_dedup_b=False, did_join=True, explosion_detected=False, grader_passed=False,
|
| 549 |
+
)
|
| 550 |
+
state_prepped = MedusaState(
|
| 551 |
+
stage="committed", did_prep_a=True, did_prep_b=True,
|
| 552 |
+
did_dedup_b=True, did_join=True, explosion_detected=False, grader_passed=True,
|
| 553 |
+
)
|
| 554 |
+
no_prep = score_episode("dirty_integration", state_no_prep)
|
| 555 |
+
prepped = score_episode("dirty_integration", state_prepped)
|
| 556 |
+
assert prepped.score > no_prep.score
|
| 557 |
+
|
| 558 |
+
# ββ full_medallion (hard) βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 559 |
+
|
| 560 |
+
def test_full_medallion_score_is_in_range(self):
|
| 561 |
+
state = self._run_happy_path(seed=2)
|
| 562 |
+
result = score_episode("full_medallion", state)
|
| 563 |
+
assert 0.0 <= result.score <= 1.0
|
| 564 |
+
|
| 565 |
+
def test_full_medallion_without_sync_penalized(self):
|
| 566 |
+
state = MedusaState(
|
| 567 |
+
stage="committed",
|
| 568 |
+
did_sync_check=False,
|
| 569 |
+
did_evolve_schema=True,
|
| 570 |
+
scd_type="SCD-2",
|
| 571 |
+
grader_passed=True,
|
| 572 |
+
)
|
| 573 |
+
result = score_episode("full_medallion", state)
|
| 574 |
+
assert result.breakdown["sync_checked"] == 0.0
|
| 575 |
+
|
| 576 |
+
def test_full_medallion_scd1_penalized(self):
|
| 577 |
+
state_scd1 = MedusaState(
|
| 578 |
+
stage="committed", did_sync_check=True,
|
| 579 |
+
did_evolve_schema=True, scd_type="SCD-1", grader_passed=False,
|
| 580 |
+
)
|
| 581 |
+
state_scd2 = MedusaState(
|
| 582 |
+
stage="committed", did_sync_check=True,
|
| 583 |
+
did_evolve_schema=True, scd_type="SCD-2", grader_passed=True,
|
| 584 |
+
)
|
| 585 |
+
r1 = score_episode("full_medallion", state_scd1)
|
| 586 |
+
r2 = score_episode("full_medallion", state_scd2)
|
| 587 |
+
assert r2.score > r1.score
|
| 588 |
+
|
| 589 |
+
def test_unknown_task_raises(self):
|
| 590 |
+
with pytest.raises(ValueError, match="Unknown task_id"):
|
| 591 |
+
score_episode("nonexistent_task", MedusaState(stage="committed"))
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|