whipstudio / models.py
Amogh-kal1's picture
Upload folder using huggingface_hub
a8f49ae verified
from typing import Literal, Optional
from openenv.core.env_server.types import Action, Observation
from pydantic import Field
# ── Action Types ───────────────────────────────────────────────────────────
class MLDebugAction(Action):
"""
Polymorphic action model. The agent specifies `action_type` to choose
which tool to invoke. Default is `submit_fix` for backward compatibility.
"""
action_type: Literal[
"submit_fix",
"execute_snippet",
"inspect_tensor",
"run_training_probe",
"get_variable_state",
"inspect_diff",
] = Field(
default="submit_fix",
description="Which tool to invoke. Default: submit_fix",
)
# ── Session tracking ──
episode_id: str = Field(
default="",
description="Episode ID returned from reset(). Required for session tracking.",
)
# ── submit_fix fields (legacy, backward compat) ──
fixed_code: str = Field(
default="",
description="For submit_fix: The complete fixed Python script.",
)
explanation: str = Field(
default="",
description="For submit_fix: Optional explanation of bugs found.",
)
attempt_number: int = Field(
default=1,
ge=1,
le=10,
description="For submit_fix: Which attempt this is.",
)
# ── execute_snippet fields ──
code: str = Field(
default="",
description="For execute_snippet/run_training_probe: Python code to run.",
)
# ── inspect_tensor fields ──
setup_code: str = Field(
default="",
description="For inspect_tensor/get_variable_state: Setup code (imports, definitions).",
)
target_expression: str = Field(
default="",
description="For inspect_tensor: The expression to inspect (e.g., 'model.weight.grad').",
)
# ── run_training_probe fields ──
steps: int = Field(
default=5,
ge=1,
le=10,
description="For run_training_probe: Number of training steps (max 10).",
)
# ── get_variable_state fields ──
expressions: list[str] = Field(
default_factory=list,
description="For get_variable_state: List of expressions to evaluate (max 10).",
)
# ── inspect_diff fields ──
proposed_code: str = Field(
default="",
description="For inspect_diff: The proposed fix to diff against original.",
)
# ── Observation Types ──────────────────────────────────────────────────────
class VariableResult(Observation):
"""Result of evaluating a single variable expression."""
repr_str: str = Field(default="", description="String representation of the value")
type_str: str = Field(default="", description="Type name of the value")
value: Optional[float | int | bool | str] = Field(default=None, description="Scalar value if applicable")
shape: Optional[list[int]] = Field(default=None, description="Shape for tensors/arrays")
error: Optional[str] = Field(default=None, description="Error message if evaluation failed")
class MLDebugObservation(Observation):
"""
Polymorphic observation returned by step().
Fields populated depend on which action_type was invoked.
"""
# ── Common fields ──
action_type: str = Field(default="submit_fix", description="Which action produced this observation")
turn: int = Field(default=0, description="Current turn number in this episode")
episode_done: bool = Field(default=False, description="Whether the episode is complete")
reward: Optional[float] = Field(default=None, description="Reward (None for tools, grader score for submit_fix)")
error: Optional[str] = Field(default=None, description="Error message if action failed")
episode_id: str = Field(default="", description="Episode ID for session tracking")
# ── Task context (always present after reset) ──
task_id: str = Field(default="", description="task1 | task2 | ... | task6")
task_description: str = Field(default="", description="Plain English task instructions")
buggy_code: str = Field(default="", description="The broken training script")
# ── execute_snippet / run_training_probe output ──
stdout: str = Field(default="", description="Captured stdout from code execution")
stderr: str = Field(default="", description="Captured stderr from code execution")
exit_code: int = Field(default=0, description="Exit code from code execution")
timed_out: bool = Field(default=False, description="Whether execution timed out")
# ── inspect_tensor output ──
shape: Optional[list[int]] = Field(default=None, description="Tensor shape")
dtype: Optional[str] = Field(default=None, description="Tensor dtype")
requires_grad: Optional[bool] = Field(default=None, description="Whether tensor requires grad")
grad_is_none: Optional[bool] = Field(default=None, description="Whether .grad is None")
min_val: Optional[float] = Field(default=None, description="Tensor min value")
max_val: Optional[float] = Field(default=None, description="Tensor max value")
mean_val: Optional[float] = Field(default=None, description="Tensor mean value")
is_nan: Optional[bool] = Field(default=None, description="Whether tensor contains NaN")
is_inf: Optional[bool] = Field(default=None, description="Whether tensor contains Inf")
# ── run_training_probe output ──
losses: list[float] = Field(default_factory=list, description="Loss values per training step")
grad_norms: dict[str, float] = Field(default_factory=dict, description="Layer name -> gradient norm")
optimizer_param_count: Optional[int] = Field(default=None, description="Number of optimizer parameters")
final_loss: Optional[float] = Field(default=None, description="Final loss value")
loss_is_nan: bool = Field(default=False, description="Whether final loss is NaN")
loss_is_inf: bool = Field(default=False, description="Whether final loss is Inf")
# ── get_variable_state output ──
results: dict[str, dict] = Field(
default_factory=dict,
description="Expression -> VariableResult dict for get_variable_state",
)
# ── inspect_diff output ──
diff: str = Field(default="", description="Unified diff between original and proposed code")
lines_changed: int = Field(default=0, description="Number of lines changed")
additions: int = Field(default=0, description="Number of lines added")
deletions: int = Field(default=0, description="Number of lines deleted")
# ── submit_fix output ──
success: bool = Field(default=False, description="Whether reward >= 0.7")
grader_details: dict = Field(default_factory=dict, description="Detailed grader breakdown")
turns_used: int = Field(default=0, description="How many tool calls before this submission")
# ── Legacy fields for backward compat ──
error_log: str = Field(default="", description="Legacy: stdout+stderr combined")
last_reward: float = Field(default=0.0, description="Legacy: same as reward")
metrics: dict = Field(default_factory=dict, description="Legacy: structured metrics")
done: bool = Field(default=False, description="Legacy: same as episode_done")