Spaces:
Sleeping
Sleeping
| """ | |
| MLOps Pipeline Debugger β Pydantic Models | |
| The agent acts as an ML engineer investigating a broken training run. | |
| It has access to training artifacts (logs, configs, dataset stats, preprocessing code) | |
| and must diagnose the root cause through systematic investigation. | |
| Action Space: | |
| read_config β Get training configuration (hyperparams, model arch, optimizer) | |
| read_logs β Get training logs (filterable by keyword/epoch range) | |
| check_dataset_stats β Get dataset split sizes, class distribution, feature statistics | |
| inspect_preprocessing β Read preprocessing pipeline code | |
| read_eval_results β Get validation and test set evaluation metrics | |
| run_sanity_check β Compute a specific diagnostic check (label overlap, class balance, etc.) | |
| query_artifact β Fetch a specific field from any artifact | |
| submit_diagnosis β Final answer β triggers grading | |
| Observation Space: | |
| task_id, task_description | |
| available_artifacts β list of artifacts the agent can inspect | |
| last_action_result β result of the most recent action | |
| artifacts_read β which artifacts have been read so far (exploration tracking) | |
| step_count, max_steps | |
| done | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Literal, Optional | |
| from pydantic import BaseModel, Field | |
| # βββ Action ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MLOpsAction(BaseModel): | |
| """ | |
| One action the agent can take per step. | |
| action_type determines which fields are used: | |
| read_config β (no extra fields) | |
| read_logs β log_filter (optional keyword or "epoch:N-M") | |
| check_dataset_stats β (no extra fields) | |
| inspect_preprocessing β (no extra fields) | |
| read_eval_results β (no extra fields) | |
| run_sanity_check β sanity_check_type (required) | |
| query_artifact β artifact_name + field_path (required) | |
| submit_diagnosis β all diagnosis fields (required) | |
| """ | |
| action_type: Literal[ | |
| "read_config", | |
| "read_logs", | |
| "check_dataset_stats", | |
| "inspect_preprocessing", | |
| "read_eval_results", | |
| "run_sanity_check", | |
| "query_artifact", | |
| "submit_diagnosis", | |
| ] = Field(..., description="Which action to perform") | |
| # read_logs | |
| log_filter: Optional[str] = Field( | |
| None, | |
| description="Filter logs by keyword (e.g. 'nan', 'warning', 'error') or epoch range 'epoch:1-5'" | |
| ) | |
| # run_sanity_check | |
| sanity_check_type: Optional[Literal[ | |
| "label_consistency", # Are train/eval label mappings identical? | |
| "data_leakage", # Is there train/val sample overlap? | |
| "gradient_norms", # Are gradient norms within normal range? | |
| "class_balance", # Are classes balanced across splits? | |
| "feature_statistics", # Do train/val feature distributions match? | |
| "encoder_version_match", # Do all pipeline stages use the same encoder version? | |
| "loss_trajectory", # Is the loss curve shape anomalous? | |
| "metric_gap_analysis", # Is val vs test metric gap suspiciously large? | |
| ]] = Field(None, description="Type of sanity check to run") | |
| # query_artifact | |
| artifact_name: Optional[Literal[ | |
| "config.yaml", | |
| "train.log", | |
| "dataset_stats.json", | |
| "preprocessing.py", | |
| "eval_results.json", | |
| "model_card.json", | |
| ]] = Field(None, description="Artifact to query a specific field from") | |
| field_path: Optional[str] = Field( | |
| None, | |
| description="Dot-notation field path, e.g. 'optimizer.learning_rate' or 'metrics.val_accuracy'" | |
| ) | |
| # submit_diagnosis | |
| failure_category: Optional[Literal[ | |
| "config_error", # Wrong hyperparameter value | |
| "data_leakage", # Train/val contamination | |
| "evaluation_bug", # Eval pipeline uses wrong artifacts | |
| "preprocessing_bug", # Data transformation applied incorrectly | |
| "label_mismatch", # Label encoding inconsistency | |
| "architecture_bug", # Model architecture misconfiguration | |
| ]] = Field(None, description="Category of the failure") | |
| root_cause_file: Optional[str] = Field( | |
| None, description="Which artifact file contains the root cause" | |
| ) | |
| root_cause_field: Optional[str] = Field( | |
| None, description="Specific parameter, function, or variable that is wrong" | |
| ) | |
| diagnosis: Optional[str] = Field( | |
| None, description="Natural language explanation of what went wrong and why" | |
| ) | |
| proposed_fix: Optional[str] = Field( | |
| None, description="Concrete change that would fix the issue" | |
| ) | |
| # βββ Observation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ArtifactMeta(BaseModel): | |
| name: str | |
| description: str | |
| size_hint: str # e.g. "47 lines", "12 fields" | |
| last_modified: str | |
| class MLOpsObservation(BaseModel): | |
| """Everything the agent sees after each step / reset.""" | |
| task_id: str | |
| task_description: str | |
| # Run summary β always visible | |
| run_id: str | |
| run_summary: Dict[str, Any] = Field( | |
| description="High-level run info: model, dataset, final loss, training status" | |
| ) | |
| available_artifacts: List[ArtifactMeta] | |
| artifacts_read: List[str] = Field( | |
| default_factory=list, | |
| description="Names of artifacts the agent has already read" | |
| ) | |
| last_action_result: Dict[str, Any] = Field(default_factory=dict) | |
| step_count: int = 0 | |
| max_steps: int = 30 | |
| done: bool = False | |
| messages: List[str] = Field(default_factory=list) | |
| # βββ State βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MLOpsState(BaseModel): | |
| """Full internal state β for RL harness and debugging.""" | |
| task_id: str | |
| seed: int | |
| step_count: int | |
| max_steps: int | |
| episode_done: bool | |
| # Planted bug ground truth | |
| bug_type: str | |
| bug_category: str | |
| bug_file: str | |
| bug_field: str | |
| gold_fix: str | |
| # All generated artifacts (full text) | |
| artifacts: Dict[str, str] | |
| # Agent's investigation history | |
| artifacts_read: List[str] | |
| sanity_checks_run: List[str] | |
| duplicate_queries: int | |
| current_score: float | |