mlops-openenv / server /mlops_environment.py
trretretret's picture
Add uv.lock, pyproject.toml, server directory
6a71058
"""
MLOps Pipeline Debugger β€” Core Environment
Episode flow:
1. reset(task_id, seed) β†’ generates a broken training run with one planted bug
2. Agent investigates using 8 actions (reads artifacts, runs sanity checks)
3. Agent submits a structured diagnosis
4. Grader compares against planted bug ground truth β†’ score in [0.0, 1.0]
Reward design (dense, not sparse):
+0.02 per new artifact read (first time β€” rewards exploration)
-0.02 per duplicate artifact read (no new filter applied)
-0.05 submitting diagnosis after reading < 3 distinct artifacts
At submit_diagnosis:
+0.15 correct failure_category
+0.25 correct root_cause_file
+0.30 correct root_cause_field (substring match, case-insensitive)
+0.30 correct proposed_fix (keyword match against gold fix)
Task 3 (hard) penalty multiplier:
wrong diagnosis β†’ Γ—1.5 penalty on the missed components
(silent bugs that reach production are more costly)
"""
from __future__ import annotations
import random
from typing import Any, Dict, List, Optional, Tuple
from models import MLOpsAction, MLOpsObservation, MLOpsState, ArtifactMeta
from artifact_generator import (
ArtifactGenerator, BUG_CATALOGUE, TASK_BUG_POOLS,
run_sanity_check,
)
TASK_MAX_STEPS = {"easy": 20, "medium": 30, "hard": 40}
TASK_DESCRIPTIONS = {
"easy": (
"TASK 1 β€” CONFIG ERROR DIAGNOSIS (Easy)\n\n"
"A training run has failed or produced clearly wrong results. The issue is in "
"the training configuration β€” a hyperparameter is set to an incorrect value that "
"causes immediate, visible degradation in training metrics.\n\n"
"Your job: investigate the training artifacts, identify which configuration "
"parameter is wrong, and propose the correct fix.\n\n"
"Strategy: Start by reading the training logs to observe symptom patterns, "
"then check the config to find the misconfigured parameter. "
"Run sanity checks (loss_trajectory, gradient_norms) to confirm your hypothesis "
"before submitting.\n\n"
"Actions available: read_config | read_logs | check_dataset_stats | "
"inspect_preprocessing | read_eval_results | run_sanity_check | "
"query_artifact | submit_diagnosis"
),
"medium": (
"TASK 2 β€” DATA LEAKAGE DETECTION (Medium)\n\n"
"Training metrics look suspiciously good β€” validation accuracy is anomalously "
"high from the first epoch, but test performance tells a different story. "
"The issue is in the data preprocessing pipeline.\n\n"
"Your job: identify the exact source of data leakage β€” whether it's a scaler "
"fitted on the full dataset, overlapping train/val splits from a non-deterministic "
"split, or an inverted split ratio β€” and propose the correct fix.\n\n"
"Strategy: Anomalous val accuracy in the logs is your first signal. "
"Inspect preprocessing code to find how splits are constructed. "
"Run the data_leakage and feature_statistics sanity checks to confirm. "
"The val/test metric gap in eval results is another key clue.\n\n"
"Actions available: read_config | read_logs | check_dataset_stats | "
"inspect_preprocessing | read_eval_results | run_sanity_check | "
"query_artifact | submit_diagnosis"
),
"hard": (
"TASK 3 β€” SILENT EVALUATION BUG (Hard)\n\n"
"Training completed normally. Validation metrics look reasonable. "
"But test set performance is catastrophically below validation β€” "
"and there are NO error logs, NO warnings, NO exceptions thrown.\n\n"
"Your job: find the silent bug in the evaluation pipeline. It could be "
"a label encoder mismatch between train and eval (different class orderings), "
"a metric assignment swap (val/test results mislabeled), or a tokenizer "
"version drift (training used v2, evaluation uses v1).\n\n"
"Strategy: The val/test metric gap in eval_results is your only initial signal. "
"Run metric_gap_analysis first to quantify the anomaly. Then systematically "
"check label_consistency, encoder_version_match, and inspect the preprocessing "
"code carefully β€” the bug produces no error output and will only be visible "
"by comparing train vs eval pipeline definitions.\n\n"
"WARNING: Missing this bug in a deployed model means silent wrong predictions "
"in production. Penalty for wrong diagnosis is weighted 1.5Γ—.\n\n"
"Actions available: read_config | read_logs | check_dataset_stats | "
"inspect_preprocessing | read_eval_results | run_sanity_check | "
"query_artifact | submit_diagnosis"
),
}
ARTIFACT_DESCRIPTIONS = {
"config.yaml": ("Training configuration β€” hyperparameters, model, optimizer, scheduler", "~45 lines"),
"train.log": ("Epoch-by-epoch training metrics β€” loss, accuracy, gradient norms", "~30–60 lines"),
"dataset_stats.json": ("Dataset split sizes, class distribution, feature statistics", "~35 fields"),
"preprocessing.py": ("Data preprocessing pipeline β€” splits, normalization, encoding", "~40–70 lines"),
"eval_results.json": ("Final evaluation metrics β€” val and test loss/accuracy", "~15 fields"),
"model_card.json": ("Model architecture summary, training config, preprocessing versions", "~20 fields"),
}
class MLOpsEnvironment:
"""OpenEnv-compatible MLOps Pipeline Debugging environment."""
def __init__(self, task_id: str = "easy"):
assert task_id in TASK_MAX_STEPS, f"task_id must be one of {list(TASK_MAX_STEPS)}"
self.task_id = task_id
self._reset_internal(seed=42)
def _reset_internal(self, seed: int):
rng = random.Random(seed)
# Pick bug from this task's pool
pool = TASK_BUG_POOLS[self.task_id]
self.bug_type = rng.choice(pool)
self.bug = BUG_CATALOGUE[self.bug_type]
# Generate all artifacts
gen = ArtifactGenerator(self.bug_type, seed)
self._artifacts: Dict[str, str] = gen.generate_all()
self._model_cfg = gen.model_cfg
self._run_id = gen.run_id
self._rng = rng
self._seed = seed
# Cache artifact metadata at reset time (avoids consuming RNG per step)
self._artifact_meta: List[ArtifactMeta] = [
ArtifactMeta(
name=name,
description=ARTIFACT_DESCRIPTIONS[name][0],
size_hint=ARTIFACT_DESCRIPTIONS[name][1],
last_modified=f"2024-03-{rng.randint(1,28):02d}",
)
for name in self._artifacts
]
# Episode state
self._step_count = 0
self._max_steps = TASK_MAX_STEPS[self.task_id]
self._done = False
self._artifacts_read: List[str] = []
self._last_read_filters: Dict[str, str] = {}
self._sanity_checks_run: List[str] = []
self._duplicate_queries = 0
self._current_score = 0.0
self._messages: List[str] = []
# ── OpenEnv API ───────────────────────────────────────────────────────────
def reset(self, seed: Optional[int] = None) -> MLOpsObservation:
import time
actual_seed = seed if seed is not None else int(time.time() * 1000) % 100000
self._reset_internal(actual_seed)
return self._build_obs(
{"status": "reset", "message": "New training run loaded. Begin investigation."},
)
def step(self, action: MLOpsAction) -> Tuple[MLOpsObservation, float, bool, Dict[str, Any]]:
if self._done:
return self._build_obs({"status": "done", "message": "Episode over. Call reset()."}), 0.0, True, {}
self._step_count += 1
reward = 0.0
info: Dict[str, Any] = {}
result: Dict[str, Any] = {}
if self._step_count >= self._max_steps:
self._done = True
score = self._current_score
result = {"status": "timeout", "message": f"Max steps ({self._max_steps}) reached.", "score": score}
return self._build_obs(result), 0.0, True, {"score": score, "reason": "timeout"}
atype = action.action_type
# ── read_config ───────────────────────────────────────────────────
if atype == "read_config":
reward, result = self._handle_artifact_read("config.yaml", None)
# ── read_logs ─────────────────────────────────────────────────────
elif atype == "read_logs":
reward, result = self._handle_artifact_read("train.log", action.log_filter)
# ── check_dataset_stats ───────────────────────────────────────────
elif atype == "check_dataset_stats":
reward, result = self._handle_artifact_read("dataset_stats.json", None)
# ── inspect_preprocessing ─────────────────────────────────────────
elif atype == "inspect_preprocessing":
reward, result = self._handle_artifact_read("preprocessing.py", None)
# ── read_eval_results ─────────────────────────────────────────────
elif atype == "read_eval_results":
reward, result = self._handle_artifact_read("eval_results.json", None)
# ── run_sanity_check ──────────────────────────────────────────────
elif atype == "run_sanity_check":
check = action.sanity_check_type
if not check:
result = {"status": "error", "message": "sanity_check_type is required."}
else:
check_result = run_sanity_check(check, self.bug_type, self._artifacts, self._rng)
if check not in self._sanity_checks_run:
self._sanity_checks_run.append(check)
reward += 0.01 # small reward for running new checks
result = {"status": "ok", "sanity_check": check_result}
# ── query_artifact ────────────────────────────────────────────────
elif atype == "query_artifact":
art = action.artifact_name
field = action.field_path
if not art or not field:
result = {"status": "error", "message": "artifact_name and field_path are required."}
elif art not in self._artifacts:
result = {"status": "error", "message": f"Artifact '{art}' not found."}
else:
val = self._resolve_field(art, field)
result = {"status": "ok", "artifact": art, "field": field, "value": val}
# ── submit_diagnosis ──────────────────────────────────────────────
elif atype == "submit_diagnosis":
reward, info, result = self._handle_submit(action)
self._done = True
obs = self._build_obs(result)
return obs, reward, self._done, info
# ── Internal handlers ──────────────────────────────────────────────────────
def _handle_artifact_read(self, artifact: str, log_filter: Optional[str]) -> Tuple[float, Dict]:
is_duplicate = (
artifact in self._artifacts_read
and self._last_read_filters.get(artifact, "") == (log_filter or "")
)
content = self._artifacts[artifact]
# Apply log filter
if artifact == "train.log" and log_filter:
lines = content.split("\n")
if log_filter.startswith("epoch:"):
try:
parts = log_filter.split(":")[1].split("-")
start, end = int(parts[0]), int(parts[1]) if len(parts) > 1 else int(parts[0])
filtered = [l for l in lines if any(f"EPOCH {ep:03d}" in l
for ep in range(start, end+1)) or "[INFO ]" in l or "[ERROR" in l]
content = "\n".join(filtered) if filtered else "No log lines match this epoch range."
except Exception:
content = "\n".join(lines)
else:
kw = log_filter.lower()
filtered = [l for l in lines if kw in l.lower()]
content = "\n".join(filtered) if filtered else f"No log lines contain '{log_filter}'."
reward = 0.0
if artifact not in self._artifacts_read:
self._artifacts_read.append(artifact)
reward = 0.02 # first read reward
elif is_duplicate:
self._duplicate_queries += 1
reward = -0.02 # duplicate penalty
self._messages.append(f"⚠️ Duplicate read of {artifact} with same filter. Try a different filter or a new artifact.")
self._last_read_filters[artifact] = log_filter or ""
return reward, {
"status": "ok",
"artifact": artifact,
"content": content,
"note": "Use log_filter='keyword' or 'epoch:N-M' for targeted log queries.",
}
def _handle_submit(self, action: MLOpsAction) -> Tuple[float, Dict, Dict]:
if len(self._artifacts_read) < 3:
# Penalty for submitting without adequate investigation
base_penalty = -0.05
self._messages.append("⚠️ Submitted diagnosis after reading fewer than 3 artifacts.")
else:
base_penalty = 0.0
score = base_penalty
breakdown: Dict[str, Any] = {}
# 1. failure_category (+0.15)
if action.failure_category == self.bug.category:
score += 0.15
breakdown["failure_category"] = {"awarded": 0.15, "correct": True}
else:
breakdown["failure_category"] = {
"awarded": 0.0, "correct": False,
"expected": self.bug.category, "got": action.failure_category,
}
# 2. root_cause_file (+0.25)
if action.root_cause_file and action.root_cause_file.lower() == self.bug.file.lower():
score += 0.25
breakdown["root_cause_file"] = {"awarded": 0.25, "correct": True}
else:
breakdown["root_cause_file"] = {
"awarded": 0.0, "correct": False,
"expected": self.bug.file, "got": action.root_cause_file,
}
# 3. root_cause_field (+0.30) β€” require majority of keywords to match
field_keywords = [kw.lower() for kw in self.bug.field.replace(".", " ").split() if len(kw) > 1]
submitted_field = (action.root_cause_field or "").lower()
field_matches = sum(1 for kw in field_keywords if kw in submitted_field)
field_threshold = max(1, len(field_keywords) // 2 + 1) # majority
field_correct = len(field_keywords) > 0 and field_matches >= field_threshold
if field_correct:
score += 0.30
breakdown["root_cause_field"] = {"awarded": 0.30, "correct": True}
else:
breakdown["root_cause_field"] = {
"awarded": 0.0, "correct": False,
"expected": self.bug.field, "got": action.root_cause_field,
"matched_keywords": field_matches, "required": field_threshold,
}
# 4. proposed_fix (+0.30) β€” keyword match against gold fix
import re as _re
_stop = {"to", "the", "a", "an", "of", "in", "on", "from", "use", "with", "and", "or", "for", "is", "at", "by"}
# Strip punctuation from keywords so "(fitted" becomes "fitted"
fix_keywords = {
_re.sub(r'[^a-z0-9_.]', '', w)
for w in self.bug.gold_fix.lower().split()
} - _stop
fix_keywords.discard("") # remove empty strings
submitted_fix = (action.proposed_fix or "").lower()
fix_overlap = sum(1 for kw in fix_keywords if kw in submitted_fix)
fix_score = min(0.30, 0.30 * (fix_overlap / max(1, len(fix_keywords))))
score += fix_score
breakdown["proposed_fix"] = {
"awarded": round(fix_score, 4),
"correct": fix_score >= 0.20,
"keyword_overlap": fix_overlap,
"total_keywords": len(fix_keywords),
}
# Hard task penalty multiplier β€” silent bugs are more costly
if self.task_id == "hard" and score < 0.70:
missed = 0.70 - min(score, 0.70)
score -= missed * 0.5 # 1.5Γ— penalty on missed components
breakdown["hard_task_penalty_applied"] = True
score = round(max(0.0, min(1.0, score)), 4)
self._current_score = score
info = {
"score": score,
"breakdown": breakdown,
"ground_truth": {
"bug_type": self.bug_type,
"category": self.bug.category,
"file": self.bug.file,
"field": self.bug.field,
"gold_fix": self.bug.gold_fix,
},
"investigation": {
"artifacts_read": self._artifacts_read,
"sanity_checks_run": self._sanity_checks_run,
"duplicate_queries": self._duplicate_queries,
"steps_taken": self._step_count,
},
}
result = {
"status": "submitted",
"score": score,
"breakdown": breakdown,
"message": f"Diagnosis submitted. Score: {score:.4f}/{1.0:.4f}",
}
return score, info, result
def _resolve_field(self, artifact: str, field_path: str) -> Any:
"""Resolve a dot-notation field path from a JSON artifact."""
import json as _json
content = self._artifacts[artifact]
if artifact.endswith(".json"):
try:
data = _json.loads(content)
parts = field_path.split(".")
val = data
for p in parts:
if isinstance(val, dict):
val = val.get(p, f"Field '{p}' not found")
else:
return f"Cannot traverse into non-dict at '{p}'"
return val
except Exception as e:
return f"Parse error: {e}"
elif artifact.endswith(".yaml"):
# Simple key search for YAML
for line in content.split("\n"):
target_key = field_path.split(".")[-1]
if f"{target_key}:" in line:
return line.strip()
return f"Field '{field_path}' not found in config"
else:
# For .py files, return lines containing the field name
target = field_path.split(".")[-1]
matches = [l.strip() for l in content.split("\n") if target in l]
return matches[:5] if matches else f"'{target}' not found in {artifact}"
def _build_obs(self, last_result: Dict[str, Any]) -> MLOpsObservation:
return MLOpsObservation(
task_id=self.task_id,
task_description=TASK_DESCRIPTIONS[self.task_id],
run_id=self._run_id,
run_summary={
"model": self._model_cfg["name"],
"dataset": self._model_cfg["dataset"],
"task": self._model_cfg["type"],
"status": "FAILED" if self.task_id == "easy" else "COMPLETED_WITH_ANOMALIES",
"note": "Investigate artifacts to determine root cause.",
},
available_artifacts=list(self._artifact_meta),
artifacts_read=list(self._artifacts_read),
last_action_result=last_result,
step_count=self._step_count,
max_steps=self._max_steps,
done=self._done,
messages=list(self._messages),
)
@property
def state(self) -> MLOpsState:
return MLOpsState(
task_id=self.task_id,
seed=self._seed,
step_count=self._step_count,
max_steps=self._max_steps,
episode_done=self._done,
bug_type=self.bug_type,
bug_category=self.bug.category,
bug_file=self.bug.file,
bug_field=self.bug.field,
gold_fix=self.bug.gold_fix,
artifacts=self._artifacts,
artifacts_read=list(self._artifacts_read),
sanity_checks_run=list(self._sanity_checks_run),
duplicate_queries=self._duplicate_queries,
current_score=self._current_score,
)
# ─── Standalone grader ────────────────────────────────────────────────────────
def grade_task(task_id: str, seed: int, diagnosis: Dict[str, Any]) -> float:
"""Deterministic grader callable by OpenEnv validation framework.
Bypasses the artifact-read penalty since the grader only evaluates
diagnosis quality, not investigation thoroughness.
"""
env = MLOpsEnvironment(task_id=task_id)
env.reset(seed=seed)
# Pre-populate artifact reads to avoid the < 3 artifacts penalty
env._artifacts_read = list(env._artifacts.keys())
action = MLOpsAction(action_type="submit_diagnosis", **diagnosis)
_, reward, _, info = env.step(action)
return info.get("score", 0.0)