medusa_env / grader.py
rampluto's picture
Upload folder using huggingface_hub
fd09b74 verified
"""MEDUSA deterministic post-commit grader.
Runs a four-check audit after the agent issues COMMIT and returns a
``GraderResult`` that feeds a bonus/penalty into the terminal reward.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List
import pandas as pd
if TYPE_CHECKING:
from .scenarios import Scenario
# ---------------------------------------------------------------------------
# GraderResult
# ---------------------------------------------------------------------------
@dataclass
class GraderResult:
"""Outcome of the post-commit audit."""
passed: bool = False
volume_ok: bool = False # Silver rows ≀ Source A rows (no duplicates from join)
integrity_ok: bool = False # Quarantine holds only true orphans
schema_ok: bool = False # Silver has union of required columns
history_ok: bool = False # SCD-2 timestamps non-overlapping
failures: List[str] = field(default_factory=list)
bonus_reward: float = 0.0
report: str = ""
# Reward tuning
_BONUS_ALL_PASS = +15.0
_PENALTY_ALL_FAIL = -20.0
_BONUS_PER_CHECK = +3.0
_PENALTY_PER_FAIL = -5.0
# ---------------------------------------------------------------------------
# Grader
# ---------------------------------------------------------------------------
class Grader:
"""Post-commit deterministic audit following MEDUSA spec Β§4."""
def audit(
self,
silver: pd.DataFrame,
quarantine: pd.DataFrame,
bronze_a: pd.DataFrame,
bronze_b: pd.DataFrame,
join_key: str,
join_type: str,
scd_type: int,
scenario: "Scenario",
) -> GraderResult:
"""Run all four grader checks and compute bonus reward.
Args:
silver: The final Silver DataFrame after SCD merge.
quarantine: Rows from A that did not match B.
bronze_a: Original fact source (pre-cleaning).
bronze_b: Original dimension source (pre-cleaning).
join_key: Column used for the join.
join_type: "inner" | "left" | "anti"
scd_type: 1 or 2
scenario: The current episode's scenario (has tracked_cols etc.)
Returns:
GraderResult with individual check statuses and bonus_reward.
"""
result = GraderResult()
# ── 1. Volume Check ──────────────────────────────────────────────
# For left joins, Silver should not exceed Source A row count.
if join_type == "left":
source_a_rows = len(bronze_a.dropna(subset=[join_key]))
silver_rows = len(silver[silver.get("is_current", pd.Series(True, index=silver.index)) == True]) if "is_current" in silver.columns else len(silver) # noqa: E712
result.volume_ok = silver_rows <= source_a_rows * 1.05 # 5% tolerance
if not result.volume_ok:
result.failures.append(
f"VOLUME_FAIL: Silver {silver_rows} rows > Source A {source_a_rows} rows"
)
else:
result.volume_ok = True # Not applicable for inner/anti joins
# ── 2. Integrity Check ───────────────────────────────────────────
# Quarantine rows should be true orphans (no match in B even after cleaning).
if not quarantine.empty and join_key in quarantine.columns:
dim_keys = set(bronze_b[join_key].dropna().astype(str).str.strip())
quarantine_keys = set(quarantine[join_key].dropna().astype(str).str.strip())
# Orphan = quarantine key truly not in dim
could_join = quarantine_keys & dim_keys
if could_join:
result.integrity_ok = False
result.failures.append(
f"INTEGRITY_FAIL: {len(could_join)} quarantine row(s) could have "
f"been joined if keys were cleaned."
)
else:
result.integrity_ok = True
else:
result.integrity_ok = True # Empty quarantine is fine
# ── 3. Schema Check ──────────────────────────────────────────────
# Silver must contain all required columns from A and B.
required_from_a = [c for c in bronze_a.columns if c != join_key]
required_from_b = [c for c in bronze_b.columns if c != join_key]
required = set(required_from_a + required_from_b + scenario.new_cols_a + scenario.new_cols_b)
silver_cols = set(silver.columns)
missing = required - silver_cols
if missing:
result.schema_ok = False
result.failures.append(f"SCHEMA_FAIL: Missing columns in Silver: {sorted(missing)}")
else:
result.schema_ok = True
# ── 4. History Check (SCD-2 only) ────────────────────────────────
if scd_type == 2 and "valid_from" in silver.columns and "valid_to" in silver.columns:
overlap_found = False
for key_val, group in silver.groupby(join_key):
if len(group) < 2:
continue
closed = group[group["valid_to"].notna()].sort_values("valid_from")
for i in range(len(closed) - 1):
vt_i = closed.iloc[i]["valid_to"]
vf_next = closed.iloc[i + 1]["valid_from"]
if pd.notna(vt_i) and pd.notna(vf_next) and vt_i > vf_next:
overlap_found = True
break
if overlap_found:
break
if overlap_found:
result.history_ok = False
result.failures.append("HISTORY_FAIL: SCD-2 timestamps overlap for some keys.")
else:
result.history_ok = True
else:
result.history_ok = True # Not applicable for SCD-1
# ── Compute bonus ────────────────────────────────────────────────
checks = [result.volume_ok, result.integrity_ok, result.schema_ok, result.history_ok]
passed_count = sum(checks)
failed_count = len(checks) - passed_count
result.passed = all(checks)
if result.passed:
result.bonus_reward = _BONUS_ALL_PASS
elif failed_count == len(checks):
result.bonus_reward = _PENALTY_ALL_FAIL
else:
result.bonus_reward = passed_count * _BONUS_PER_CHECK - failed_count * _PENALTY_PER_FAIL
result.report = _build_report(result)
return result
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _build_report(result: GraderResult) -> str:
lines = ["=== MEDUSA Grader Audit ==="]
lines.append(f" Volume OK: {'βœ“' if result.volume_ok else 'βœ—'}")
lines.append(f" Integrity OK: {'βœ“' if result.integrity_ok else 'βœ—'}")
lines.append(f" Schema OK: {'βœ“' if result.schema_ok else 'βœ—'}")
lines.append(f" History OK: {'βœ“' if result.history_ok else 'βœ—'}")
lines.append(f" Bonus Reward: {result.bonus_reward:+.1f}")
if result.failures:
lines.append(" Failures:")
for f in result.failures:
lines.append(f" - {f}")
lines.append(f" {'PASS βœ“' if result.passed else 'FAIL βœ—'}")
return "\n".join(lines)