mlops-openenv / server /models.py
trretretret's picture
Add uv.lock, pyproject.toml, server directory
6a71058
"""
MLOps Pipeline Debugger β€” Pydantic Models
The agent acts as an ML engineer investigating a broken training run.
It has access to training artifacts (logs, configs, dataset stats, preprocessing code)
and must diagnose the root cause through systematic investigation.
Action Space:
read_config β†’ Get training configuration (hyperparams, model arch, optimizer)
read_logs β†’ Get training logs (filterable by keyword/epoch range)
check_dataset_stats β†’ Get dataset split sizes, class distribution, feature statistics
inspect_preprocessing β†’ Read preprocessing pipeline code
read_eval_results β†’ Get validation and test set evaluation metrics
run_sanity_check β†’ Compute a specific diagnostic check (label overlap, class balance, etc.)
query_artifact β†’ Fetch a specific field from any artifact
submit_diagnosis β†’ Final answer β€” triggers grading
Observation Space:
task_id, task_description
available_artifacts β€” list of artifacts the agent can inspect
last_action_result β€” result of the most recent action
artifacts_read β€” which artifacts have been read so far (exploration tracking)
step_count, max_steps
done
"""
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
# ─── Action ──────────────────────────────────────────────────────────────────
class MLOpsAction(BaseModel):
"""
One action the agent can take per step.
action_type determines which fields are used:
read_config β†’ (no extra fields)
read_logs β†’ log_filter (optional keyword or "epoch:N-M")
check_dataset_stats β†’ (no extra fields)
inspect_preprocessing β†’ (no extra fields)
read_eval_results β†’ (no extra fields)
run_sanity_check β†’ sanity_check_type (required)
query_artifact β†’ artifact_name + field_path (required)
submit_diagnosis β†’ all diagnosis fields (required)
"""
action_type: Literal[
"read_config",
"read_logs",
"check_dataset_stats",
"inspect_preprocessing",
"read_eval_results",
"run_sanity_check",
"query_artifact",
"submit_diagnosis",
] = Field(..., description="Which action to perform")
# read_logs
log_filter: Optional[str] = Field(
None,
description="Filter logs by keyword (e.g. 'nan', 'warning', 'error') or epoch range 'epoch:1-5'"
)
# run_sanity_check
sanity_check_type: Optional[Literal[
"label_consistency", # Are train/eval label mappings identical?
"data_leakage", # Is there train/val sample overlap?
"gradient_norms", # Are gradient norms within normal range?
"class_balance", # Are classes balanced across splits?
"feature_statistics", # Do train/val feature distributions match?
"encoder_version_match", # Do all pipeline stages use the same encoder version?
"loss_trajectory", # Is the loss curve shape anomalous?
"metric_gap_analysis", # Is val vs test metric gap suspiciously large?
]] = Field(None, description="Type of sanity check to run")
# query_artifact
artifact_name: Optional[Literal[
"config.yaml",
"train.log",
"dataset_stats.json",
"preprocessing.py",
"eval_results.json",
"model_card.json",
]] = Field(None, description="Artifact to query a specific field from")
field_path: Optional[str] = Field(
None,
description="Dot-notation field path, e.g. 'optimizer.learning_rate' or 'metrics.val_accuracy'"
)
# submit_diagnosis
failure_category: Optional[Literal[
"config_error", # Wrong hyperparameter value
"data_leakage", # Train/val contamination
"evaluation_bug", # Eval pipeline uses wrong artifacts
"preprocessing_bug", # Data transformation applied incorrectly
"label_mismatch", # Label encoding inconsistency
"architecture_bug", # Model architecture misconfiguration
]] = Field(None, description="Category of the failure")
root_cause_file: Optional[str] = Field(
None, description="Which artifact file contains the root cause"
)
root_cause_field: Optional[str] = Field(
None, description="Specific parameter, function, or variable that is wrong"
)
diagnosis: Optional[str] = Field(
None, description="Natural language explanation of what went wrong and why"
)
proposed_fix: Optional[str] = Field(
None, description="Concrete change that would fix the issue"
)
# ─── Observation ─────────────────────────────────────────────────────────────
class ArtifactMeta(BaseModel):
name: str
description: str
size_hint: str # e.g. "47 lines", "12 fields"
last_modified: str
class MLOpsObservation(BaseModel):
"""Everything the agent sees after each step / reset."""
task_id: str
task_description: str
# Run summary β€” always visible
run_id: str
run_summary: Dict[str, Any] = Field(
description="High-level run info: model, dataset, final loss, training status"
)
available_artifacts: List[ArtifactMeta]
artifacts_read: List[str] = Field(
default_factory=list,
description="Names of artifacts the agent has already read"
)
last_action_result: Dict[str, Any] = Field(default_factory=dict)
step_count: int = 0
max_steps: int = 30
done: bool = False
messages: List[str] = Field(default_factory=list)
# ─── State ───────────────────────────────────────────────────────────────────
class MLOpsState(BaseModel):
"""Full internal state β€” for RL harness and debugging."""
task_id: str
seed: int
step_count: int
max_steps: int
episode_done: bool
# Planted bug ground truth
bug_type: str
bug_category: str
bug_file: str
bug_field: str
gold_fix: str
# All generated artifacts (full text)
artifacts: Dict[str, str]
# Agent's investigation history
artifacts_read: List[str]
sanity_checks_run: List[str]
duplicate_queries: int
current_score: float