rampluto commited on
Commit
fd09b74
Β·
verified Β·
1 Parent(s): ea782b7

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF Space root-level Dockerfile β€” targets port 7860 (HF default).
2
+ # This file lives at envs/medusa_env/Dockerfile and is the file
3
+ # HF Spaces uses when deploying a Docker Space from this directory.
4
+
5
+ FROM python:3.12-slim
6
+
7
+ WORKDIR /app
8
+
9
+ # Install uv for fast dependency resolution
10
+ RUN pip install uv --no-cache-dir
11
+
12
+ # Copy environment code
13
+ COPY . /app/env
14
+
15
+ WORKDIR /app/env
16
+
17
+ # Install all dependencies including openenv-core + pandas + numpy
18
+ RUN uv pip install --system --no-cache \
19
+ "openenv-core[core]>=0.2.2" \
20
+ fastapi \
21
+ "uvicorn[standard]" \
22
+ pydantic \
23
+ pandas \
24
+ numpy \
25
+ websockets
26
+
27
+ # Install the medusa package itself (so medusa_env.* imports resolve)
28
+ RUN uv pip install --system --no-cache -e .
29
+
30
+ # HF Spaces requires port 7860
31
+ ENV PORT=7860
32
+ EXPOSE 7860
33
+
34
+ # PYTHONPATH so imports resolve correctly when running from /app/env
35
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
36
+ ENV ENABLE_WEB_INTERFACE=true
37
+
38
+ # Health check on HF port
39
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \
40
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
41
+
42
+ # Run on port 7860 β€” HF Space requirement
43
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Ram Janam Yadav
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,219 @@
1
  ---
2
- title: Medusa Env
3
- emoji: 🏒
4
- colorFrom: pink
5
  colorTo: blue
6
  sdk: docker
7
  pinned: false
8
- license: mit
9
- short_description: 'Reinforcement Learning developed using OpenEnv '
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MEDUSA Environment
3
+ emoji: πŸ¦‘
4
+ colorFrom: purple
5
  colorTo: blue
6
  sdk: docker
7
  pinned: false
8
+ tags:
9
+ - openenv
10
+ - reinforcement-learning
11
+ - data-engineering
12
+ app_port: 7860
13
+ base_path: /web
14
  ---
15
 
16
+ # MEDUSA
17
+
18
+ **Medallion-Engineered Deterministic Unified Storage Agent**
19
+
20
+ An OpenEnv reinforcement learning environment that trains agents to act as *Relational Controllers* — orchestrating multi-source Bronze→Silver data integration pipelines inside a Medallion Architecture.
21
+
22
+ ---
23
+
24
+ ## Problem
25
+
26
+ Modern data platforms fail not because they can't clean a single table, but because they can't reliably integrate **multiple shifting sources**. The Bronze→Silver transition is a minefield of:
27
+
28
+ - **Stale data** β€” processing yesterday's snapshot wastes compute and produces wrong results
29
+ - **Schema drift** β€” new columns appear in sources that Silver doesn't know about yet
30
+ - **Dirty join keys** β€” NULLs and whitespace cause 0-row joins and silent data loss
31
+ - **Cartesian explosions** β€” joining on non-unique Dimension keys multiplies rows catastrophically
32
+ - **Orphaned records** β€” unmatched Fact rows must be quarantined, not silently dropped
33
+
34
+ MEDUSA trains an agent to detect and handle all of these autonomously.
35
+
36
+ ---
37
+
38
+ ## Environment Overview
39
+
40
+ ```
41
+ Bronze A (Fact) ──┐
42
+ β”œβ”€β”€β–Ί [Agent] ──► Silver + /quarantine
43
+ Bronze B (Dim) β”€β”€β”˜
44
+ ```
45
+
46
+ The agent observes data quality signals and selects ETL actions step-by-step. At the end it issues `COMMIT`, triggering a deterministic grader audit.
47
+
48
+ ---
49
+
50
+ ## The MDP
51
+
52
+ ### Observation Space
53
+
54
+ A **16-element normalised float vector** `[0, 1]`:
55
+
56
+ | Index | Feature | Description |
57
+ |-------|---------|-------------|
58
+ | 0–1 | `time_delta_a/b_norm` | Source freshness (hours / 48h ceiling) |
59
+ | 2–3 | `is_stale_a/b` | Binary staleness flag |
60
+ | 4–5 | `null_ratio_key_a/b` | Fraction of null join keys |
61
+ | 6–7 | `uniqueness_a/b` | Key uniqueness ratio (1.0 = fully unique) |
62
+ | 8 | `match_rate` | % of Fact keys found in Dimension |
63
+ | 9–10 | `new_cols_a/b_norm` | Schema drift columns pending |
64
+ | 11 | `schema_compat` | Key type compatibility score |
65
+ | 12–14 | `did_prep_a/b`, `did_dedup_b` | Prerequisite action flags |
66
+ | 15 | `step_frac` | Episode progress (step / max_steps) |
67
+
68
+ ### Action Space
69
+
70
+ 11 discrete actions:
71
+
72
+ | Action | Description |
73
+ |--------|-------------|
74
+ | `SYNC_CHECK` | Verify freshness of both sources |
75
+ | `EVOLVE_SCHEMA` | Add new columns from A/B into Silver schema |
76
+ | `PREP_KEYS_A` | Cast, strip, null-fill join key in Source A |
77
+ | `PREP_KEYS_B` | Cast, strip, null-fill join key in Source B |
78
+ | `DEDUPLICATE_B` | Ensure Dimension (B) is unique on the join key |
79
+ | `EXECUTE_JOIN_INNER` | Inner join A β‹ˆ B |
80
+ | `EXECUTE_JOIN_LEFT` | Left join A β‹ˆ B (orphans β†’ quarantine) |
81
+ | `EXECUTE_JOIN_ANTI` | Anti-join: extract rows in A with no match in B |
82
+ | `APPLY_SCD_1` | Overwrite Silver records (SCD Type 1) |
83
+ | `APPLY_SCD_2` | Close old records, insert new with timestamps (SCD Type 2) |
84
+ | `COMMIT` | Finalise pipeline; triggers grader audit |
85
+
86
+ ### Reward Model
87
+
88
+ | Event | Reward | Trigger |
89
+ |-------|--------|---------|
90
+ | High-Match Join | **+25.0** | `match_rate > 90%` after join |
91
+ | Quarantine Precision | **+10.0** | Orphaned rows correctly isolated |
92
+ | Correct SCD-2 | **+5.0** | SCD-2 applied on a tracked column |
93
+ | Grader All-Pass Bonus | **+15.0** | All 4 post-commit checks pass |
94
+ | Row Explosion | **βˆ’100.0** | Join output > 105% of Fact row count |
95
+ | Join on Dirty Keys | **βˆ’30.0** | Join without PREP_KEYS β†’ 0-row result |
96
+ | Stale Processing | **βˆ’15.0** | Action taken while source is stale, SYNC_CHECK never called |
97
+ | Step Penalty | **βˆ’0.2** | Applied every step (efficiency incentive) |
98
+
99
+ ---
100
+
101
+ ## Post-Commit Grader
102
+
103
+ After `COMMIT` the deterministic grader runs 4 checks:
104
+
105
+ | Check | Pass Condition |
106
+ |-------|---------------|
107
+ | **Volume** | `Silver rows ≀ Source A rows` (for left joins) |
108
+ | **Integrity** | Quarantine holds only true orphans (not keys that could have joined if cleaned) |
109
+ | **Schema** | Silver contains the union of all required columns from A and B |
110
+ | **History** | SCD-2 `valid_from`/`valid_to` timestamps are non-overlapping |
111
+
112
+ All 4 pass β†’ **+15.0** bonus. Each failure costs **βˆ’5.0**.
113
+
114
+ ---
115
+
116
+ ## Episode Scenarios
117
+
118
+ Four canonical scenarios (selectable by seed):
119
+
120
+ | Seed | Scenario | Challenge |
121
+ |------|----------|-----------|
122
+ | 0 | `clean` | Fresh, unique keys, ~100% match rate. Baseline. |
123
+ | 1 | `dirty_keys` | NULLs + whitespace in join keys. Must PREP first. |
124
+ | 2 | `stale` | Source A is 8–24h old. Must SYNC_CHECK first. |
125
+ | 3 | `schema_drift` | New columns in A and B not yet in Silver. Must EVOLVE first. |
126
+
127
+ Random seeds produce blended variants.
128
+
129
+ ---
130
+
131
+ ## Setup
132
+
133
+ ```bash
134
+ # Clone / navigate to repo
135
+ cd /path/to/OpenEnv
136
+
137
+ # Create venv and install all deps (including pandas, numpy)
138
+ uv sync
139
+
140
+ # Activate
141
+ source .venv/bin/activate
142
+ ```
143
+
144
+ ---
145
+
146
+ ## Running
147
+
148
+ ### Start the FastAPI server
149
+
150
+ ```bash
151
+ uvicorn envs.medusa_env.server.app:app --reload --host 0.0.0.0 --port 8000
152
+ ```
153
+
154
+ API docs available at `http://localhost:8000/docs`.
155
+
156
+ ### Run tests
157
+
158
+ ```bash
159
+ python -m pytest tests/envs/test_medusa_environment.py -v
160
+ # 39 passed in ~4s
161
+ ```
162
+
163
+ ### Run a manual episode (Python)
164
+
165
+ ```python
166
+ from envs.medusa_env import MedusaEnv, MedusaAction
167
+ from envs.medusa_env.models import MedusaActionType
168
+
169
+ env = MedusaEnv(n_fact_rows=200, n_dim_rows=150)
170
+ obs = env.reset(seed=0) # seed 0 = clean scenario
171
+ print(obs.message)
172
+
173
+ for action_type in [
174
+ MedusaActionType.SYNC_CHECK,
175
+ MedusaActionType.EVOLVE_SCHEMA,
176
+ MedusaActionType.PREP_KEYS_A,
177
+ MedusaActionType.PREP_KEYS_B,
178
+ MedusaActionType.DEDUPLICATE_B,
179
+ MedusaActionType.EXECUTE_JOIN_LEFT,
180
+ MedusaActionType.APPLY_SCD_2,
181
+ MedusaActionType.COMMIT,
182
+ ]:
183
+ obs = env.step(MedusaAction(action=action_type))
184
+ print(f"{action_type.value:25s} reward={obs.reward:+.1f} done={obs.done}")
185
+
186
+ print(f"\nGrader: {env.state.grader_report}")
187
+ ```
188
+
189
+ ---
190
+
191
+ ## Architecture
192
+
193
+ ```
194
+ envs/medusa_env/
195
+ β”œβ”€β”€ __init__.py # Package exports
196
+ β”œβ”€β”€ medusa_env.py # MedusaEnv β€” reset / step / commit loop
197
+ β”œβ”€β”€ models.py # MedusaAction, MedusaObservation, MedusaState (Pydantic)
198
+ β”œβ”€β”€ scenarios.py # ScenarioGenerator β€” procedural Bronze A/B DataFrames
199
+ β”œβ”€β”€ operators.py # Stateless ETL functions (sync_check, prep_keys, execute_join, apply_scd …)
200
+ β”œβ”€β”€ rewards.py # RewardEngine β€” per-step reward computation
201
+ β”œβ”€β”€ grader.py # Grader β€” post-commit deterministic audit
202
+ β”œβ”€β”€ openenv.yaml # OpenEnv environment manifest
203
+ └── server/
204
+ └── app.py # FastAPI app via create_app()
205
+
206
+ tests/envs/
207
+ └── test_medusa_environment.py # 39 tests across 6 test classes
208
+ ```
209
+
210
+ **Stack:** Python 3.10+ Β· Pandas Β· Pydantic v2 Β· FastAPI Β· OpenEnv
211
+
212
+ ---
213
+
214
+ ## Technical Notes
215
+
216
+ - **No external data required.** All Bronze tables are generated procedurally per episode.
217
+ - **No Spark or Delta Lake required.** All logic uses Pandas β€” identical semantics, zero cluster setup.
218
+ - The grader is fully deterministic: same Silver + quarantine tables always produce the same audit result.
219
+ - The governance log (accessible at `env._tables.governance_log`) records every agent decision with its reward and operator metrics.
__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA (Medallion-Engineered Deterministic Unified Storage Agent) environment.
2
+
3
+ Full Bronze→Silver integration controller with:
4
+ - Multi-source join orchestration (inner / left / anti)
5
+ - Schema drift handling (EVOLVE_SCHEMA)
6
+ - Key preparation and deduplication
7
+ - SCD-1 and SCD-2 merge logic
8
+ - Per-step RL reward engine
9
+ - Deterministic post-commit grader
10
+ """
11
+
12
+ from .client import medusa_env
13
+ from .grader import Grader, GraderResult
14
+ from .models import MedusaAction, MedusaActionType, MedusaObservation, MedusaState
15
+ from .rewards import RewardEngine
16
+ from .scenarios import Scenario, ScenarioGenerator
17
+ from .tasks import TASKS, Task, TaskResult, score_episode
18
+
19
+ __all__ = [
20
+ "medusa_env",
21
+ "MedusaAction",
22
+ "MedusaActionType",
23
+ "MedusaObservation",
24
+ "MedusaState",
25
+ "Scenario",
26
+ "ScenarioGenerator",
27
+ "RewardEngine",
28
+ "Grader",
29
+ "GraderResult",
30
+ "TASKS",
31
+ "Task",
32
+ "TaskResult",
33
+ "score_episode",
34
+ ]
client.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA Environment Client.
2
+
3
+ Connects to a running MEDUSA server via WebSocket for persistent sessions.
4
+
5
+ Example:
6
+ >>> # Connect to a running server
7
+ >>> with medusa_env(base_url="http://localhost:8000") as client:
8
+ ... result = client.reset(seed=0)
9
+ ... print(result.observation.message)
10
+ ...
11
+ ... from envs.medusa_env.models import MedusaActionType
12
+ ... result = client.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
13
+ ... print(f"Reward: {result.reward}")
14
+
15
+ Example with Docker:
16
+ >>> client = medusa_env.from_docker_image("medusa_env:latest")
17
+ >>> try:
18
+ ... result = client.reset()
19
+ ... result = client.step(MedusaAction(action=MedusaActionType.COMMIT))
20
+ ... finally:
21
+ ... client.close()
22
+ """
23
+
24
+ from typing import Any, Dict
25
+
26
+ # Support both in-repo and standalone imports
27
+ try:
28
+ from openenv.core.client_types import StepResult
29
+ from openenv.core.env_client import EnvClient
30
+
31
+ from .models import MedusaAction, MedusaObservation, MedusaState
32
+ except ImportError:
33
+ from models import MedusaAction, MedusaObservation, MedusaState
34
+
35
+ from openenv.core.client_types import StepResult
36
+ from openenv.core.env_client import EnvClient
37
+
38
+
39
+ class medusa_env(EnvClient[MedusaAction, MedusaObservation, MedusaState]):
40
+ """Client for the MEDUSA Bronze→Silver integration environment.
41
+
42
+ Maintains a persistent WebSocket connection to the MEDUSA server.
43
+ Each client instance has its own dedicated environment session.
44
+
45
+ The agent observes a 16-float data quality feature vector and chooses
46
+ from 11 discrete ETL actions to build a correct Silver entity from
47
+ two Bronze sources (Fact + Dimension).
48
+
49
+ Example:
50
+ >>> with medusa_env(base_url="http://localhost:8000") as env:
51
+ ... result = env.reset(seed=0) # clean scenario
52
+ ... result = env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
53
+ ... result = env.step(MedusaAction(action=MedusaActionType.PREP_KEYS_A))
54
+ ... result = env.step(MedusaAction(action=MedusaActionType.PREP_KEYS_B))
55
+ ... result = env.step(MedusaAction(action=MedusaActionType.DEDUPLICATE_B))
56
+ ... result = env.step(MedusaAction(action=MedusaActionType.EXECUTE_JOIN_LEFT))
57
+ ... result = env.step(MedusaAction(action=MedusaActionType.APPLY_SCD_2))
58
+ ... result = env.step(MedusaAction(action=MedusaActionType.COMMIT))
59
+ ... print(result.reward)
60
+ """
61
+
62
+ def _step_payload(self, action: MedusaAction) -> Dict[str, Any]:
63
+ """Convert MedusaAction to JSON payload for the step request."""
64
+ return {
65
+ "action": action.action.value,
66
+ "params": action.params,
67
+ }
68
+
69
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[MedusaObservation]:
70
+ """Parse server response into StepResult[MedusaObservation]."""
71
+ obs_data = payload.get("observation", {})
72
+ observation = MedusaObservation(
73
+ message=obs_data.get("message", ""),
74
+ features=obs_data.get("features", []),
75
+ metrics=obs_data.get("metrics", {}),
76
+ metadata=obs_data.get("metadata", {}),
77
+ reward=payload.get("reward"),
78
+ done=payload.get("done", False),
79
+ )
80
+ return StepResult(
81
+ observation=observation,
82
+ reward=payload.get("reward"),
83
+ done=payload.get("done", False),
84
+ )
85
+
86
+ def _parse_state(self, payload: Dict[str, Any]) -> MedusaState:
87
+ """Parse server response into MedusaState."""
88
+ return MedusaState(
89
+ run_id=payload.get("run_id"),
90
+ seed=payload.get("seed"),
91
+ scenario_id=payload.get("scenario_id"),
92
+ step_idx=payload.get("step_idx", 0),
93
+ stage=payload.get("stage", "init"),
94
+ # Freshness
95
+ time_delta_a=payload.get("time_delta_a", 0.0),
96
+ time_delta_b=payload.get("time_delta_b", 0.0),
97
+ is_stale_a=payload.get("is_stale_a", False),
98
+ is_stale_b=payload.get("is_stale_b", False),
99
+ did_sync_check=payload.get("did_sync_check", False),
100
+ # Key health
101
+ null_ratio_key_a=payload.get("null_ratio_key_a", 0.0),
102
+ null_ratio_key_b=payload.get("null_ratio_key_b", 0.0),
103
+ uniqueness_a=payload.get("uniqueness_a", 1.0),
104
+ uniqueness_b=payload.get("uniqueness_b", 1.0),
105
+ did_prep_a=payload.get("did_prep_a", False),
106
+ did_prep_b=payload.get("did_prep_b", False),
107
+ did_dedup_b=payload.get("did_dedup_b", False),
108
+ # Join
109
+ match_rate=payload.get("match_rate", 0.0),
110
+ did_join=payload.get("did_join", False),
111
+ join_type=payload.get("join_type"),
112
+ join_row_count=payload.get("join_row_count", 0),
113
+ explosion_detected=payload.get("explosion_detected", False),
114
+ # SCD
115
+ did_scd=payload.get("did_scd", False),
116
+ scd_type=payload.get("scd_type"),
117
+ scd_inserts=payload.get("scd_inserts", 0),
118
+ scd_updates=payload.get("scd_updates", 0),
119
+ # Silver / Quarantine
120
+ silver_row_count=payload.get("silver_row_count", 0),
121
+ quarantine_row_count=payload.get("quarantine_row_count", 0),
122
+ source_a_row_count=payload.get("source_a_row_count", 0),
123
+ # Grader
124
+ grader_passed=payload.get("grader_passed", False),
125
+ grader_report=payload.get("grader_report", ""),
126
+ cumulative_reward=payload.get("cumulative_reward", 0.0),
127
+ )
grader.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA deterministic post-commit grader.
2
+
3
+ Runs a four-check audit after the agent issues COMMIT and returns a
4
+ ``GraderResult`` that feeds a bonus/penalty into the terminal reward.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import TYPE_CHECKING, List
11
+
12
+ import pandas as pd
13
+
14
+ if TYPE_CHECKING:
15
+ from .scenarios import Scenario
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # GraderResult
20
+ # ---------------------------------------------------------------------------
21
+
22
+ @dataclass
23
+ class GraderResult:
24
+ """Outcome of the post-commit audit."""
25
+
26
+ passed: bool = False
27
+ volume_ok: bool = False # Silver rows ≀ Source A rows (no duplicates from join)
28
+ integrity_ok: bool = False # Quarantine holds only true orphans
29
+ schema_ok: bool = False # Silver has union of required columns
30
+ history_ok: bool = False # SCD-2 timestamps non-overlapping
31
+ failures: List[str] = field(default_factory=list)
32
+ bonus_reward: float = 0.0
33
+ report: str = ""
34
+
35
+
36
+ # Reward tuning
37
+ _BONUS_ALL_PASS = +15.0
38
+ _PENALTY_ALL_FAIL = -20.0
39
+ _BONUS_PER_CHECK = +3.0
40
+ _PENALTY_PER_FAIL = -5.0
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Grader
45
+ # ---------------------------------------------------------------------------
46
+
47
+ class Grader:
48
+ """Post-commit deterministic audit following MEDUSA spec Β§4."""
49
+
50
+ def audit(
51
+ self,
52
+ silver: pd.DataFrame,
53
+ quarantine: pd.DataFrame,
54
+ bronze_a: pd.DataFrame,
55
+ bronze_b: pd.DataFrame,
56
+ join_key: str,
57
+ join_type: str,
58
+ scd_type: int,
59
+ scenario: "Scenario",
60
+ ) -> GraderResult:
61
+ """Run all four grader checks and compute bonus reward.
62
+
63
+ Args:
64
+ silver: The final Silver DataFrame after SCD merge.
65
+ quarantine: Rows from A that did not match B.
66
+ bronze_a: Original fact source (pre-cleaning).
67
+ bronze_b: Original dimension source (pre-cleaning).
68
+ join_key: Column used for the join.
69
+ join_type: "inner" | "left" | "anti"
70
+ scd_type: 1 or 2
71
+ scenario: The current episode's scenario (has tracked_cols etc.)
72
+
73
+ Returns:
74
+ GraderResult with individual check statuses and bonus_reward.
75
+ """
76
+ result = GraderResult()
77
+
78
+ # ── 1. Volume Check ──────────────────────────────────────────────
79
+ # For left joins, Silver should not exceed Source A row count.
80
+ if join_type == "left":
81
+ source_a_rows = len(bronze_a.dropna(subset=[join_key]))
82
+ silver_rows = len(silver[silver.get("is_current", pd.Series(True, index=silver.index)) == True]) if "is_current" in silver.columns else len(silver) # noqa: E712
83
+ result.volume_ok = silver_rows <= source_a_rows * 1.05 # 5% tolerance
84
+ if not result.volume_ok:
85
+ result.failures.append(
86
+ f"VOLUME_FAIL: Silver {silver_rows} rows > Source A {source_a_rows} rows"
87
+ )
88
+ else:
89
+ result.volume_ok = True # Not applicable for inner/anti joins
90
+
91
+ # ── 2. Integrity Check ───────────────────────────────────────────
92
+ # Quarantine rows should be true orphans (no match in B even after cleaning).
93
+ if not quarantine.empty and join_key in quarantine.columns:
94
+ dim_keys = set(bronze_b[join_key].dropna().astype(str).str.strip())
95
+ quarantine_keys = set(quarantine[join_key].dropna().astype(str).str.strip())
96
+ # Orphan = quarantine key truly not in dim
97
+ could_join = quarantine_keys & dim_keys
98
+ if could_join:
99
+ result.integrity_ok = False
100
+ result.failures.append(
101
+ f"INTEGRITY_FAIL: {len(could_join)} quarantine row(s) could have "
102
+ f"been joined if keys were cleaned."
103
+ )
104
+ else:
105
+ result.integrity_ok = True
106
+ else:
107
+ result.integrity_ok = True # Empty quarantine is fine
108
+
109
+ # ── 3. Schema Check ──────────────────────────────────────────────
110
+ # Silver must contain all required columns from A and B.
111
+ required_from_a = [c for c in bronze_a.columns if c != join_key]
112
+ required_from_b = [c for c in bronze_b.columns if c != join_key]
113
+ required = set(required_from_a + required_from_b + scenario.new_cols_a + scenario.new_cols_b)
114
+ silver_cols = set(silver.columns)
115
+ missing = required - silver_cols
116
+ if missing:
117
+ result.schema_ok = False
118
+ result.failures.append(f"SCHEMA_FAIL: Missing columns in Silver: {sorted(missing)}")
119
+ else:
120
+ result.schema_ok = True
121
+
122
+ # ── 4. History Check (SCD-2 only) ────────────────────────────────
123
+ if scd_type == 2 and "valid_from" in silver.columns and "valid_to" in silver.columns:
124
+ overlap_found = False
125
+ for key_val, group in silver.groupby(join_key):
126
+ if len(group) < 2:
127
+ continue
128
+ closed = group[group["valid_to"].notna()].sort_values("valid_from")
129
+ for i in range(len(closed) - 1):
130
+ vt_i = closed.iloc[i]["valid_to"]
131
+ vf_next = closed.iloc[i + 1]["valid_from"]
132
+ if pd.notna(vt_i) and pd.notna(vf_next) and vt_i > vf_next:
133
+ overlap_found = True
134
+ break
135
+ if overlap_found:
136
+ break
137
+ if overlap_found:
138
+ result.history_ok = False
139
+ result.failures.append("HISTORY_FAIL: SCD-2 timestamps overlap for some keys.")
140
+ else:
141
+ result.history_ok = True
142
+ else:
143
+ result.history_ok = True # Not applicable for SCD-1
144
+
145
+ # ── Compute bonus ────────────────────────────────────────────────
146
+ checks = [result.volume_ok, result.integrity_ok, result.schema_ok, result.history_ok]
147
+ passed_count = sum(checks)
148
+ failed_count = len(checks) - passed_count
149
+
150
+ result.passed = all(checks)
151
+
152
+ if result.passed:
153
+ result.bonus_reward = _BONUS_ALL_PASS
154
+ elif failed_count == len(checks):
155
+ result.bonus_reward = _PENALTY_ALL_FAIL
156
+ else:
157
+ result.bonus_reward = passed_count * _BONUS_PER_CHECK - failed_count * _PENALTY_PER_FAIL
158
+
159
+ result.report = _build_report(result)
160
+ return result
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Internal helpers
165
+ # ---------------------------------------------------------------------------
166
+
167
+ def _build_report(result: GraderResult) -> str:
168
+ lines = ["=== MEDUSA Grader Audit ==="]
169
+ lines.append(f" Volume OK: {'βœ“' if result.volume_ok else 'βœ—'}")
170
+ lines.append(f" Integrity OK: {'βœ“' if result.integrity_ok else 'βœ—'}")
171
+ lines.append(f" Schema OK: {'βœ“' if result.schema_ok else 'βœ—'}")
172
+ lines.append(f" History OK: {'βœ“' if result.history_ok else 'βœ—'}")
173
+ lines.append(f" Bonus Reward: {result.bonus_reward:+.1f}")
174
+ if result.failures:
175
+ lines.append(" Failures:")
176
+ for f in result.failures:
177
+ lines.append(f" - {f}")
178
+ lines.append(f" {'PASS βœ“' if result.passed else 'FAIL βœ—'}")
179
+ return "\n".join(lines)
models.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from pydantic import Field
7
+
8
+ from openenv.core.env_server.types import Action, Observation, State
9
+
10
+
11
+ class MedusaActionType(str, Enum):
12
+ """Discrete action set for the MEDUSA controller."""
13
+
14
+ SYNC_CHECK = "SYNC_CHECK"
15
+ EVOLVE_SCHEMA = "EVOLVE_SCHEMA"
16
+ PREP_KEYS_A = "PREP_KEYS_A"
17
+ PREP_KEYS_B = "PREP_KEYS_B"
18
+ DEDUPLICATE_B = "DEDUPLICATE_B"
19
+ EXECUTE_JOIN_INNER = "EXECUTE_JOIN_INNER"
20
+ EXECUTE_JOIN_LEFT = "EXECUTE_JOIN_LEFT"
21
+ EXECUTE_JOIN_ANTI = "EXECUTE_JOIN_ANTI"
22
+ APPLY_SCD_1 = "APPLY_SCD_1"
23
+ APPLY_SCD_2 = "APPLY_SCD_2"
24
+ COMMIT = "COMMIT"
25
+
26
+
27
+ class MedusaAction(Action):
28
+ """One controller action (enum + optional params for future use)."""
29
+
30
+ action: MedusaActionType
31
+ params: Dict[str, Any] = Field(default_factory=dict)
32
+
33
+
34
+ class MedusaState(State):
35
+ """Full pipeline controller state.
36
+
37
+ Tracks every book-keeping flag needed by the reward engine and grader.
38
+ """
39
+
40
+ run_id: Optional[str] = None
41
+ seed: Optional[int] = None
42
+ scenario_id: Optional[str] = None
43
+ max_steps: int = 20
44
+
45
+ step_idx: int = 0
46
+ stage: str = "init" # init | running | committed | failed
47
+
48
+ # --- Freshness ---
49
+ time_delta_a: float = 0.0 # Hours since Source A last updated
50
+ time_delta_b: float = 0.0
51
+ is_stale_a: bool = False
52
+ is_stale_b: bool = False
53
+ did_sync_check: bool = False
54
+
55
+ # --- Schema ---
56
+ did_evolve_schema: bool = False
57
+ new_cols_a: int = 0 # Number of new columns in A not yet in Silver
58
+ new_cols_b: int = 0
59
+ schema_compat: float = 1.0 # 0-1 key-type compatibility score
60
+
61
+ # --- Key Health ---
62
+ null_ratio_key_a: float = 0.0
63
+ null_ratio_key_b: float = 0.0
64
+ uniqueness_a: float = 1.0 # 1.0 = fully unique
65
+ uniqueness_b: float = 1.0
66
+ did_prep_a: bool = False
67
+ did_prep_b: bool = False
68
+ did_dedup_b: bool = False
69
+
70
+ # --- Referential Integrity ---
71
+ match_rate: float = 0.0 # % of Key_A values found in Key_B
72
+
73
+ # --- Join Result ---
74
+ did_join: bool = False
75
+ join_type: Optional[str] = None
76
+ join_row_count: int = 0
77
+ explosion_detected: bool = False
78
+
79
+ # --- SCD ---
80
+ did_scd: bool = False
81
+ scd_type: Optional[str] = None
82
+ scd_inserts: int = 0
83
+ scd_updates: int = 0
84
+
85
+ # --- Silver / Quarantine ---
86
+ silver_row_count: int = 0
87
+ quarantine_row_count: int = 0
88
+ source_a_row_count: int = 0
89
+
90
+ # --- Grader ---
91
+ grader_passed: bool = False
92
+ grader_report: str = ""
93
+
94
+ # --- Governance ---
95
+ cumulative_reward: float = 0.0
96
+
97
+
98
+ class MedusaObservation(Observation):
99
+ """Observation returned to the agent after every step.
100
+
101
+ ``features`` is a 16-element normalised float vector suitable as
102
+ direct RL input::
103
+
104
+ [time_delta_a_norm, time_delta_b_norm, is_stale_a, is_stale_b,
105
+ null_ratio_key_a, null_ratio_key_b, uniqueness_a, uniqueness_b,
106
+ match_rate, new_cols_a_norm, new_cols_b_norm, schema_compat,
107
+ did_prep_a, did_prep_b, did_dedup_b, step_frac]
108
+ """
109
+
110
+ message: str = ""
111
+ features: List[float] = Field(default_factory=list)
112
+ metrics: Dict[str, Any] = Field(default_factory=dict)
113
+ metadata: Dict[str, Any] = Field(default_factory=dict)
114
+ reward: Optional[float] = None
115
+ done: bool = False
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: medusa_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
openenv_medusa.egg-info/PKG-INFO ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-medusa
3
+ Version: 0.2.0
4
+ Summary: MEDUSA: Medallion-Engineered Deterministic Unified Storage Agent — Bronze→Silver RL environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ License-File: LICENSE
7
+ Requires-Dist: openenv-core[core]>=0.2.2
8
+ Requires-Dist: fastapi>=0.115.0
9
+ Requires-Dist: pydantic>=2.0.0
10
+ Requires-Dist: uvicorn>=0.24.0
11
+ Requires-Dist: pandas>=2.0.0
12
+ Requires-Dist: numpy>=1.24.0
13
+ Provides-Extra: dev
14
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
15
+ Requires-Dist: pytest-asyncio>=0.23.0; extra == "dev"
16
+ Dynamic: license-file
openenv_medusa.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ ./__init__.py
5
+ ./client.py
6
+ ./grader.py
7
+ ./models.py
8
+ ./openenv.yaml
9
+ ./operators.py
10
+ ./rewards.py
11
+ ./scenarios.py
12
+ ./tasks.py
13
+ ./server/__init__.py
14
+ ./server/app.py
15
+ ./server/medusa_env.py
16
+ openenv_medusa.egg-info/PKG-INFO
17
+ openenv_medusa.egg-info/SOURCES.txt
18
+ openenv_medusa.egg-info/dependency_links.txt
19
+ openenv_medusa.egg-info/entry_points.txt
20
+ openenv_medusa.egg-info/requires.txt
21
+ openenv_medusa.egg-info/top_level.txt
22
+ server/__init__.py
23
+ server/app.py
24
+ server/medusa_env.py
25
+ tests/test_medusa_environment.py
openenv_medusa.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_medusa.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = server.app:main
openenv_medusa.egg-info/requires.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.2
2
+ fastapi>=0.115.0
3
+ pydantic>=2.0.0
4
+ uvicorn>=0.24.0
5
+ pandas>=2.0.0
6
+ numpy>=1.24.0
7
+
8
+ [dev]
9
+ pytest>=8.0.0
10
+ pytest-asyncio>=0.23.0
openenv_medusa.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ medusa_env
2
+ server
operators.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA ETL operators.
2
+
3
+ Each operator is a stateless function that takes DataFrame(s) and returns a
4
+ (result_df_or_None, metrics_dict) tuple. The environment calls these from
5
+ ``step()`` and passes the metrics to the reward engine.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import datetime
11
+ from typing import Any, Dict, Optional, Tuple
12
+
13
+ import pandas as pd
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Type alias
18
+ # ---------------------------------------------------------------------------
19
+
20
+ Metrics = Dict[str, Any]
21
+ OpResult = Tuple[Optional[pd.DataFrame], Metrics]
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Operator: sync_check
26
+ # ---------------------------------------------------------------------------
27
+
28
+ def sync_check(
29
+ bronze_a: pd.DataFrame,
30
+ bronze_b: pd.DataFrame,
31
+ time_delta_a: float,
32
+ time_delta_b: float,
33
+ stale_threshold_hours: float = 6.0,
34
+ ) -> OpResult:
35
+ """Inspect freshness of both sources.
36
+
37
+ Returns metrics about staleness without modifying any data.
38
+ """
39
+ is_stale_a = time_delta_a > stale_threshold_hours
40
+ is_stale_b = time_delta_b > stale_threshold_hours
41
+ metrics: Metrics = {
42
+ "time_delta_a": time_delta_a,
43
+ "time_delta_b": time_delta_b,
44
+ "is_stale_a": is_stale_a,
45
+ "is_stale_b": is_stale_b,
46
+ "rows_a": len(bronze_a),
47
+ "rows_b": len(bronze_b),
48
+ }
49
+ return None, metrics
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Operator: evolve_schema
54
+ # ---------------------------------------------------------------------------
55
+
56
+ def evolve_schema(
57
+ silver: pd.DataFrame,
58
+ bronze_a: pd.DataFrame,
59
+ bronze_b: pd.DataFrame,
60
+ new_cols_a: list[str],
61
+ new_cols_b: list[str],
62
+ ) -> OpResult:
63
+ """Add new columns (from schema drift) to the Silver DataFrame.
64
+
65
+ Fills missing historical rows with NaN.
66
+ """
67
+ added: list[str] = []
68
+ result = silver.copy()
69
+
70
+ for col in new_cols_a + new_cols_b:
71
+ if col not in result.columns:
72
+ result[col] = pd.NA
73
+ added.append(col)
74
+
75
+ metrics: Metrics = {
76
+ "cols_added": added,
77
+ "new_cols_count": len(added),
78
+ "silver_col_count": len(result.columns),
79
+ }
80
+ return result, metrics
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Operator: prep_keys
85
+ # ---------------------------------------------------------------------------
86
+
87
+ def prep_keys(df: pd.DataFrame, key_col: str) -> OpResult:
88
+ """Cast, strip whitespace, and null-fill the join key column.
89
+
90
+ Returns a cleaned copy of ``df`` with metrics about how many rows were
91
+ affected.
92
+ """
93
+ result = df.copy()
94
+ original_nulls = result[key_col].isna().sum()
95
+ original_len = len(result)
96
+
97
+ # Strip whitespace (treat blank strings as nulls)
98
+ result[key_col] = result[key_col].astype(str).str.strip()
99
+ result[key_col] = result[key_col].replace({"None": pd.NA, "nan": pd.NA, "": pd.NA})
100
+
101
+ # Cast to string (uniform type for join)
102
+ result[key_col] = result[key_col].astype("string")
103
+
104
+ after_nulls = result[key_col].isna().sum()
105
+ null_ratio_before = original_nulls / max(original_len, 1)
106
+ null_ratio_after = int(after_nulls) / max(original_len, 1)
107
+
108
+ metrics: Metrics = {
109
+ "null_ratio_before": null_ratio_before,
110
+ "null_ratio_after": null_ratio_after,
111
+ "rows_trimmed": original_len - int(after_nulls),
112
+ "null_rows_dropped": 0, # We do NOT drop nulls; grader catches orphans
113
+ }
114
+ return result, metrics
115
+
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # Operator: deduplicate
119
+ # ---------------------------------------------------------------------------
120
+
121
+ def deduplicate(df: pd.DataFrame, key_col: str) -> OpResult:
122
+ """Ensure Dimension (Source B) is unique on ``key_col``.
123
+
124
+ Keeps the last occurrence so the most-recent record wins.
125
+ """
126
+ original_len = len(df)
127
+ result = df.drop_duplicates(subset=[key_col], keep="last").reset_index(drop=True)
128
+ dupes_removed = original_len - len(result)
129
+
130
+ non_null = result[key_col].notna().sum()
131
+ uniqueness = non_null / max(len(result), 1)
132
+
133
+ metrics: Metrics = {
134
+ "dupes_removed": dupes_removed,
135
+ "uniqueness": float(uniqueness),
136
+ "rows_after": len(result),
137
+ }
138
+ return result, metrics
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # Operator: execute_join
143
+ # ---------------------------------------------------------------------------
144
+
145
+ _EXPLOSION_MULTIPLIER = 1.05 # > 5% extra rows triggers explosion alert
146
+
147
+
148
+ def execute_join(
149
+ fact: pd.DataFrame,
150
+ dim: pd.DataFrame,
151
+ key_col: str,
152
+ join_type: str, # "inner" | "left" | "anti"
153
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, Metrics]:
154
+ """Join Fact (A) with Dimension (B).
155
+
156
+ Returns (joined_df, quarantine_df, metrics).
157
+ ``quarantine_df`` contains rows from A that did not match B (orphans).
158
+ """
159
+ # Drop null-keyed rows from both before joining
160
+ fact_clean = fact.dropna(subset=[key_col])
161
+ dim_clean = dim.dropna(subset=[key_col])
162
+
163
+ # Compute match rate before join
164
+ fact_keys = set(fact_clean[key_col].astype(str))
165
+ dim_keys = set(dim_clean[key_col].astype(str))
166
+ overlap = fact_keys & dim_keys
167
+ match_rate = len(overlap) / max(len(fact_keys), 1)
168
+
169
+ if join_type == "anti":
170
+ # Anti-join: rows in A NOT in B β†’ goes to quarantine
171
+ mask = ~fact_clean[key_col].astype(str).isin(dim_keys)
172
+ joined = pd.DataFrame(columns=list(fact_clean.columns) + [
173
+ c for c in dim_clean.columns if c != key_col
174
+ ])
175
+ quarantine = fact_clean[mask].copy()
176
+ elif join_type == "inner":
177
+ merged = fact_clean.merge(dim_clean, on=key_col, how="inner",
178
+ suffixes=("_a", "_b"))
179
+ quarantine = fact_clean[~fact_clean[key_col].astype(str).isin(dim_keys)].copy()
180
+ joined = merged
181
+ else: # left
182
+ merged = fact_clean.merge(dim_clean, on=key_col, how="left",
183
+ suffixes=("_a", "_b"))
184
+ # Quarantine = rows where all dim columns are NaN (no match)
185
+ dim_cols = [c for c in dim_clean.columns if c != key_col]
186
+ if dim_cols:
187
+ no_match_mask = merged[dim_cols[0]].isna() if dim_cols else pd.Series(False, index=merged.index)
188
+ else:
189
+ no_match_mask = pd.Series(False, index=merged.index)
190
+ quarantine = merged[no_match_mask][[key_col]].copy()
191
+ joined = merged
192
+
193
+ # Explosion detection
194
+ explosion = len(joined) > len(fact_clean) * _EXPLOSION_MULTIPLIER
195
+
196
+ metrics: Metrics = {
197
+ "join_type": join_type,
198
+ "fact_rows": len(fact_clean),
199
+ "dim_rows": len(dim_clean),
200
+ "join_rows": len(joined),
201
+ "quarantine_rows": len(quarantine),
202
+ "match_rate": match_rate,
203
+ "explosion_detected": explosion,
204
+ }
205
+ return joined, quarantine, metrics
206
+
207
+
208
+ # ---------------------------------------------------------------------------
209
+ # Operator: apply_scd
210
+ # ---------------------------------------------------------------------------
211
+
212
+ def apply_scd(
213
+ silver: pd.DataFrame,
214
+ joined: pd.DataFrame,
215
+ key_col: str,
216
+ tracked_col: str,
217
+ scd_type: int, # 1 or 2
218
+ ) -> OpResult:
219
+ """Merge ``joined`` result into Silver using SCD-1 or SCD-2.
220
+
221
+ SCD-1: overwrite existing records.
222
+ SCD-2: close old records (valid_to = now) and insert new ones with
223
+ a new valid_from / valid_to = None (open record).
224
+ """
225
+ now = datetime.datetime.now(datetime.UTC)
226
+ inserts = 0
227
+ updates = 0
228
+
229
+ if joined.empty:
230
+ metrics: Metrics = {
231
+ "scd_type": scd_type,
232
+ "inserts": 0,
233
+ "updates": 0,
234
+ "silver_rows": len(silver),
235
+ }
236
+ return silver, metrics
237
+
238
+ if silver.empty:
239
+ # First load β€” treat everything as inserts
240
+ result = joined.copy()
241
+ if scd_type == 2:
242
+ result["valid_from"] = now
243
+ result["valid_to"] = pd.NaT
244
+ result["is_current"] = True
245
+ inserts = len(result)
246
+ metrics = {
247
+ "scd_type": scd_type,
248
+ "inserts": inserts,
249
+ "updates": 0,
250
+ "silver_rows": len(result),
251
+ }
252
+ return result, metrics
253
+
254
+ if scd_type == 1:
255
+ # Upsert: overwrite matching records
256
+ exists_mask = silver[key_col].isin(joined[key_col])
257
+ new_keys_mask = ~joined[key_col].isin(silver[key_col])
258
+
259
+ result = silver[~exists_mask].copy()
260
+ result = pd.concat([result, joined], ignore_index=True)
261
+
262
+ updates = int(exists_mask.sum())
263
+ inserts = int(new_keys_mask.sum())
264
+
265
+ else: # SCD-2
266
+ # Ensure Silver has timestamp columns
267
+ if "valid_from" not in silver.columns:
268
+ silver = silver.copy()
269
+ silver["valid_from"] = now - datetime.timedelta(days=30)
270
+ silver["valid_to"] = pd.NaT
271
+ silver["is_current"] = True
272
+
273
+ silver_result = silver.copy()
274
+ new_rows: list[pd.DataFrame] = []
275
+
276
+ for _, new_row in joined.iterrows():
277
+ key_val = new_row[key_col]
278
+ current_mask = (silver_result[key_col] == key_val) & (silver_result["is_current"] == True) # noqa: E712
279
+ current_rows = silver_result[current_mask]
280
+
281
+ if current_rows.empty:
282
+ # New record
283
+ row_df = pd.DataFrame([new_row])
284
+ row_df["valid_from"] = now
285
+ row_df["valid_to"] = pd.NaT
286
+ row_df["is_current"] = True
287
+ new_rows.append(row_df)
288
+ inserts += 1
289
+ else:
290
+ # Check if tracked column changed
291
+ old_val = current_rows.iloc[0].get(tracked_col)
292
+ new_val = new_row.get(tracked_col)
293
+ if old_val != new_val:
294
+ # Close old record
295
+ silver_result.loc[current_mask, "valid_to"] = now
296
+ silver_result.loc[current_mask, "is_current"] = False
297
+ # Insert new record
298
+ row_df = pd.DataFrame([new_row])
299
+ row_df["valid_from"] = now
300
+ row_df["valid_to"] = pd.NaT
301
+ row_df["is_current"] = True
302
+ new_rows.append(row_df)
303
+ updates += 1
304
+
305
+ if new_rows:
306
+ silver_result = pd.concat([silver_result] + new_rows, ignore_index=True)
307
+ result = silver_result
308
+
309
+ metrics = {
310
+ "scd_type": scd_type,
311
+ "inserts": inserts,
312
+ "updates": updates,
313
+ "silver_rows": len(result),
314
+ }
315
+ return result, metrics
pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-medusa"
7
+ version = "0.2.0"
8
+ description = "MEDUSA: Medallion-Engineered Deterministic Unified Storage Agent — Bronze→Silver RL environment for OpenEnv"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ # Core OpenEnv dependencies
12
+ "openenv-core[core]>=0.2.2",
13
+ "fastapi>=0.115.0",
14
+ "pydantic>=2.0.0",
15
+ "uvicorn>=0.24.0",
16
+ # Data pipeline dependencies
17
+ "pandas>=2.0.0",
18
+ "numpy>=1.24.0",
19
+ ]
20
+
21
+ [project.optional-dependencies]
22
+ dev = [
23
+ "pytest>=8.0.0",
24
+ "pytest-asyncio>=0.23.0",
25
+ ]
26
+
27
+ [project.scripts]
28
+ # Enables: uv run server (from the medusa_env directory)
29
+ server = "server.app:main"
30
+
31
+ [tool.setuptools]
32
+ include-package-data = true
33
+ packages = ["medusa_env", "medusa_env.server", "server"]
34
+ package-dir = { "medusa_env" = ".", "server" = "server" }
35
+
36
+ [tool.setuptools.package-data]
37
+ medusa_env = ["**/*.yaml", "**/*.yml"]
rewards.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA reward engine.
2
+
3
+ Reward model as defined in the MEDUSA blueprint. All reward logic is in a
4
+ single ``RewardEngine`` class so it can be unit-tested in isolation from the
5
+ environment.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Dict
11
+
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # Reward table (blueprint Β§3)
15
+ # ---------------------------------------------------------------------------
16
+
17
+ REWARD_TABLE: Dict[str, float] = {
18
+ "high_match_join": +25.0, # match_rate > 0.90
19
+ "correct_scd2": +5.0, # SCD-2 used on a tracked column
20
+ "quarantine_precision": +10.0, # Orphaned rows correctly moved to quarantine
21
+ "row_explosion": -100.0, # Cartesian product detected
22
+ "dirty_join": -30.0, # Join attempted without PREP_KEYS β†’ 0-row result
23
+ "stale_processing": -15.0, # Action taken while source is stale (not synced first)
24
+ "step_penalty": -0.2, # Per-step efficiency penalty
25
+ }
26
+
27
+ HIGH_MATCH_THRESHOLD = 0.90
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # RewardEngine
32
+ # ---------------------------------------------------------------------------
33
+
34
+ class RewardEngine:
35
+ """Compute per-step reward from action context and operator metrics."""
36
+
37
+ def evaluate(
38
+ self,
39
+ action_type: str,
40
+ metrics: Dict[str, Any],
41
+ state_before: Any, # MedusaState snapshot before step
42
+ ) -> float:
43
+ """Return the scalar reward for a single step.
44
+
45
+ Args:
46
+ action_type: The ``MedusaActionType`` value string (e.g. "SYNC_CHECK").
47
+ metrics: Dictionary returned by the corresponding operator.
48
+ state_before: State object *before* this step was applied.
49
+
50
+ Returns:
51
+ Scalar float reward.
52
+ """
53
+ reward = REWARD_TABLE["step_penalty"] # always applied
54
+
55
+ if action_type == "SYNC_CHECK":
56
+ # No positive/negative signal from sync_check itself
57
+ pass
58
+
59
+ elif action_type in ("PREP_KEYS_A", "PREP_KEYS_B"):
60
+ # Neutral β€” prep is just a prerequisite
61
+ pass
62
+
63
+ elif action_type == "DEDUPLICATE_B":
64
+ pass
65
+
66
+ elif action_type == "EVOLVE_SCHEMA":
67
+ pass
68
+
69
+ elif action_type in ("EXECUTE_JOIN_INNER", "EXECUTE_JOIN_LEFT", "EXECUTE_JOIN_ANTI"):
70
+ explosion = metrics.get("explosion_detected", False)
71
+ if explosion:
72
+ reward += REWARD_TABLE["row_explosion"]
73
+ else:
74
+ join_rows = metrics.get("join_rows", 0)
75
+ fact_rows = metrics.get("fact_rows", 1)
76
+ # "Dirty join" = join executed without PREP_KEYS and produced 0 rows
77
+ # even though the source was non-empty
78
+ if join_rows == 0 and fact_rows > 0:
79
+ if not state_before.did_prep_a or not state_before.did_prep_b:
80
+ reward += REWARD_TABLE["dirty_join"]
81
+ else:
82
+ match_rate = metrics.get("match_rate", 0.0)
83
+ if match_rate >= HIGH_MATCH_THRESHOLD:
84
+ reward += REWARD_TABLE["high_match_join"]
85
+
86
+ # Quarantine precision: reward if orphans were quarantined
87
+ quarantine_rows = metrics.get("quarantine_rows", 0)
88
+ if quarantine_rows > 0 and action_type == "EXECUTE_JOIN_LEFT":
89
+ reward += REWARD_TABLE["quarantine_precision"]
90
+
91
+ # Stale processing: ran join while a source was stale (never synced)
92
+ if (state_before.is_stale_a or state_before.is_stale_b) and not state_before.did_sync_check:
93
+ reward += REWARD_TABLE["stale_processing"]
94
+
95
+ elif action_type in ("APPLY_SCD_1", "APPLY_SCD_2"):
96
+ if action_type == "APPLY_SCD_2":
97
+ # Reward if SCD-2 was the right choice (tracked col involved)
98
+ reward += REWARD_TABLE["correct_scd2"]
99
+
100
+ if (state_before.is_stale_a or state_before.is_stale_b) and not state_before.did_sync_check:
101
+ reward += REWARD_TABLE["stale_processing"]
102
+
103
+ elif action_type == "COMMIT":
104
+ # Base commit β€” grader adds bonus/penalty separately
105
+ pass
106
+
107
+ return reward
scenarios.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA scenario generator.
2
+
3
+ Produces randomised Bronze A (Fact) and Bronze B (Dimension) DataFrames to
4
+ drive each training episode. Four canonical scenarios cover the canonical
5
+ failure modes described in the MEDUSA blueprint.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import random
11
+ from dataclasses import dataclass, field
12
+ from typing import List, Optional
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Scenario dataclass
20
+ # ---------------------------------------------------------------------------
21
+
22
+ @dataclass
23
+ class Scenario:
24
+ """One episode's worth of Bronze source data + configuration."""
25
+
26
+ id: str
27
+ bronze_a: pd.DataFrame # Fact table (source of truth for volume)
28
+ bronze_b: pd.DataFrame # Dimension table (must be unique on key)
29
+ join_key: str # Column name used to join A and B
30
+ tracked_cols: List[str] # Columns in B that require SCD-2 history
31
+ is_stale_a: bool # Whether Source A is past the freshness threshold
32
+ is_stale_b: bool
33
+ time_delta_a: float # Hours since Source A was last refreshed
34
+ time_delta_b: float
35
+ new_cols_a: List[str] # Extra columns in A not in Silver yet
36
+ new_cols_b: List[str] # Extra columns in B not in Silver yet
37
+ description: str = ""
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Internal helpers
42
+ # ---------------------------------------------------------------------------
43
+
44
+ _STALE_THRESHOLD_HOURS = 6.0
45
+
46
+
47
+ def _make_fact(
48
+ rng: random.Random,
49
+ n_rows: int,
50
+ key_col: str,
51
+ null_ratio: float = 0.0,
52
+ extra_cols: Optional[List[str]] = None,
53
+ ) -> pd.DataFrame:
54
+ """Create a synthetic Fact (Bronze A) DataFrame."""
55
+ keys = [f"K{i:04d}" for i in rng.sample(range(1, n_rows * 2), n_rows)]
56
+
57
+ # Inject nulls into the key
58
+ null_mask = rng.sample(range(n_rows), int(n_rows * null_ratio))
59
+ for idx in null_mask:
60
+ keys[idx] = None # type: ignore[call-overload]
61
+
62
+ data = {
63
+ key_col: keys,
64
+ "fact_value": [rng.uniform(0, 1000) for _ in range(n_rows)],
65
+ "fact_category": [rng.choice(["A", "B", "C"]) for _ in range(n_rows)],
66
+ "created_at": pd.date_range("2024-01-01", periods=n_rows, freq="h"),
67
+ }
68
+ for col in (extra_cols or []):
69
+ data[col] = [rng.uniform(0, 100) for _ in range(n_rows)]
70
+
71
+ return pd.DataFrame(data)
72
+
73
+
74
+ def _make_dim(
75
+ rng: random.Random,
76
+ n_rows: int,
77
+ key_col: str,
78
+ null_ratio: float = 0.0,
79
+ uniqueness: float = 1.0, # < 1.0 means some keys are duplicated
80
+ match_keys: Optional[List[str]] = None, # If given, use these as the key pool
81
+ extra_cols: Optional[List[str]] = None,
82
+ tracked_cols: Optional[List[str]] = None,
83
+ ) -> pd.DataFrame:
84
+ """Create a synthetic Dimension (Bronze B) DataFrame."""
85
+ if match_keys:
86
+ # Choose from overlap pool to control referential integrity
87
+ available = list(match_keys)
88
+ keys = [rng.choice(available) for _ in range(n_rows)]
89
+ else:
90
+ keys = [f"K{i:04d}" for i in rng.sample(range(1, n_rows * 3), n_rows)]
91
+
92
+ # Inject duplicates (lower uniqueness)
93
+ if uniqueness < 1.0:
94
+ n_dupes = int(n_rows * (1 - uniqueness))
95
+ for i in rng.sample(range(n_rows), n_dupes):
96
+ keys[i] = keys[rng.randint(0, i - 1)] if i > 0 else keys[0]
97
+
98
+ # Inject nulls
99
+ null_mask = rng.sample(range(n_rows), int(n_rows * null_ratio))
100
+ for idx in null_mask:
101
+ keys[idx] = None # type: ignore[call-overload]
102
+
103
+ data: dict = {key_col: keys, "dim_name": [f"Name_{k}" for k in keys]}
104
+ for col in (tracked_cols or []):
105
+ data[col] = [rng.choice(["x", "y", "z"]) for _ in range(n_rows)]
106
+ for col in (extra_cols or []):
107
+ data[col] = [rng.uniform(0, 100) for _ in range(n_rows)]
108
+
109
+ return pd.DataFrame(data)
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # Scenario Generator
114
+ # ---------------------------------------------------------------------------
115
+
116
+ class ScenarioGenerator:
117
+ """Generates Bronze A/B DataFrames for MEDUSA episodes."""
118
+
119
+ STALE_THRESHOLD = _STALE_THRESHOLD_HOURS
120
+ JOIN_KEY = "entity_id"
121
+ TRACKED_COLS = ["dim_status"]
122
+
123
+ # Four canonical scenario types
124
+ CANONICAL: List[str] = ["clean", "dirty_keys", "stale", "schema_drift"]
125
+
126
+ def __init__(self, n_fact_rows: int = 200, n_dim_rows: int = 150):
127
+ self.n_fact_rows = n_fact_rows
128
+ self.n_dim_rows = n_dim_rows
129
+
130
+ def generate(self, seed: Optional[int] = None) -> Scenario:
131
+ """Generate a random scenario. Canonical scenarios cycle through seeds 0-3."""
132
+ rng = random.Random(seed)
133
+ if seed is not None and 0 <= seed < len(self.CANONICAL):
134
+ return self._canonical(self.CANONICAL[seed], seed)
135
+ variant = rng.choice(self.CANONICAL)
136
+ return self._canonical(variant, seed)
137
+
138
+ def _canonical(self, variant: str, seed: Optional[int]) -> Scenario:
139
+ rng = random.Random(seed)
140
+ np_rng = np.random.default_rng(seed)
141
+ key = self.JOIN_KEY
142
+ n_a = self.n_fact_rows
143
+ n_b = self.n_dim_rows
144
+
145
+ if variant == "clean":
146
+ # Fresh, unique keys, ~100% match rate
147
+ fact = _make_fact(rng, n_a, key, null_ratio=0.0)
148
+ valid_keys = fact[key].dropna().tolist()
149
+ dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0,
150
+ match_keys=valid_keys, tracked_cols=self.TRACKED_COLS)
151
+ return Scenario(
152
+ id=f"clean_{seed}",
153
+ bronze_a=fact, bronze_b=dim,
154
+ join_key=key, tracked_cols=self.TRACKED_COLS,
155
+ is_stale_a=False, is_stale_b=False,
156
+ time_delta_a=1.0, time_delta_b=2.0,
157
+ new_cols_a=[], new_cols_b=[],
158
+ description="Clean scenario: fresh, unique keys, high match rate.",
159
+ )
160
+
161
+ elif variant == "dirty_keys":
162
+ # High null ratio in keys, no trimming / type-casting yet
163
+ fact = _make_fact(rng, n_a, key, null_ratio=0.25)
164
+ fact[key] = fact[key].apply(
165
+ lambda k: f" {k} " if k and rng.random() < 0.3 else k # whitespace noise
166
+ )
167
+ dim = _make_dim(rng, n_b, key, null_ratio=0.15, uniqueness=0.85,
168
+ tracked_cols=self.TRACKED_COLS)
169
+ return Scenario(
170
+ id=f"dirty_keys_{seed}",
171
+ bronze_a=fact, bronze_b=dim,
172
+ join_key=key, tracked_cols=self.TRACKED_COLS,
173
+ is_stale_a=False, is_stale_b=False,
174
+ time_delta_a=2.0, time_delta_b=3.0,
175
+ new_cols_a=[], new_cols_b=[],
176
+ description="Dirty keys: nulls + whitespace in join keys.",
177
+ )
178
+
179
+ elif variant == "stale":
180
+ # One or both sources have not refreshed recently
181
+ fact = _make_fact(rng, n_a, key, null_ratio=0.0)
182
+ valid_keys = fact[key].dropna().tolist()
183
+ dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0,
184
+ match_keys=valid_keys, tracked_cols=self.TRACKED_COLS)
185
+ td_a = rng.uniform(8.0, 24.0) # definitely stale
186
+ td_b = rng.uniform(0.5, 4.0)
187
+ return Scenario(
188
+ id=f"stale_{seed}",
189
+ bronze_a=fact, bronze_b=dim,
190
+ join_key=key, tracked_cols=self.TRACKED_COLS,
191
+ is_stale_a=td_a > self.STALE_THRESHOLD,
192
+ is_stale_b=td_b > self.STALE_THRESHOLD,
193
+ time_delta_a=td_a, time_delta_b=td_b,
194
+ new_cols_a=[], new_cols_b=[],
195
+ description=f"Stale scenario: Source A is {td_a:.1f}h old.",
196
+ )
197
+
198
+ else: # schema_drift
199
+ # New columns in A and/or B not yet registered in Silver
200
+ extra_a = ["new_metric_a"]
201
+ extra_b = ["new_attr_b"]
202
+ fact = _make_fact(rng, n_a, key, null_ratio=0.0, extra_cols=extra_a)
203
+ valid_keys = fact[key].dropna().tolist()
204
+ dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0,
205
+ match_keys=valid_keys,
206
+ tracked_cols=self.TRACKED_COLS, extra_cols=extra_b)
207
+ return Scenario(
208
+ id=f"schema_drift_{seed}",
209
+ bronze_a=fact, bronze_b=dim,
210
+ join_key=key, tracked_cols=self.TRACKED_COLS,
211
+ is_stale_a=False, is_stale_b=False,
212
+ time_delta_a=1.0, time_delta_b=1.5,
213
+ new_cols_a=extra_a, new_cols_b=extra_b,
214
+ description="Schema drift: new columns in A and B.",
215
+ )
scripts/inference.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA inference script β€” OpenEnv Hackathon submission.
2
+
3
+ Runs an LLM agent (via OpenAI-compatible API) against all three MEDUSA tasks
4
+ and reports per-task scores (0.0–1.0).
5
+
6
+ Required environment variables:
7
+ API_BASE_URL The API endpoint for the LLM (OpenAI-compatible).
8
+ MODEL_NAME The model identifier to use for inference.
9
+ HF_TOKEN Your Hugging Face / API key (used as the API key).
10
+
11
+ Usage:
12
+ export API_BASE_URL="https://api.openai.com/v1"
13
+ export MODEL_NAME="gpt-4o-mini"
14
+ export HF_TOKEN="hf-..."
15
+ python inference.py
16
+
17
+ Output:
18
+ Prints per-task results and a final summary table to stdout.
19
+ Exits with code 0 if all tasks score >= 0.35, else 1.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import json
25
+ import os
26
+ import sys
27
+ import textwrap
28
+ import time
29
+ from typing import List, Optional
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Validate required environment variables before anything else
33
+ # ---------------------------------------------------------------------------
34
+
35
+ API_BASE_URL = os.environ.get("API_BASE_URL", "").rstrip("/")
36
+ MODEL_NAME = os.environ.get("MODEL_NAME", "")
37
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
38
+
39
+ _missing = [k for k, v in {
40
+ "API_BASE_URL": API_BASE_URL,
41
+ "MODEL_NAME": MODEL_NAME,
42
+ "HF_TOKEN": HF_TOKEN,
43
+ }.items() if not v]
44
+
45
+ if _missing:
46
+ print(f"ERROR: Missing required environment variables: {', '.join(_missing)}", file=sys.stderr)
47
+ print("Set them before running:", file=sys.stderr)
48
+ for k in _missing:
49
+ print(f" export {k}=<value>", file=sys.stderr)
50
+ sys.exit(1)
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # OpenAI client (uses API_BASE_URL + HF_TOKEN as the key)
54
+ # ---------------------------------------------------------------------------
55
+
56
+ from openai import OpenAI # noqa: E402
57
+
58
+ client = OpenAI(
59
+ base_url=API_BASE_URL,
60
+ api_key=HF_TOKEN,
61
+ )
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # MEDUSA environment imports
65
+ # ---------------------------------------------------------------------------
66
+
67
+ from pathlib import Path
68
+
69
+ # Dynamically add the OpenEnv repo root to sys.path so absolute imports work
70
+ # no matter where this script is executed from.
71
+ repo_root = str(Path(__file__).resolve().parent.parent.parent)
72
+ if repo_root not in sys.path:
73
+ sys.path.insert(0, repo_root)
74
+
75
+ try:
76
+ # In-repo
77
+ from envs.medusa_env import MedusaEnv
78
+ from envs.medusa_env.models import MedusaAction, MedusaActionType
79
+ from envs.medusa_env.tasks import TASKS, TaskResult, score_episode
80
+ except ImportError:
81
+ # Standalone (running from inside envs/medusa_env/ installation)
82
+ from medusa_env import MedusaEnv # type: ignore
83
+ from models import MedusaAction, MedusaActionType # type: ignore
84
+ from tasks import TASKS, TaskResult, score_episode # type: ignore
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # System prompt
88
+ # ---------------------------------------------------------------------------
89
+
90
+ SYSTEM_PROMPT = textwrap.dedent("""
91
+ You are a data integration agent controlling a Bronze→Silver ETL pipeline.
92
+
93
+ You observe a 16-float feature vector describing data quality signals, and
94
+ you must choose one action per step from the list below.
95
+
96
+ ACTIONS (respond with ONLY the action name β€” nothing else):
97
+ SYNC_CHECK β€” Verify source freshness before processing
98
+ EVOLVE_SCHEMA β€” Add new columns from sources into Silver schema
99
+ PREP_KEYS_A β€” Clean and normalise join keys in Source A (Fact)
100
+ PREP_KEYS_B β€” Clean and normalise join keys in Source B (Dimension)
101
+ DEDUPLICATE_B β€” Remove duplicate keys from Source B
102
+ EXECUTE_JOIN_INNER β€” Inner join A β‹ˆ B
103
+ EXECUTE_JOIN_LEFT β€” Left join A β‹ˆ B (keeps all Fact rows; orphans β†’ quarantine)
104
+ EXECUTE_JOIN_ANTI β€” Anti-join: extract Fact rows with no Dimension match
105
+ APPLY_SCD_1 β€” Overwrite Silver records (SCD Type 1)
106
+ APPLY_SCD_2 β€” Close old records and insert new with timestamps (SCD Type 2)
107
+ COMMIT β€” Finalise pipeline and trigger audit
108
+
109
+ STRATEGY:
110
+ 1. Always call SYNC_CHECK first to verify freshness.
111
+ 2. If schema drift signals are non-zero (features[9] or [10] > 0), call EVOLVE_SCHEMA.
112
+ 3. If null key ratios (features[4] or [5] > 0), call PREP_KEYS_A and/or PREP_KEYS_B.
113
+ 4. If Dimension uniqueness (features[7]) < 1.0, call DEDUPLICATE_B.
114
+ 5. Prefer EXECUTE_JOIN_LEFT to preserve all Fact rows.
115
+ 6. Prefer APPLY_SCD_2 for tracked history.
116
+ 7. Call COMMIT when pipeline is complete.
117
+
118
+ The feature vector indices:
119
+ [0] time_delta_a_norm [1] time_delta_b_norm
120
+ [2] is_stale_a [3] is_stale_b
121
+ [4] null_ratio_key_a [5] null_ratio_key_b
122
+ [6] uniqueness_a [7] uniqueness_b
123
+ [8] match_rate [9] new_cols_a_norm
124
+ [10] new_cols_b_norm [11] schema_compat
125
+ [12] did_prep_a [13] did_prep_b
126
+ [14] did_dedup_b [15] step_frac
127
+ """).strip()
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # LLM action chooser
131
+ # ---------------------------------------------------------------------------
132
+
133
+ VALID_ACTIONS = {a.value for a in MedusaActionType}
134
+
135
+
136
+ def choose_action(
137
+ features: List[float],
138
+ history: List[dict],
139
+ step: int,
140
+ ) -> str:
141
+ """Ask the LLM to choose the next action given the current observation."""
142
+ feature_str = ", ".join(f"{v:.3f}" for v in features)
143
+ user_msg = (
144
+ f"Step {step}. Feature vector: [{feature_str}]\n"
145
+ "What is the single best next action? Respond with ONLY the action name."
146
+ )
147
+
148
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
149
+ # Include the last 4 steps of history for context (keep prompt short)
150
+ for h in history[-4:]:
151
+ messages.append({"role": "user", "content": h["user"]})
152
+ messages.append({"role": "assistant", "content": h["assistant"]})
153
+ messages.append({"role": "user", "content": user_msg})
154
+
155
+ response = client.chat.completions.create(
156
+ model=MODEL_NAME,
157
+ messages=messages,
158
+ max_tokens=20,
159
+ temperature=0.0,
160
+ )
161
+ raw = response.choices[0].message.content.strip().upper().replace(" ", "_")
162
+
163
+ # Fuzzy match: accept if the response contains a valid action name
164
+ for action in VALID_ACTIONS:
165
+ if action in raw:
166
+ return action
167
+
168
+ # Fallback: extract the longest matching token
169
+ for action in sorted(VALID_ACTIONS, key=len, reverse=True):
170
+ if action.replace("_", "") in raw.replace("_", ""):
171
+ return action
172
+
173
+ # Hard fallback: commit to end gracefully
174
+ return MedusaActionType.COMMIT.value
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # Run one task
179
+ # ---------------------------------------------------------------------------
180
+
181
+ def run_task(task_id: str, max_steps: int = 15) -> TaskResult:
182
+ """Run the LLM agent for one MEDUSA task. Returns the TaskResult."""
183
+ task = TASKS[task_id]
184
+ print(f"\n{'='*60}")
185
+ print(f"TASK: {task.name} [{task.difficulty.upper()}] (seed={task.seed})")
186
+ print(f" {task.description}")
187
+ print(f"{'='*60}")
188
+
189
+ env = MedusaEnv(n_fact_rows=200, n_dim_rows=150, max_steps=max_steps)
190
+ obs = env.reset(seed=task.seed)
191
+
192
+ history: List[dict] = []
193
+ step = 0
194
+ t0 = time.time()
195
+
196
+ while not obs.done and step < max_steps:
197
+ step += 1
198
+ action_str = choose_action(obs.features, history, step)
199
+ action_type = MedusaActionType(action_str)
200
+ action = MedusaAction(action=action_type)
201
+
202
+ obs = env.step(action)
203
+ reward = obs.reward or 0.0
204
+
205
+ print(f" Step {step:2d}: {action_str:25s} reward={reward:+7.2f} "
206
+ f"cumulative={env.state.cumulative_reward:+8.2f}")
207
+
208
+ history.append({
209
+ "user": (f"Step {step}. Features: [{', '.join(f'{v:.3f}' for v in obs.features)}]"
210
+ " What action?"),
211
+ "assistant": action_str,
212
+ })
213
+
214
+ elapsed = time.time() - t0
215
+ result = score_episode(task_id, env.state, env._tables)
216
+
217
+ print(f"\n β†’ Score: {result.score:.4f} Grade: {result.grade} "
218
+ f"Passed: {result.passed} ({elapsed:.1f}s)")
219
+ if result.notes:
220
+ for note in result.notes:
221
+ print(f" ⚠ {note}")
222
+ print(f" β†’ Breakdown: " +
223
+ ", ".join(f"{k}={v:.2f}" for k, v in result.breakdown.items()))
224
+ return result
225
+
226
+
227
+ # ---------------------------------------------------------------------------
228
+ # Main
229
+ # ---------------------------------------------------------------------------
230
+
231
+ def main() -> None:
232
+ print("MEDUSA β€” Baseline Inference")
233
+ print(f"Model: {MODEL_NAME}")
234
+ print(f"API: {API_BASE_URL}")
235
+ print()
236
+
237
+ task_ids = ["clean_pipeline", "dirty_integration", "full_medallion"]
238
+ results: dict[str, TaskResult] = {}
239
+ total_start = time.time()
240
+
241
+ for task_id in task_ids:
242
+ result = run_task(task_id)
243
+ results[task_id] = result
244
+
245
+ total_elapsed = time.time() - total_start
246
+
247
+ # Summary
248
+ print(f"\n{'='*60}")
249
+ print("SUMMARY")
250
+ print(f"{'='*60}")
251
+ print(f"{'Task':<25} {'Difficulty':<8} {'Score':>6} {'Grade':>5} {'Pass?':>5}")
252
+ print("-" * 60)
253
+ all_passed = True
254
+ for task_id, result in results.items():
255
+ task = TASKS[task_id]
256
+ print(f"{task.name:<25} {task.difficulty:<8} "
257
+ f"{result.score:>6.4f} {result.grade:>5} {'YES' if result.passed else 'NO':>5}")
258
+ if not result.passed:
259
+ all_passed = False
260
+
261
+ print("-" * 60)
262
+ avg = sum(r.score for r in results.values()) / len(results)
263
+ print(f"{'Average':<25} {'':8} {avg:>6.4f}")
264
+ print(f"\nTotal time: {total_elapsed:.1f}s")
265
+
266
+ # Machine-readable output for the evaluator
267
+ output = {
268
+ "model": MODEL_NAME,
269
+ "tasks": {
270
+ tid: {
271
+ "score": r.score,
272
+ "grade": r.grade,
273
+ "passed": r.passed,
274
+ "breakdown": r.breakdown,
275
+ }
276
+ for tid, r in results.items()
277
+ },
278
+ "average_score": avg,
279
+ "all_passed": all_passed,
280
+ }
281
+ print("\n--- JSON RESULTS ---")
282
+ print(json.dumps(output, indent=2))
283
+
284
+ sys.exit(0 if all_passed else 1)
285
+
286
+
287
+ if __name__ == "__main__":
288
+ main()
server/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """FastAPI server package for medusa_env."""
2
+ from .medusa_env import MedusaEnv
3
+
4
+ __all__ = [
5
+ "MedusaEnv"
6
+ ]
server/app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server for the MEDUSA environment.
2
+
3
+ Usage:
4
+ # Development:
5
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
6
+
7
+ # Via openenv CLI:
8
+ openenv serve medusa_env
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ # Support three import contexts:
14
+ # 1. In-repo (from OpenEnv root): relative imports via `..`
15
+ # 2. Standalone installed (uv run server): medusa_env.* package
16
+ # 3. Direct execution inside env dir: bare module names
17
+ from openenv.core.env_server.http_server import create_app
18
+ from medusa_env.server import MedusaEnv
19
+ from medusa_env.models import MedusaAction, MedusaObservation
20
+
21
+ app = create_app(
22
+ MedusaEnv,
23
+ MedusaAction,
24
+ MedusaObservation,
25
+ env_name="medusa_env",
26
+ )
27
+
28
+
29
+ def main() -> None:
30
+ """Entry point for direct execution."""
31
+ import uvicorn
32
+
33
+ uvicorn.run(app, host="0.0.0.0", port=8000)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()
server/medusa_env.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA β€” full environment implementation.
2
+
3
+ Replaces the Phase-1 skeleton with a complete reset/step pipeline that:
4
+ β€’ Generates Bronze A/B data from ``ScenarioGenerator``
5
+ β€’ Dispatches each action to the appropriate operator
6
+ β€’ Computes per-step rewards via ``RewardEngine``
7
+ β€’ Runs the deterministic grader on COMMIT
8
+ β€’ Builds a 16-float normalized feature vector for the RL agent
9
+ β€’ Maintains a governance log of every decision
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import time
16
+ import uuid
17
+ from dataclasses import dataclass, field
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ import pandas as pd
21
+
22
+ from openenv.core.env_server.interfaces import Environment
23
+ from openenv.core.env_server.types import EnvironmentMetadata
24
+
25
+ from medusa_env.grader import Grader
26
+ from medusa_env.models import MedusaAction, MedusaActionType, MedusaObservation, MedusaState
27
+ from medusa_env.operators import (
28
+ apply_scd,
29
+ deduplicate,
30
+ evolve_schema,
31
+ execute_join,
32
+ prep_keys,
33
+ sync_check,
34
+ )
35
+ from medusa_env.rewards import RewardEngine
36
+ from medusa_env.scenarios import Scenario, ScenarioGenerator
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Internal episode tables
41
+ # ---------------------------------------------------------------------------
42
+
43
+ @dataclass
44
+ class _EpisodeTables:
45
+ """In-memory tables for one episode."""
46
+
47
+ bronze_a: pd.DataFrame = field(default_factory=pd.DataFrame)
48
+ bronze_a_prepped: pd.DataFrame = field(default_factory=pd.DataFrame)
49
+ bronze_b: pd.DataFrame = field(default_factory=pd.DataFrame)
50
+ bronze_b_prepped: pd.DataFrame = field(default_factory=pd.DataFrame)
51
+ joined: pd.DataFrame = field(default_factory=pd.DataFrame)
52
+ silver: pd.DataFrame = field(default_factory=pd.DataFrame)
53
+ quarantine: pd.DataFrame = field(default_factory=pd.DataFrame)
54
+ governance_log: List[Dict[str, Any]] = field(default_factory=list)
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Feature vector builder
59
+ # ---------------------------------------------------------------------------
60
+
61
+ _MAX_TIME_DELTA = 48.0 # Normalisation ceiling (hours)
62
+ _MAX_COLS = 10.0 # Normalisation ceiling (new columns)
63
+
64
+
65
+ def _build_features(state: MedusaState) -> List[float]:
66
+ """Build the 16-float normalised observation vector."""
67
+ return [
68
+ min(state.time_delta_a / _MAX_TIME_DELTA, 1.0),
69
+ min(state.time_delta_b / _MAX_TIME_DELTA, 1.0),
70
+ float(state.is_stale_a),
71
+ float(state.is_stale_b),
72
+ state.null_ratio_key_a,
73
+ state.null_ratio_key_b,
74
+ state.uniqueness_a,
75
+ state.uniqueness_b,
76
+ state.match_rate,
77
+ min(state.new_cols_a / _MAX_COLS, 1.0),
78
+ min(state.new_cols_b / _MAX_COLS, 1.0),
79
+ state.schema_compat,
80
+ float(state.did_prep_a),
81
+ float(state.did_prep_b),
82
+ float(state.did_dedup_b),
83
+ min(state.step_idx / max(state.max_steps, 1), 1.0),
84
+ ]
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # Main environment
89
+ # ---------------------------------------------------------------------------
90
+
91
+ class MedusaEnv(Environment[MedusaAction, MedusaObservation, MedusaState]):
92
+ """MEDUSA: Medallion-Engineered Deterministic Unified Storage Agent.
93
+
94
+ Simulates a Bronze→Silver data integration pipeline. The agent observes
95
+ data quality signals and chooses ETL actions to produce a correct,
96
+ historically consistent Silver entity.
97
+
98
+ Args:
99
+ scenario_seed: Fixed seed for deterministic episodes. ``None`` = random.
100
+ max_steps: Maximum steps per episode before forced termination.
101
+ stale_threshold_hours: Age (hours) at which a source is deemed stale.
102
+ n_fact_rows: Size of the Fact / Source A table.
103
+ n_dim_rows: Size of the Dimension / Source B table.
104
+ """
105
+
106
+ SUPPORTS_CONCURRENT_SESSIONS = True
107
+
108
+ def __init__(
109
+ self,
110
+ scenario_seed: Optional[int] = None,
111
+ max_steps: int = 20,
112
+ stale_threshold_hours: float = 6.0,
113
+ n_fact_rows: int = 200,
114
+ n_dim_rows: int = 150,
115
+ **kwargs: Any,
116
+ ):
117
+ super().__init__(**kwargs)
118
+ self._scenario_seed = scenario_seed
119
+ self._max_steps = max_steps
120
+ self._stale_threshold = stale_threshold_hours
121
+
122
+ self._generator = ScenarioGenerator(
123
+ n_fact_rows=n_fact_rows, n_dim_rows=n_dim_rows
124
+ )
125
+ self._reward_engine = RewardEngine()
126
+ self._grader = Grader()
127
+
128
+ self._state = MedusaState()
129
+ self._tables = _EpisodeTables()
130
+ self._scenario: Optional[Scenario] = None
131
+
132
+ # ------------------------------------------------------------------
133
+ # Metadata
134
+ # ------------------------------------------------------------------
135
+
136
+ def get_metadata(self) -> EnvironmentMetadata:
137
+ return EnvironmentMetadata(
138
+ name="medusa_env",
139
+ description=(
140
+ "MEDUSA: simulated Bronze→Silver integration controller for "
141
+ "multi-source joins, schema drift, and SCD merges."
142
+ ),
143
+ version="0.2.0",
144
+ documentation="envs/medusa_env/README.md",
145
+ )
146
+
147
+ # ------------------------------------------------------------------
148
+ # State
149
+ # ------------------------------------------------------------------
150
+
151
+ @property
152
+ def state(self) -> MedusaState:
153
+ return self._state
154
+
155
+ # ------------------------------------------------------------------
156
+ # Reset
157
+ # ------------------------------------------------------------------
158
+
159
+ def reset(
160
+ self,
161
+ seed: Optional[int] = None,
162
+ episode_id: Optional[str] = None,
163
+ **kwargs: Any,
164
+ ) -> MedusaObservation:
165
+ self._reset_rubric()
166
+
167
+ effective_seed = seed if seed is not None else self._scenario_seed
168
+ run_id = episode_id or str(uuid.uuid4())
169
+
170
+ # Generate scenario
171
+ self._scenario = self._generator.generate(seed=effective_seed)
172
+ scen = self._scenario
173
+
174
+ # Initialise tables
175
+ self._tables = _EpisodeTables(
176
+ bronze_a=scen.bronze_a.copy(),
177
+ bronze_a_prepped=scen.bronze_a.copy(),
178
+ bronze_b=scen.bronze_b.copy(),
179
+ bronze_b_prepped=scen.bronze_b.copy(),
180
+ )
181
+
182
+ # Compute initial key health metrics from raw Bronze
183
+ na_a = scen.bronze_a[scen.join_key].isna().sum()
184
+ na_b = scen.bronze_b[scen.join_key].isna().sum()
185
+ null_ratio_a = na_a / max(len(scen.bronze_a), 1)
186
+ null_ratio_b = na_b / max(len(scen.bronze_b), 1)
187
+
188
+ # Uniqueness of raw keys
189
+ nna_a = scen.bronze_a[scen.join_key].dropna()
190
+ nna_b = scen.bronze_b[scen.join_key].dropna()
191
+ uniq_a = nna_a.nunique() / max(len(nna_a), 1)
192
+ uniq_b = nna_b.nunique() / max(len(nna_b), 1)
193
+
194
+ # Match rate on raw keys
195
+ keys_a = set(nna_a.astype(str))
196
+ keys_b = set(nna_b.astype(str))
197
+ match_rate = len(keys_a & keys_b) / max(len(keys_a), 1)
198
+
199
+ self._state = MedusaState(
200
+ run_id=run_id,
201
+ seed=effective_seed,
202
+ scenario_id=scen.id,
203
+ max_steps=self._max_steps,
204
+ step_idx=0,
205
+ stage="running",
206
+ time_delta_a=scen.time_delta_a,
207
+ time_delta_b=scen.time_delta_b,
208
+ is_stale_a=scen.is_stale_a,
209
+ is_stale_b=scen.is_stale_b,
210
+ null_ratio_key_a=float(null_ratio_a),
211
+ null_ratio_key_b=float(null_ratio_b),
212
+ uniqueness_a=float(uniq_a),
213
+ uniqueness_b=float(uniq_b),
214
+ match_rate=float(match_rate),
215
+ new_cols_a=len(scen.new_cols_a),
216
+ new_cols_b=len(scen.new_cols_b),
217
+ source_a_row_count=len(scen.bronze_a),
218
+ )
219
+
220
+ features = _build_features(self._state)
221
+ obs = MedusaObservation(
222
+ message=(
223
+ f"MEDUSA episode started. Scenario: {scen.id}. "
224
+ f"{scen.description} "
225
+ f"Source A: {len(scen.bronze_a)} rows | "
226
+ f"Source B: {len(scen.bronze_b)} rows."
227
+ ),
228
+ features=features,
229
+ metrics={
230
+ "scenario_id": scen.id,
231
+ "null_ratio_key_a": null_ratio_a,
232
+ "null_ratio_key_b": null_ratio_b,
233
+ "match_rate": match_rate,
234
+ "is_stale_a": scen.is_stale_a,
235
+ "is_stale_b": scen.is_stale_b,
236
+ "new_cols_a": scen.new_cols_a,
237
+ "new_cols_b": scen.new_cols_b,
238
+ },
239
+ metadata={"run_id": run_id, "seed": effective_seed},
240
+ reward=None,
241
+ done=False,
242
+ )
243
+ return self._apply_transform(obs)
244
+
245
+ # ------------------------------------------------------------------
246
+ # Step
247
+ # ------------------------------------------------------------------
248
+
249
+ def step(
250
+ self,
251
+ action: MedusaAction,
252
+ timeout_s: Optional[float] = None,
253
+ **kwargs: Any,
254
+ ) -> MedusaObservation:
255
+ if self._state.stage != "running":
256
+ return self._apply_transform(MedusaObservation(
257
+ message=f"Episode not running (stage={self._state.stage}). Call reset().",
258
+ done=True,
259
+ reward=0.0,
260
+ features=_build_features(self._state),
261
+ metadata={"run_id": self._state.run_id},
262
+ ))
263
+
264
+ # Snapshot state *before* applying action (for reward evaluation)
265
+ state_before = copy.copy(self._state)
266
+ self._state.step_idx += 1
267
+
268
+ action_type = action.action
269
+ metrics: dict = {}
270
+ step_message = ""
271
+
272
+ scen = self._scenario
273
+ assert scen is not None, "reset() must be called before step()"
274
+
275
+ # ── Dispatch ──────────────────────────────────────────────────
276
+ try:
277
+ if action_type == MedusaActionType.SYNC_CHECK:
278
+ _, metrics = sync_check(
279
+ self._tables.bronze_a,
280
+ self._tables.bronze_b,
281
+ scen.time_delta_a,
282
+ scen.time_delta_b,
283
+ self._stale_threshold,
284
+ )
285
+ self._state.did_sync_check = True
286
+ step_message = (
287
+ f"SYNC_CHECK: A={scen.time_delta_a:.1f}h "
288
+ f"{'[STALE]' if scen.is_stale_a else '[FRESH]'} | "
289
+ f"B={scen.time_delta_b:.1f}h "
290
+ f"{'[STALE]' if scen.is_stale_b else '[FRESH]'}"
291
+ )
292
+
293
+ elif action_type == MedusaActionType.EVOLVE_SCHEMA:
294
+ result_df, metrics = evolve_schema(
295
+ self._tables.silver,
296
+ self._tables.bronze_a,
297
+ self._tables.bronze_b,
298
+ scen.new_cols_a,
299
+ scen.new_cols_b,
300
+ )
301
+ if result_df is not None:
302
+ self._tables.silver = result_df
303
+ self._state.did_evolve_schema = True
304
+ step_message = f"EVOLVE_SCHEMA: added {metrics.get('new_cols_count', 0)} column(s)."
305
+
306
+ elif action_type == MedusaActionType.PREP_KEYS_A:
307
+ result_df, metrics = prep_keys(
308
+ self._tables.bronze_a_prepped, scen.join_key
309
+ )
310
+ if result_df is not None:
311
+ self._tables.bronze_a_prepped = result_df
312
+ self._state.did_prep_a = True
313
+ self._state.null_ratio_key_a = float(metrics.get("null_ratio_after", 0.0))
314
+ step_message = (
315
+ f"PREP_KEYS_A: null ratio {metrics.get('null_ratio_before', 0):.2%}"
316
+ f"β†’{metrics.get('null_ratio_after', 0):.2%}."
317
+ )
318
+
319
+ elif action_type == MedusaActionType.PREP_KEYS_B:
320
+ result_df, metrics = prep_keys(
321
+ self._tables.bronze_b_prepped, scen.join_key
322
+ )
323
+ if result_df is not None:
324
+ self._tables.bronze_b_prepped = result_df
325
+ self._state.did_prep_b = True
326
+ self._state.null_ratio_key_b = float(metrics.get("null_ratio_after", 0.0))
327
+ step_message = (
328
+ f"PREP_KEYS_B: null ratio {metrics.get('null_ratio_before', 0):.2%}"
329
+ f"β†’{metrics.get('null_ratio_after', 0):.2%}."
330
+ )
331
+
332
+ elif action_type == MedusaActionType.DEDUPLICATE_B:
333
+ result_df, metrics = deduplicate(
334
+ self._tables.bronze_b_prepped, scen.join_key
335
+ )
336
+ if result_df is not None:
337
+ self._tables.bronze_b_prepped = result_df
338
+ self._state.did_dedup_b = True
339
+ self._state.uniqueness_b = float(metrics.get("uniqueness", 1.0))
340
+ step_message = f"DEDUPLICATE_B: removed {metrics.get('dupes_removed', 0)} duplicate(s)."
341
+
342
+ elif action_type in {
343
+ MedusaActionType.EXECUTE_JOIN_INNER,
344
+ MedusaActionType.EXECUTE_JOIN_LEFT,
345
+ MedusaActionType.EXECUTE_JOIN_ANTI,
346
+ }:
347
+ join_map = {
348
+ MedusaActionType.EXECUTE_JOIN_INNER: "inner",
349
+ MedusaActionType.EXECUTE_JOIN_LEFT: "left",
350
+ MedusaActionType.EXECUTE_JOIN_ANTI: "anti",
351
+ }
352
+ join_type_str = join_map[action_type]
353
+ joined, quarantine, metrics = execute_join(
354
+ self._tables.bronze_a_prepped,
355
+ self._tables.bronze_b_prepped,
356
+ scen.join_key,
357
+ join_type_str,
358
+ )
359
+ self._tables.joined = joined
360
+ self._tables.quarantine = quarantine
361
+ self._state.did_join = True
362
+ self._state.join_type = join_type_str
363
+ self._state.join_row_count = int(metrics.get("join_rows", 0))
364
+ self._state.explosion_detected = bool(metrics.get("explosion_detected", False))
365
+ self._state.match_rate = float(metrics.get("match_rate", 0.0))
366
+ self._state.quarantine_row_count = len(quarantine)
367
+ step_message = (
368
+ f"EXECUTE_JOIN ({join_type_str.upper()}): "
369
+ f"{self._state.join_row_count} rows | "
370
+ f"match_rate={self._state.match_rate:.1%} | "
371
+ f"quarantine={self._state.quarantine_row_count} | "
372
+ f"{'⚠ EXPLOSION' if self._state.explosion_detected else 'OK'}"
373
+ )
374
+
375
+ elif action_type in {MedusaActionType.APPLY_SCD_1, MedusaActionType.APPLY_SCD_2}:
376
+ scd_type_int = 1 if action_type == MedusaActionType.APPLY_SCD_1 else 2
377
+ tracked_col = scen.tracked_cols[0] if scen.tracked_cols else scen.join_key
378
+ result_df, metrics = apply_scd(
379
+ self._tables.silver,
380
+ self._tables.joined,
381
+ scen.join_key,
382
+ tracked_col,
383
+ scd_type_int,
384
+ )
385
+ if result_df is not None:
386
+ self._tables.silver = result_df
387
+ self._state.did_scd = True
388
+ self._state.scd_type = f"SCD-{scd_type_int}"
389
+ self._state.scd_inserts = int(metrics.get("inserts", 0))
390
+ self._state.scd_updates = int(metrics.get("updates", 0))
391
+ self._state.silver_row_count = int(metrics.get("silver_rows", 0))
392
+ step_message = (
393
+ f"APPLY_SCD-{scd_type_int}: "
394
+ f"{self._state.scd_inserts} inserts, "
395
+ f"{self._state.scd_updates} updates β†’ "
396
+ f"Silver {self._state.silver_row_count} rows."
397
+ )
398
+
399
+ elif action_type == MedusaActionType.COMMIT:
400
+ return self._do_commit(state_before)
401
+
402
+ except Exception as exc: # noqa: BLE001
403
+ step_message = f"ERROR in {action_type}: {exc}"
404
+ metrics = {"error": str(exc)}
405
+
406
+ # ── Reward ────────────────────────────────────────────────────
407
+ reward = self._reward_engine.evaluate(
408
+ action_type=action_type.value,
409
+ metrics=metrics,
410
+ state_before=state_before,
411
+ )
412
+ self._state.cumulative_reward += reward
413
+
414
+ # ── Governance log ────────────────────────────────────────────
415
+ self._tables.governance_log.append({
416
+ "step": self._state.step_idx,
417
+ "action": action_type.value,
418
+ "reward": reward,
419
+ "cumulative_reward": self._state.cumulative_reward,
420
+ "metrics": metrics,
421
+ "timestamp": time.time(),
422
+ })
423
+
424
+ # Check step limit
425
+ done = self._state.step_idx >= self._state.max_steps
426
+ if done:
427
+ self._state.stage = "failed"
428
+ step_message += " [MAX STEPS REACHED]"
429
+
430
+ features = _build_features(self._state)
431
+ obs = MedusaObservation(
432
+ message=step_message,
433
+ features=features,
434
+ metrics=metrics,
435
+ metadata={
436
+ "run_id": self._state.run_id,
437
+ "step": self._state.step_idx,
438
+ "cumulative_reward": self._state.cumulative_reward,
439
+ },
440
+ reward=reward,
441
+ done=done,
442
+ )
443
+ return self._apply_transform(obs)
444
+
445
+ # ------------------------------------------------------------------
446
+ # Commit (terminal step)
447
+ # ------------------------------------------------------------------
448
+
449
+ def _do_commit(self, state_before: MedusaState) -> MedusaObservation:
450
+ """Run grader then finalise the episode."""
451
+ scen = self._scenario
452
+ assert scen is not None
453
+
454
+ # Base step reward
455
+ reward = self._reward_engine.evaluate(
456
+ action_type=MedusaActionType.COMMIT.value,
457
+ metrics={},
458
+ state_before=state_before,
459
+ )
460
+
461
+ # Grader audit
462
+ grader_result = self._grader.audit(
463
+ silver=self._tables.silver,
464
+ quarantine=self._tables.quarantine,
465
+ bronze_a=scen.bronze_a,
466
+ bronze_b=scen.bronze_b,
467
+ join_key=scen.join_key,
468
+ join_type=self._state.join_type or "left",
469
+ scd_type=int(self._state.scd_type[-1]) if self._state.scd_type else 1,
470
+ scenario=scen,
471
+ )
472
+ reward += grader_result.bonus_reward
473
+ self._state.grader_passed = grader_result.passed
474
+ self._state.grader_report = grader_result.report
475
+ self._state.cumulative_reward += reward
476
+ self._state.silver_row_count = len(self._tables.silver)
477
+ self._state.quarantine_row_count = len(self._tables.quarantine)
478
+ self._state.stage = "committed"
479
+
480
+ self._tables.governance_log.append({
481
+ "step": self._state.step_idx,
482
+ "action": "COMMIT",
483
+ "reward": reward,
484
+ "cumulative_reward": self._state.cumulative_reward,
485
+ "grader_passed": grader_result.passed,
486
+ "grader_report": grader_result.report,
487
+ "timestamp": time.time(),
488
+ })
489
+
490
+ features = _build_features(self._state)
491
+ obs = MedusaObservation(
492
+ message=(
493
+ f"COMMIT: episode finalized. "
494
+ f"{'Grader: PASS βœ“' if grader_result.passed else 'Grader: FAIL βœ—'} "
495
+ f"Bonus: {grader_result.bonus_reward:+.1f} | "
496
+ f"Total reward: {self._state.cumulative_reward:.1f}"
497
+ ),
498
+ features=features,
499
+ metrics={
500
+ "grader_passed": grader_result.passed,
501
+ "grader_report": grader_result.report,
502
+ "silver_rows": self._state.silver_row_count,
503
+ "quarantine_rows": self._state.quarantine_row_count,
504
+ "governance_log_entries": len(self._tables.governance_log),
505
+ },
506
+ metadata={
507
+ "run_id": self._state.run_id,
508
+ "steps": self._state.step_idx,
509
+ "cumulative_reward": self._state.cumulative_reward,
510
+ },
511
+ reward=reward,
512
+ done=True,
513
+ )
514
+ return self._apply_transform(obs)
tasks.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDUSA Task Definitions.
2
+
3
+ Three formally graded tasks covering the easy β†’ medium β†’ hard spectrum.
4
+ Each task returns a deterministic score in [0.0, 1.0] after COMMIT.
5
+
6
+ Usage::
7
+
8
+ from envs.medusa_env.tasks import TASKS, score_episode
9
+
10
+ task = TASKS["clean_pipeline"] # easy
11
+ env = MedusaEnv(n_fact_rows=200, n_dim_rows=150)
12
+ obs = env.reset(seed=task.seed)
13
+
14
+ # ... agent takes actions ...
15
+ obs = env.step(MedusaAction(action=MedusaActionType.COMMIT))
16
+
17
+ result = score_episode(task.id, env.state, env._tables)
18
+ print(f"Score: {result.score:.2f} ({result.grade})")
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass, field
24
+ from typing import TYPE_CHECKING, Dict, List, Optional
25
+
26
+ if TYPE_CHECKING:
27
+ from .medusa_env import _EpisodeTables
28
+ from .models import MedusaState
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Task definition
33
+ # ---------------------------------------------------------------------------
34
+
35
+ @dataclass
36
+ class Task:
37
+ """A MEDUSA task definition."""
38
+
39
+ id: str
40
+ name: str
41
+ difficulty: str # "easy" | "medium" | "hard"
42
+ seed: int # Controls ScenarioGenerator variant
43
+ description: str
44
+ success_criteria: List[str]
45
+ scoring_rubric: Dict[str, float]
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Scoring result
50
+ # ---------------------------------------------------------------------------
51
+
52
+ @dataclass
53
+ class TaskResult:
54
+ """Outcome of scoring a completed episode against a task."""
55
+
56
+ task_id: str
57
+ score: float # 0.0 – 1.0
58
+ grade: str # "S" | "A" | "B" | "C" | "F"
59
+ breakdown: Dict[str, float] # per-criterion scores
60
+ passed: bool
61
+ notes: List[str] = field(default_factory=list)
62
+
63
+
64
+ def _grade(score: float) -> str:
65
+ if score >= 0.90:
66
+ return "S"
67
+ if score >= 0.75:
68
+ return "A"
69
+ if score >= 0.55:
70
+ return "B"
71
+ if score >= 0.35:
72
+ return "C"
73
+ return "F"
74
+
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # Task catalogue
78
+ # ---------------------------------------------------------------------------
79
+
80
+ TASKS: Dict[str, Task] = {
81
+
82
+ # ── EASY: Clean Pipeline ────────────────────────────────────────────────
83
+ "clean_pipeline": Task(
84
+ id="clean_pipeline",
85
+ name="Clean Pipeline",
86
+ difficulty="easy",
87
+ seed=0,
88
+ description=(
89
+ "Both sources are fresh. Join keys are clean and unique. "
90
+ "The agent must verify freshness, prepare keys, join, apply SCD, "
91
+ "and commit without triggering a row explosion."
92
+ ),
93
+ success_criteria=[
94
+ "COMMIT issued (episode finalized)",
95
+ "No Cartesian explosion detected",
96
+ "Silver row count ≀ Source A row count",
97
+ "match_rate > 0.80 after join",
98
+ ],
99
+ scoring_rubric={
100
+ "committed": 0.20, # Agent issued COMMIT
101
+ "no_explosion": 0.25, # No row explosion
102
+ "volume_ok": 0.20, # Silver ≀ Source A rows
103
+ "high_match": 0.20, # match_rate > 0.80
104
+ "grader_pass": 0.15, # All 4 grader checks pass
105
+ },
106
+ ),
107
+
108
+ # ── MEDIUM: Dirty Integration ───────────────────────────────────────────
109
+ "dirty_integration": Task(
110
+ id="dirty_integration",
111
+ name="Dirty Key Integration",
112
+ difficulty="medium",
113
+ seed=1,
114
+ description=(
115
+ "Source A has NULLs and whitespace in join keys. "
116
+ "Source B has duplicate keys that can cause row explosion. "
117
+ "The agent must PREP_KEYS and DEDUPLICATE before joining, "
118
+ "and correctly quarantine unresolvable orphans."
119
+ ),
120
+ success_criteria=[
121
+ "PREP_KEYS_A issued before EXECUTE_JOIN",
122
+ "PREP_KEYS_B issued before EXECUTE_JOIN",
123
+ "DEDUPLICATE_B issued before EXECUTE_JOIN",
124
+ "No row explosion",
125
+ "Quarantine integrity check passes",
126
+ ],
127
+ scoring_rubric={
128
+ "committed": 0.10,
129
+ "prepped_before_join": 0.20, # Both PREP_KEYS before join
130
+ "deduped_before_join": 0.20, # DEDUP before join
131
+ "no_explosion": 0.25,
132
+ "integrity_ok": 0.15, # Quarantine holds true orphans only
133
+ "grader_pass": 0.10,
134
+ },
135
+ ),
136
+
137
+ # ── HARD: Full Medallion Integration ────────────────────────────────────
138
+ "full_medallion": Task(
139
+ id="full_medallion",
140
+ name="Full Medallion Integration",
141
+ difficulty="hard",
142
+ seed=2,
143
+ description=(
144
+ "Source A is stale (>6h old). Source B has new schema columns "
145
+ "not registered in Silver. The agent must: check freshness, "
146
+ "evolve the schema, clean keys, deduplicate, execute a left join, "
147
+ "apply SCD-2 for tracked columns, and pass all grader checks."
148
+ ),
149
+ success_criteria=[
150
+ "SYNC_CHECK issued before any join",
151
+ "EVOLVE_SCHEMA issued before COMMIT",
152
+ "SCD-2 applied (not SCD-1) for tracked column",
153
+ "Silver schema contains new columns from drift",
154
+ "All 4 grader checks pass",
155
+ ],
156
+ scoring_rubric={
157
+ "committed": 0.05,
158
+ "sync_checked": 0.15, # SYNC_CHECK before join
159
+ "schema_evolved": 0.15, # EVOLVE_SCHEMA called
160
+ "used_scd2": 0.20, # Chose SCD-2 over SCD-1
161
+ "schema_ok": 0.20, # Silver has all required columns
162
+ "grader_pass": 0.25, # All 4 grader checks pass
163
+ },
164
+ ),
165
+ }
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # Scoring engine
170
+ # ---------------------------------------------------------------------------
171
+
172
+ def score_episode(
173
+ task_id: str,
174
+ state: "MedusaState",
175
+ tables: "Optional[_EpisodeTables]" = None,
176
+ ) -> TaskResult:
177
+ """Score a completed MEDUSA episode against the named task.
178
+
179
+ Args:
180
+ task_id: One of "clean_pipeline", "dirty_integration", "full_medallion".
181
+ state: Final ``MedusaState`` after the episode ended.
182
+ tables: Episode tables (used for schema checks). Optional.
183
+
184
+ Returns:
185
+ TaskResult with score in [0.0, 1.0].
186
+ """
187
+ task = TASKS.get(task_id)
188
+ if task is None:
189
+ raise ValueError(f"Unknown task_id={task_id!r}. Valid: {list(TASKS)}")
190
+
191
+ if state.stage not in ("committed", "failed"):
192
+ return TaskResult(
193
+ task_id=task_id, score=0.0, grade="F",
194
+ breakdown={}, passed=False,
195
+ notes=["Episode not finished β€” COMMIT was never issued."],
196
+ )
197
+
198
+ breakdown: Dict[str, float] = {}
199
+ notes: List[str] = []
200
+ rubric = task.scoring_rubric
201
+ committed = state.stage == "committed"
202
+
203
+ # ── Shared criteria ──────────────────────────────────────────────────
204
+ if "committed" in rubric:
205
+ breakdown["committed"] = rubric["committed"] if committed else 0.0
206
+
207
+ if "no_explosion" in rubric:
208
+ ok = not state.explosion_detected
209
+ breakdown["no_explosion"] = rubric["no_explosion"] if ok else 0.0
210
+ if not ok:
211
+ notes.append("Row explosion was detected β€” heavy penalty applied.")
212
+
213
+ if "grader_pass" in rubric:
214
+ breakdown["grader_pass"] = rubric["grader_pass"] if state.grader_passed else 0.0
215
+
216
+ # ── Task-specific criteria ────────────────────────────────────────────
217
+
218
+ if task_id == "clean_pipeline":
219
+ volume_ok = (
220
+ state.silver_row_count <= state.source_a_row_count * 1.05
221
+ and state.silver_row_count > 0
222
+ )
223
+ breakdown["volume_ok"] = rubric["volume_ok"] if volume_ok else 0.0
224
+ breakdown["high_match"] = rubric["high_match"] if state.match_rate >= 0.80 else 0.0
225
+ if state.match_rate < 0.80:
226
+ notes.append(f"match_rate={state.match_rate:.1%} β€” target >80%.")
227
+
228
+ elif task_id == "dirty_integration":
229
+ # Both PREP_KEYS before join
230
+ prepped = state.did_prep_a and state.did_prep_b and state.did_join
231
+ breakdown["prepped_before_join"] = rubric["prepped_before_join"] if prepped else 0.0
232
+ # DEDUP before join
233
+ deduped = state.did_dedup_b and state.did_join
234
+ breakdown["deduped_before_join"] = rubric["deduped_before_join"] if deduped else 0.0
235
+ # Integrity check comes from grader
236
+ integrity_ok = state.grader_passed or (
237
+ state.quarantine_row_count >= 0 # grader_passed already covers this
238
+ )
239
+ # Use grader_passed as proxy for integrity
240
+ breakdown["integrity_ok"] = rubric["integrity_ok"] if state.grader_passed else 0.0
241
+ if not prepped:
242
+ notes.append("Agent joined without prepping keys first.")
243
+ if not deduped:
244
+ notes.append("Agent joined without deduplicating Dimension.")
245
+
246
+ elif task_id == "full_medallion":
247
+ breakdown["sync_checked"] = rubric["sync_checked"] if state.did_sync_check else 0.0
248
+ breakdown["schema_evolved"] = rubric["schema_evolved"] if state.did_evolve_schema else 0.0
249
+ used_scd2 = state.scd_type == "SCD-2"
250
+ breakdown["used_scd2"] = rubric["used_scd2"] if used_scd2 else 0.0
251
+ breakdown["schema_ok"] = rubric["schema_ok"] if state.grader_passed else 0.0
252
+ if not state.did_sync_check:
253
+ notes.append("SYNC_CHECK was never called β€” stale source not verified.")
254
+ if not state.did_evolve_schema:
255
+ notes.append("EVOLVE_SCHEMA never called β€” new columns may be missing from Silver.")
256
+ if not used_scd2:
257
+ notes.append(f"Used SCD-1 instead of SCD-2 (scd_type={state.scd_type!r}).")
258
+
259
+ # ── Final score ───────────────────────────────────────────────────────
260
+ total = sum(breakdown.values())
261
+ # Clip to [0, 1] (row explosion can make total negative from reward engine)
262
+ score = max(0.0, min(1.0, total))
263
+ passed = score >= 0.55
264
+
265
+ return TaskResult(
266
+ task_id=task_id,
267
+ score=round(score, 4),
268
+ grade=_grade(score),
269
+ breakdown=breakdown,
270
+ passed=passed,
271
+ notes=notes,
272
+ )
273
+
274
+
275
+ # ---------------------------------------------------------------------------
276
+ # Convenience: score all tasks
277
+ # ---------------------------------------------------------------------------
278
+
279
+ def score_all_tasks(
280
+ results: Dict[str, tuple], # task_id β†’ (state, tables)
281
+ ) -> Dict[str, TaskResult]:
282
+ """Score multiple completed episodes, one per task."""
283
+ return {
284
+ task_id: score_episode(task_id, state, tables)
285
+ for task_id, (state, tables) in results.items()
286
+ }
tests/test_medusa_environment.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the MEDUSA environment.
2
+
3
+ Covers: models, scenario generator, operators, reward engine, grader,
4
+ and full end-to-end environment episodes.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import pytest
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Models
13
+ # ---------------------------------------------------------------------------
14
+
15
+ from medusa_env.models import (
16
+ MedusaAction,
17
+ MedusaActionType,
18
+ MedusaObservation,
19
+ MedusaState,
20
+ )
21
+
22
+
23
+ class TestMedusaModels:
24
+ def test_action_creation(self):
25
+ a = MedusaAction(action=MedusaActionType.SYNC_CHECK)
26
+ assert a.action == MedusaActionType.SYNC_CHECK
27
+ assert a.params == {}
28
+
29
+ def test_state_defaults(self):
30
+ s = MedusaState()
31
+ assert s.stage == "init"
32
+ assert s.step_idx == 0
33
+ assert s.did_sync_check is False
34
+ assert s.explosion_detected is False
35
+ assert s.grader_passed is False
36
+
37
+ def test_observation_defaults(self):
38
+ obs = MedusaObservation()
39
+ assert obs.done is False
40
+ assert obs.reward is None
41
+ assert obs.features == []
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Scenario Generator
46
+ # ---------------------------------------------------------------------------
47
+
48
+ import pandas as pd
49
+
50
+ from medusa_env.scenarios import Scenario, ScenarioGenerator
51
+
52
+
53
+ class TestMedusaScenarios:
54
+ @pytest.fixture
55
+ def gen(self):
56
+ return ScenarioGenerator(n_fact_rows=50, n_dim_rows=40)
57
+
58
+ def test_canonical_clean(self, gen):
59
+ scen = gen.generate(seed=0)
60
+ assert scen.id.startswith("clean")
61
+ assert isinstance(scen.bronze_a, pd.DataFrame)
62
+ assert len(scen.bronze_a) == 50
63
+ assert not scen.is_stale_a
64
+ assert not scen.is_stale_b
65
+ assert scen.new_cols_a == []
66
+
67
+ def test_canonical_dirty_keys(self, gen):
68
+ scen = gen.generate(seed=1)
69
+ assert "dirty_keys" in scen.id
70
+ # Dirty scenario should have actual null or whitespace keys
71
+ has_issues = (
72
+ scen.bronze_a[scen.join_key].isna().any()
73
+ or scen.bronze_a[scen.join_key].astype(str).str.contains(r"^\s|\s$").any()
74
+ )
75
+ assert has_issues
76
+
77
+ def test_canonical_stale(self, gen):
78
+ scen = gen.generate(seed=2)
79
+ assert "stale" in scen.id
80
+ assert scen.is_stale_a # Source A should be stale
81
+
82
+ def test_canonical_schema_drift(self, gen):
83
+ scen = gen.generate(seed=3)
84
+ assert "schema_drift" in scen.id
85
+ assert len(scen.new_cols_a) > 0
86
+ assert len(scen.new_cols_b) > 0
87
+
88
+ def test_random_seed_produces_scenario(self, gen):
89
+ scen = gen.generate(seed=999)
90
+ assert isinstance(scen, Scenario)
91
+ assert scen.join_key in scen.bronze_a.columns
92
+ assert scen.join_key in scen.bronze_b.columns
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Operators
97
+ # ---------------------------------------------------------------------------
98
+
99
+ from medusa_env.operators import (
100
+ apply_scd,
101
+ deduplicate,
102
+ evolve_schema,
103
+ execute_join,
104
+ prep_keys,
105
+ sync_check,
106
+ )
107
+
108
+
109
+ class TestMedusaOperators:
110
+ def test_sync_check_fresh(self):
111
+ a = pd.DataFrame({"id": [1, 2]})
112
+ b = pd.DataFrame({"id": [1, 2]})
113
+ _, m = sync_check(a, b, time_delta_a=1.0, time_delta_b=2.0)
114
+ assert m["is_stale_a"] is False
115
+ assert m["is_stale_b"] is False
116
+
117
+ def test_sync_check_stale(self):
118
+ a = pd.DataFrame({"id": [1]})
119
+ b = pd.DataFrame({"id": [1]})
120
+ _, m = sync_check(a, b, time_delta_a=10.0, time_delta_b=1.0)
121
+ assert m["is_stale_a"] is True
122
+ assert m["is_stale_b"] is False
123
+
124
+ def test_prep_keys_strips_whitespace(self):
125
+ df = pd.DataFrame({"key": [" K001 ", "K002", None]})
126
+ result, m = prep_keys(df, "key")
127
+ # Stripped key should have no leading/trailing spaces
128
+ non_null = result["key"].dropna().tolist()
129
+ assert all(v.strip() == v for v in non_null)
130
+ assert m["null_ratio_before"] > 0
131
+
132
+ def test_deduplicate_removes_dupes(self):
133
+ df = pd.DataFrame({"key": ["A", "A", "B"], "val": [1, 2, 3]})
134
+ result, m = deduplicate(df, "key")
135
+ assert m["dupes_removed"] == 1
136
+ assert len(result) == 2
137
+
138
+ def test_execute_join_left_basic(self):
139
+ fact = pd.DataFrame({"key": ["K001", "K002", "K003"], "val": [1, 2, 3]})
140
+ dim = pd.DataFrame({"key": ["K001", "K002"], "dim_name": ["A", "B"]})
141
+ joined, quarantine, m = execute_join(fact, dim, "key", "left")
142
+ assert m["join_rows"] == 3 # left join keeps all fact rows
143
+ assert m["match_rate"] == pytest.approx(2 / 3, abs=0.01)
144
+ assert len(quarantine) >= 1 # K003 should be quarantined
145
+
146
+ def test_execute_join_detects_explosion(self):
147
+ # Non-unique dim key β†’ Cartesian explosion
148
+ fact = pd.DataFrame({"key": ["K001"] * 10, "val": list(range(10))})
149
+ dim = pd.DataFrame({"key": ["K001"] * 20, "dim_name": ["X"] * 20})
150
+ joined, quarantine, m = execute_join(fact, dim, "key", "inner")
151
+ assert m["explosion_detected"] is True
152
+
153
+ def test_execute_join_anti(self):
154
+ fact = pd.DataFrame({"key": ["K001", "K002", "K999"], "val": [1, 2, 3]})
155
+ dim = pd.DataFrame({"key": ["K001", "K002"], "name": ["A", "B"]})
156
+ joined, quarantine, m = execute_join(fact, dim, "key", "anti")
157
+ assert len(joined) == 0 # Anti-join: no rows in joined
158
+ assert len(quarantine) == 1 # K999 goes to quarantine
159
+
160
+ def test_apply_scd1_upsert(self):
161
+ silver = pd.DataFrame({"key": ["K001"], "val": [10], "status": ["old"]})
162
+ joined = pd.DataFrame({"key": ["K001", "K002"], "val": [99, 20], "status": ["new", "new"]})
163
+ result, m = apply_scd(silver, joined, "key", "status", scd_type=1)
164
+ assert m["scd_type"] == 1
165
+ assert m["inserts"] + m["updates"] > 0
166
+ # K001 should be updated to val=99
167
+ k1_row = result[result["key"] == "K001"]
168
+ assert not k1_row.empty
169
+
170
+ def test_apply_scd2_adds_history(self):
171
+ silver = pd.DataFrame()
172
+ joined = pd.DataFrame({"key": ["K001"], "status": ["active"]})
173
+ result, m = apply_scd(silver, joined, "key", "status", scd_type=2)
174
+ assert "valid_from" in result.columns
175
+ assert m["inserts"] == 1
176
+
177
+ def test_evolve_schema_adds_columns(self):
178
+ silver = pd.DataFrame({"key": ["K001"], "val": [1]})
179
+ a = pd.DataFrame({"key": ["K001"], "new_metric": [42]})
180
+ b = pd.DataFrame({"key": ["K001"]})
181
+ result, m = evolve_schema(silver, a, b, ["new_metric"], [])
182
+ assert "new_metric" in result.columns
183
+ assert m["new_cols_count"] == 1
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # Reward Engine
188
+ # ---------------------------------------------------------------------------
189
+
190
+ from medusa_env.rewards import RewardEngine
191
+
192
+
193
+ class TestMedusaRewards:
194
+ @pytest.fixture
195
+ def engine(self):
196
+ return RewardEngine()
197
+
198
+ def _clean_state(self):
199
+ s = MedusaState()
200
+ s.did_prep_a = True
201
+ s.did_prep_b = True
202
+ s.did_sync_check = True
203
+ return s
204
+
205
+ def test_step_penalty_always_applied(self, engine):
206
+ r = engine.evaluate("SYNC_CHECK", {}, MedusaState())
207
+ assert r == pytest.approx(-0.2, abs=0.01)
208
+
209
+ def test_high_match_join_reward(self, engine):
210
+ r = engine.evaluate(
211
+ "EXECUTE_JOIN_LEFT",
212
+ {"match_rate": 0.95, "join_rows": 100, "fact_rows": 100,
213
+ "explosion_detected": False, "quarantine_rows": 5},
214
+ self._clean_state(),
215
+ )
216
+ assert r > 0.0 # +25 - 0.2 + 10 (quarantine) = +34.8
217
+
218
+ def test_row_explosion_heavy_penalty(self, engine):
219
+ r = engine.evaluate(
220
+ "EXECUTE_JOIN_INNER",
221
+ {"explosion_detected": True, "join_rows": 1000, "fact_rows": 100,
222
+ "match_rate": 1.0, "quarantine_rows": 0},
223
+ self._clean_state(),
224
+ )
225
+ assert r < -50.0
226
+
227
+ def test_dirty_join_penalty(self, engine):
228
+ # No PREP_KEYS β†’ dirty join penalty
229
+ state = MedusaState()
230
+ state.did_prep_a = False
231
+ state.did_prep_b = False
232
+ r = engine.evaluate(
233
+ "EXECUTE_JOIN_LEFT",
234
+ {"explosion_detected": False, "join_rows": 0, "fact_rows": 50,
235
+ "match_rate": 0.0, "quarantine_rows": 0},
236
+ state,
237
+ )
238
+ assert r < -20.0
239
+
240
+ def test_scd2_extra_reward(self, engine):
241
+ r = engine.evaluate("APPLY_SCD_2", {}, self._clean_state())
242
+ # +5 for SCD-2 - 0.2 step penalty
243
+ assert r == pytest.approx(4.8, abs=0.01)
244
+
245
+ def test_stale_processing_penalty(self, engine):
246
+ state = MedusaState()
247
+ state.is_stale_a = True
248
+ state.did_sync_check = False # Never checked freshness
249
+ state.did_prep_a = True
250
+ state.did_prep_b = True
251
+ r = engine.evaluate(
252
+ "EXECUTE_JOIN_LEFT",
253
+ {"explosion_detected": False, "join_rows": 100, "fact_rows": 100,
254
+ "match_rate": 0.95, "quarantine_rows": 0},
255
+ state,
256
+ )
257
+ # Should include stale penalty on top of positive join reward
258
+ assert r < 25.0 # Stale penalty reduces it
259
+
260
+
261
+ # ---------------------------------------------------------------------------
262
+ # Grader
263
+ # ---------------------------------------------------------------------------
264
+
265
+ from medusa_env.grader import Grader
266
+ from medusa_env.scenarios import Scenario
267
+
268
+
269
+ class TestMedusaGrader:
270
+ @pytest.fixture
271
+ def grader(self):
272
+ return Grader()
273
+
274
+ def _make_scenario(self):
275
+ a = pd.DataFrame({"entity_id": ["K1", "K2", "K3"], "val": [1, 2, 3],
276
+ "fact_category": ["A", "B", "C"],
277
+ "fact_value": [1.0, 2.0, 3.0],
278
+ "created_at": pd.date_range("2024-01-01", periods=3, freq="h")})
279
+ b = pd.DataFrame({"entity_id": ["K1", "K2"], "dim_name": ["N1", "N2"], "dim_status": ["x", "y"]})
280
+ return a, b
281
+
282
+ def test_volume_check_pass(self, grader):
283
+ a, b = self._make_scenario()
284
+ silver = pd.DataFrame({"entity_id": ["K1", "K2"], "val": [1, 2]})
285
+ scen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2).generate(seed=0)
286
+ r = grader.audit(silver, pd.DataFrame(), a, b, "entity_id", "left", 1, scen)
287
+ assert r.volume_ok is True
288
+
289
+ def test_volume_check_fail(self, grader):
290
+ a, b = self._make_scenario()
291
+ # Silver has way more rows than source A β†’ violation
292
+ silver = pd.DataFrame({"entity_id": ["K1"] * 100})
293
+ scen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2).generate(seed=0)
294
+ r = grader.audit(silver, pd.DataFrame(), a, b, "entity_id", "left", 1, scen)
295
+ assert r.volume_ok is False
296
+
297
+ def test_integrity_check_quarantine_true_orphans(self, grader):
298
+ a, b = self._make_scenario()
299
+ # K3 is not in B β†’ true orphan
300
+ quarantine = pd.DataFrame({"entity_id": ["K3"]})
301
+ scen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2).generate(seed=0)
302
+ silver = pd.DataFrame({"entity_id": ["K1", "K2"]})
303
+ r = grader.audit(silver, quarantine, a, b, "entity_id", "left", 1, scen)
304
+ assert r.integrity_ok is True
305
+
306
+ def test_integrity_check_fail_dirty_quarantine(self, grader):
307
+ a, b = self._make_scenario()
308
+ # K1 IS in B but ends up in quarantine (agent failed to clean it)
309
+ quarantine = pd.DataFrame({"entity_id": ["K1"]})
310
+ scen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2).generate(seed=0)
311
+ silver = pd.DataFrame({"entity_id": ["K2"]})
312
+ r = grader.audit(silver, quarantine, a, b, "entity_id", "left", 1, scen)
313
+ assert r.integrity_ok is False
314
+
315
+ def test_all_pass_gives_bonus(self, grader):
316
+ gen = ScenarioGenerator(n_fact_rows=3, n_dim_rows=2)
317
+ scen = gen.generate(seed=0)
318
+ a, b = scen.bronze_a, scen.bronze_b
319
+ # Simulate a perfect run
320
+ silver = a.merge(b, on="entity_id", how="left")
321
+ r = grader.audit(silver, pd.DataFrame(), a, b, "entity_id", "left", 1, scen)
322
+ assert r.bonus_reward > 0
323
+
324
+
325
+ # ---------------------------------------------------------------------------
326
+ # Full environment integration
327
+ # ---------------------------------------------------------------------------
328
+
329
+ from medusa_env.server import MedusaEnv
330
+ from medusa_env.models import MedusaActionType
331
+
332
+
333
+ class TestMedusaEnvironment:
334
+ @pytest.fixture
335
+ def env(self):
336
+ return MedusaEnv(n_fact_rows=50, n_dim_rows=40)
337
+
338
+ def test_reset_returns_observation(self, env):
339
+ obs = env.reset(seed=0)
340
+ assert isinstance(obs, MedusaObservation)
341
+ assert obs.done is False
342
+ assert len(obs.features) == 16
343
+ assert obs.reward is None
344
+
345
+ def test_state_after_reset(self, env):
346
+ env.reset(seed=0)
347
+ state = env.state
348
+ assert state.stage == "running"
349
+ assert state.step_idx == 0
350
+ assert state.source_a_row_count == 50
351
+
352
+ def test_happy_path_episode(self, env):
353
+ """Full pipeline: sync β†’ evolve β†’ prep both β†’ dedup β†’ join β†’ scd β†’ commit."""
354
+ env.reset(seed=0) # clean scenario
355
+
356
+ actions = [
357
+ MedusaActionType.SYNC_CHECK,
358
+ MedusaActionType.EVOLVE_SCHEMA,
359
+ MedusaActionType.PREP_KEYS_A,
360
+ MedusaActionType.PREP_KEYS_B,
361
+ MedusaActionType.DEDUPLICATE_B,
362
+ MedusaActionType.EXECUTE_JOIN_LEFT,
363
+ MedusaActionType.APPLY_SCD_2,
364
+ MedusaActionType.COMMIT,
365
+ ]
366
+ obs = None
367
+ for act_type in actions:
368
+ obs = env.step(MedusaAction(action=act_type))
369
+
370
+ assert obs is not None
371
+ assert obs.done is True
372
+ assert env.state.stage == "committed"
373
+ assert env.state.grader_passed # Clean scenario should pass grader
374
+
375
+ def test_row_explosion_gives_heavy_penalty(self, env):
376
+ """Joining on non-unique B keys should trigger explosion penalty."""
377
+ env.reset(seed=1) # dirty_keys scenario β€” B has duplicate keys
378
+
379
+ # Skip prep & dedup β€” go straight to join
380
+ env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
381
+
382
+ # Force the dimension to have many duplicates so explosion fires
383
+ import pandas as _pd
384
+
385
+ env._tables.bronze_b_prepped = _pd.DataFrame({
386
+ "entity_id": ["K001"] * 30,
387
+ "dim_name": ["X"] * 30,
388
+ "dim_status": ["x"] * 30,
389
+ })
390
+ env._tables.bronze_a_prepped = _pd.DataFrame({
391
+ "entity_id": ["K001"] * 10,
392
+ "fact_value": list(range(10)),
393
+ "fact_category": ["A"] * 10,
394
+ "created_at": _pd.date_range("2024-01-01", periods=10, freq="h"),
395
+ })
396
+
397
+ obs = env.step(MedusaAction(action=MedusaActionType.EXECUTE_JOIN_INNER))
398
+ assert obs.reward is not None
399
+ assert obs.reward < -50.0
400
+ assert env.state.explosion_detected is True
401
+
402
+ def test_dirty_join_penalty(self, env):
403
+ """Skipping PREP_KEYS and joining on null-heavy keys β†’ dirty join."""
404
+ env.reset(seed=1) # dirty_keys scenario
405
+
406
+ # Skip PREP β€” join directly
407
+ obs = env.step(MedusaAction(action=MedusaActionType.EXECUTE_JOIN_LEFT))
408
+ # If all fact keys are null/non-matching β†’ 0-row join β†’ dirty join penalty
409
+ # (reward < base -0.2 if dirty join fired)
410
+ assert obs.reward is not None
411
+
412
+ def test_step_idx_increments(self, env):
413
+ env.reset(seed=0)
414
+ for _ in range(3):
415
+ env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
416
+ assert env.state.step_idx == 3
417
+
418
+ def test_max_steps_terminates_episode(self):
419
+ env = MedusaEnv(n_fact_rows=10, n_dim_rows=10, max_steps=3)
420
+ env.reset(seed=0)
421
+ obs = None
422
+ for _ in range(4): # more than max_steps
423
+ obs = env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
424
+ assert obs is not None
425
+ assert obs.done is True
426
+
427
+ def test_commit_without_join_grader_fails(self, env):
428
+ """Committing without joining should make the grader fail."""
429
+ env.reset(seed=0)
430
+ env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
431
+ obs = env.step(MedusaAction(action=MedusaActionType.COMMIT))
432
+ assert obs.done is True
433
+ # Silver will be empty β†’ schema check should fail or volume check fail
434
+ assert env.state.grader_report != ""
435
+
436
+ def test_features_vector_length(self, env):
437
+ env.reset(seed=0)
438
+ obs = env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
439
+ assert len(obs.features) == 16
440
+ assert all(0.0 <= f <= 1.0 for f in obs.features)
441
+
442
+ def test_governance_log_populated(self, env):
443
+ env.reset(seed=0)
444
+ env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
445
+ env.step(MedusaAction(action=MedusaActionType.PREP_KEYS_A))
446
+ log = env._tables.governance_log
447
+ assert len(log) == 2
448
+ assert log[0]["action"] == "SYNC_CHECK"
449
+
450
+
451
+ # ---------------------------------------------------------------------------
452
+ # Task Scorer
453
+ # ---------------------------------------------------------------------------
454
+
455
+ from medusa_env.tasks import TASKS, score_episode
456
+
457
+
458
+ class TestMedusaTasks:
459
+ """Tests for the 3 formal task definitions and 0.0–1.0 scorer."""
460
+
461
+ def test_three_tasks_defined(self):
462
+ assert "clean_pipeline" in TASKS
463
+ assert "dirty_integration" in TASKS
464
+ assert "full_medallion" in TASKS
465
+
466
+ def test_task_difficulties(self):
467
+ assert TASKS["clean_pipeline"].difficulty == "easy"
468
+ assert TASKS["dirty_integration"].difficulty == "medium"
469
+ assert TASKS["full_medallion"].difficulty == "hard"
470
+
471
+ def test_task_seeds_match_scenarios(self):
472
+ assert TASKS["clean_pipeline"].seed == 0
473
+ assert TASKS["dirty_integration"].seed == 1
474
+ assert TASKS["full_medallion"].seed == 2
475
+
476
+ def _run_happy_path(self, seed: int) -> MedusaState:
477
+ """Run the optimal action sequence for the given seed and return final state."""
478
+ env = MedusaEnv(n_fact_rows=50, n_dim_rows=40)
479
+ env.reset(seed=seed)
480
+ for act in [
481
+ MedusaActionType.SYNC_CHECK,
482
+ MedusaActionType.EVOLVE_SCHEMA,
483
+ MedusaActionType.PREP_KEYS_A,
484
+ MedusaActionType.PREP_KEYS_B,
485
+ MedusaActionType.DEDUPLICATE_B,
486
+ MedusaActionType.EXECUTE_JOIN_LEFT,
487
+ MedusaActionType.APPLY_SCD_2,
488
+ MedusaActionType.COMMIT,
489
+ ]:
490
+ env.step(MedusaAction(action=act))
491
+ return env.state
492
+
493
+ # ── clean_pipeline (easy) ───────────────────────────────────────────────
494
+
495
+ def test_clean_pipeline_score_is_in_range(self):
496
+ state = self._run_happy_path(seed=0)
497
+ result = score_episode("clean_pipeline", state)
498
+ assert 0.0 <= result.score <= 1.0
499
+
500
+ def test_clean_pipeline_happy_path_passes(self):
501
+ state = self._run_happy_path(seed=0)
502
+ result = score_episode("clean_pipeline", state)
503
+ assert result.passed is True
504
+ assert result.grade in ("S", "A", "B")
505
+
506
+ def test_clean_pipeline_uncommitted_scores_zero(self):
507
+ state = MedusaState(stage="running")
508
+ result = score_episode("clean_pipeline", state)
509
+ assert result.score == 0.0
510
+ assert result.grade == "F"
511
+
512
+ def test_clean_pipeline_explosion_detected_lowers_score(self):
513
+ state = MedusaState(
514
+ stage="committed",
515
+ explosion_detected=True,
516
+ silver_row_count=0,
517
+ source_a_row_count=50,
518
+ match_rate=0.0,
519
+ grader_passed=False,
520
+ )
521
+ result = score_episode("clean_pipeline", state)
522
+ assert result.breakdown["no_explosion"] == 0.0
523
+
524
+ # ── dirty_integration (medium) ─────────���────────────────────────────────
525
+
526
+ def test_dirty_integration_score_is_in_range(self):
527
+ state = self._run_happy_path(seed=1)
528
+ result = score_episode("dirty_integration", state)
529
+ assert 0.0 <= result.score <= 1.0
530
+
531
+ def test_dirty_integration_without_prep_penalized(self):
532
+ state = MedusaState(
533
+ stage="committed",
534
+ did_prep_a=False,
535
+ did_prep_b=False,
536
+ did_dedup_b=False,
537
+ did_join=True,
538
+ explosion_detected=False,
539
+ grader_passed=False,
540
+ )
541
+ result = score_episode("dirty_integration", state)
542
+ assert result.breakdown["prepped_before_join"] == 0.0
543
+ assert result.breakdown["deduped_before_join"] == 0.0
544
+
545
+ def test_dirty_integration_with_all_prereqs_scores_higher(self):
546
+ state_no_prep = MedusaState(
547
+ stage="committed", did_prep_a=False, did_prep_b=False,
548
+ did_dedup_b=False, did_join=True, explosion_detected=False, grader_passed=False,
549
+ )
550
+ state_prepped = MedusaState(
551
+ stage="committed", did_prep_a=True, did_prep_b=True,
552
+ did_dedup_b=True, did_join=True, explosion_detected=False, grader_passed=True,
553
+ )
554
+ no_prep = score_episode("dirty_integration", state_no_prep)
555
+ prepped = score_episode("dirty_integration", state_prepped)
556
+ assert prepped.score > no_prep.score
557
+
558
+ # ── full_medallion (hard) ───────────────────────────────────────────────
559
+
560
+ def test_full_medallion_score_is_in_range(self):
561
+ state = self._run_happy_path(seed=2)
562
+ result = score_episode("full_medallion", state)
563
+ assert 0.0 <= result.score <= 1.0
564
+
565
+ def test_full_medallion_without_sync_penalized(self):
566
+ state = MedusaState(
567
+ stage="committed",
568
+ did_sync_check=False,
569
+ did_evolve_schema=True,
570
+ scd_type="SCD-2",
571
+ grader_passed=True,
572
+ )
573
+ result = score_episode("full_medallion", state)
574
+ assert result.breakdown["sync_checked"] == 0.0
575
+
576
+ def test_full_medallion_scd1_penalized(self):
577
+ state_scd1 = MedusaState(
578
+ stage="committed", did_sync_check=True,
579
+ did_evolve_schema=True, scd_type="SCD-1", grader_passed=False,
580
+ )
581
+ state_scd2 = MedusaState(
582
+ stage="committed", did_sync_check=True,
583
+ did_evolve_schema=True, scd_type="SCD-2", grader_passed=True,
584
+ )
585
+ r1 = score_episode("full_medallion", state_scd1)
586
+ r2 = score_episode("full_medallion", state_scd2)
587
+ assert r2.score > r1.score
588
+
589
+ def test_unknown_task_raises(self):
590
+ with pytest.raises(ValueError, match="Unknown task_id"):
591
+ score_episode("nonexistent_task", MedusaState(stage="committed"))
uv.lock ADDED
The diff for this file is too large to render. See raw diff