File size: 7,657 Bytes
ffd85e1
 
72ddcb6
 
 
ffd85e1
72ddcb6
 
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ddcb6
ffd85e1
 
72ddcb6
 
 
ffd85e1
72ddcb6
 
 
 
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ddcb6
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ddcb6
 
 
ffd85e1
 
 
 
72ddcb6
ffd85e1
 
 
 
a8f49ae
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ddcb6
ffd85e1
72ddcb6
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Literal, Optional

from openenv.core.env_server.types import Action, Observation
from pydantic import Field

# ── Action Types ───────────────────────────────────────────────────────────

class MLDebugAction(Action):
    """
    Polymorphic action model. The agent specifies `action_type` to choose
    which tool to invoke. Default is `submit_fix` for backward compatibility.
    """
    
    action_type: Literal[
        "submit_fix",
        "execute_snippet",
        "inspect_tensor",
        "run_training_probe",
        "get_variable_state",
        "inspect_diff",
    ] = Field(
        default="submit_fix",
        description="Which tool to invoke. Default: submit_fix",
    )
    
    # ── Session tracking ──
    episode_id: str = Field(
        default="",
        description="Episode ID returned from reset(). Required for session tracking.",
    )
    
    # ── submit_fix fields (legacy, backward compat) ──
    fixed_code: str = Field(
        default="",
        description="For submit_fix: The complete fixed Python script.",
    )
    explanation: str = Field(
        default="",
        description="For submit_fix: Optional explanation of bugs found.",
    )
    attempt_number: int = Field(
        default=1,
        ge=1,
        le=10,
        description="For submit_fix: Which attempt this is.",
    )
    
    # ── execute_snippet fields ──
    code: str = Field(
        default="",
        description="For execute_snippet/run_training_probe: Python code to run.",
    )
    
    # ── inspect_tensor fields ──
    setup_code: str = Field(
        default="",
        description="For inspect_tensor/get_variable_state: Setup code (imports, definitions).",
    )
    target_expression: str = Field(
        default="",
        description="For inspect_tensor: The expression to inspect (e.g., 'model.weight.grad').",
    )
    
    # ── run_training_probe fields ──
    steps: int = Field(
        default=5,
        ge=1,
        le=10,
        description="For run_training_probe: Number of training steps (max 10).",
    )
    
    # ── get_variable_state fields ──
    expressions: list[str] = Field(
        default_factory=list,
        description="For get_variable_state: List of expressions to evaluate (max 10).",
    )
    
    # ── inspect_diff fields ──
    proposed_code: str = Field(
        default="",
        description="For inspect_diff: The proposed fix to diff against original.",
    )


# ── Observation Types ──────────────────────────────────────────────────────

class VariableResult(Observation):
    """Result of evaluating a single variable expression."""
    repr_str: str = Field(default="", description="String representation of the value")
    type_str: str = Field(default="", description="Type name of the value")
    value: Optional[float | int | bool | str] = Field(default=None, description="Scalar value if applicable")
    shape: Optional[list[int]] = Field(default=None, description="Shape for tensors/arrays")
    error: Optional[str] = Field(default=None, description="Error message if evaluation failed")


class MLDebugObservation(Observation):
    """
    Polymorphic observation returned by step().
    Fields populated depend on which action_type was invoked.
    """

    # ── Common fields ──
    action_type: str = Field(default="submit_fix", description="Which action produced this observation")
    turn: int = Field(default=0, description="Current turn number in this episode")
    episode_done: bool = Field(default=False, description="Whether the episode is complete")
    reward: Optional[float] = Field(default=None, description="Reward (None for tools, grader score for submit_fix)")
    error: Optional[str] = Field(default=None, description="Error message if action failed")
    episode_id: str = Field(default="", description="Episode ID for session tracking")
    
    # ── Task context (always present after reset) ──
    task_id: str = Field(default="", description="task1 | task2 | ... | task6")
    task_description: str = Field(default="", description="Plain English task instructions")
    buggy_code: str = Field(default="", description="The broken training script")
    
    # ── execute_snippet / run_training_probe output ──
    stdout: str = Field(default="", description="Captured stdout from code execution")
    stderr: str = Field(default="", description="Captured stderr from code execution")
    exit_code: int = Field(default=0, description="Exit code from code execution")
    timed_out: bool = Field(default=False, description="Whether execution timed out")
    
    # ── inspect_tensor output ──
    shape: Optional[list[int]] = Field(default=None, description="Tensor shape")
    dtype: Optional[str] = Field(default=None, description="Tensor dtype")
    requires_grad: Optional[bool] = Field(default=None, description="Whether tensor requires grad")
    grad_is_none: Optional[bool] = Field(default=None, description="Whether .grad is None")
    min_val: Optional[float] = Field(default=None, description="Tensor min value")
    max_val: Optional[float] = Field(default=None, description="Tensor max value")
    mean_val: Optional[float] = Field(default=None, description="Tensor mean value")
    is_nan: Optional[bool] = Field(default=None, description="Whether tensor contains NaN")
    is_inf: Optional[bool] = Field(default=None, description="Whether tensor contains Inf")
    
    # ── run_training_probe output ──
    losses: list[float] = Field(default_factory=list, description="Loss values per training step")
    grad_norms: dict[str, float] = Field(default_factory=dict, description="Layer name -> gradient norm")
    optimizer_param_count: Optional[int] = Field(default=None, description="Number of optimizer parameters")
    final_loss: Optional[float] = Field(default=None, description="Final loss value")
    loss_is_nan: bool = Field(default=False, description="Whether final loss is NaN")
    loss_is_inf: bool = Field(default=False, description="Whether final loss is Inf")
    
    # ── get_variable_state output ──
    results: dict[str, dict] = Field(
        default_factory=dict,
        description="Expression -> VariableResult dict for get_variable_state",
    )
    
    # ── inspect_diff output ──
    diff: str = Field(default="", description="Unified diff between original and proposed code")
    lines_changed: int = Field(default=0, description="Number of lines changed")
    additions: int = Field(default=0, description="Number of lines added")
    deletions: int = Field(default=0, description="Number of lines deleted")
    
    # ── submit_fix output ──
    success: bool = Field(default=False, description="Whether reward >= 0.7")
    grader_details: dict = Field(default_factory=dict, description="Detailed grader breakdown")
    turns_used: int = Field(default=0, description="How many tool calls before this submission")
    
    # ── Legacy fields for backward compat ──
    error_log: str = Field(default="", description="Legacy: stdout+stderr combined")
    last_reward: float = Field(default=0.0, description="Legacy: same as reward")
    metrics: dict = Field(default_factory=dict, description="Legacy: structured metrics")
    done: bool = Field(default=False, description="Legacy: same as episode_done")