File size: 7,803 Bytes
ffd85e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""
Tool execution tests for WhipStudio.

Tests the debugging tools: execute_snippet, inspect_tensor, 
run_training_probe, get_variable_state, inspect_diff.
"""

import pytest
import sys
import os

# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


class TestSecurityChecks:
    """Test security validation for code execution."""
    
    def test_banned_import_socket(self):
        """Socket import should be rejected."""
        from server.environment import check_code_security
        
        code = "import socket\ns = socket.socket()"
        is_safe, error = check_code_security(code)
        assert not is_safe
        assert "socket" in error.lower()
    
    def test_banned_import_requests(self):
        """Requests import should be rejected."""
        from server.environment import check_code_security
        
        code = "import requests\nrequests.get('http://evil.com')"
        is_safe, error = check_code_security(code)
        assert not is_safe
        assert "requests" in error.lower()
    
    def test_banned_import_subprocess(self):
        """Subprocess import should be rejected."""
        from server.environment import check_code_security
        
        code = "import subprocess\nsubprocess.run(['ls'])"
        is_safe, error = check_code_security(code)
        assert not is_safe
        assert "subprocess" in error.lower()
    
    def test_allowed_imports(self):
        """Standard ML imports should be allowed."""
        from server.environment import check_code_security
        
        code = """
import torch
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
import math
import json
"""
        is_safe, error = check_code_security(code)
        assert is_safe, f"Should be safe: {error}"
    
    def test_file_write_outside_tmp(self):
        """File writes outside /tmp should be rejected."""
        from server.environment import check_code_security
        
        code = "open('/etc/passwd', 'w').write('hacked')"
        is_safe, error = check_code_security(code)
        assert not is_safe
        assert "tmp" in error.lower() or "file" in error.lower()
    
    def test_file_write_in_tmp_allowed(self):
        """File writes in /tmp should be allowed."""
        from server.environment import check_code_security
        
        code = "open('/tmp/test.txt', 'w').write('ok')"
        is_safe, error = check_code_security(code)
        assert is_safe, f"Should be safe: {error}"


class TestToolDefinitions:
    """Test that tool definitions are complete."""
    
    def test_all_tools_defined(self):
        """All 6 tools should be defined."""
        from server.environment import TOOL_DEFINITIONS
        
        expected_tools = {
            "execute_snippet",
            "inspect_tensor", 
            "run_training_probe",
            "get_variable_state",
            "inspect_diff",
            "submit_fix"
        }
        
        # TOOL_DEFINITIONS is a list of dicts
        defined_tools = {t["name"] for t in TOOL_DEFINITIONS}
        assert expected_tools == defined_tools
    
    def test_tool_definitions_have_required_fields(self):
        """Each tool definition should have name, description, action_fields."""
        from server.environment import TOOL_DEFINITIONS
        
        for tool_def in TOOL_DEFINITIONS:
            assert "name" in tool_def, f"Tool missing name"
            assert "description" in tool_def, f"{tool_def.get('name')} missing description"
            assert "action_fields" in tool_def, f"{tool_def.get('name')} missing action_fields"


class TestActionParsing:
    """Test that actions are parsed correctly."""
    
    def test_submit_fix_action(self):
        """submit_fix action should parse correctly."""
        from models import MLDebugAction
        
        action = MLDebugAction(
            action_type="submit_fix",
            fixed_code="import torch\nprint('hello')"
        )
        assert action.action_type == "submit_fix"
        assert "import torch" in action.fixed_code
    
    def test_execute_snippet_action(self):
        """execute_snippet action should parse correctly."""
        from models import MLDebugAction
        
        action = MLDebugAction(
            action_type="execute_snippet",
            code="print('test')"
        )
        assert action.action_type == "execute_snippet"
        assert action.code == "print('test')"
    
    def test_inspect_tensor_action(self):
        """inspect_tensor action should parse correctly."""
        from models import MLDebugAction
        
        action = MLDebugAction(
            action_type="inspect_tensor",
            setup_code="import torch; t = torch.randn(3, 4)",
            target_expression="t.shape"
        )
        assert action.action_type == "inspect_tensor"
        assert action.target_expression == "t.shape"
    
    def test_get_variable_state_action(self):
        """get_variable_state action should parse correctly."""
        from models import MLDebugAction
        
        action = MLDebugAction(
            action_type="get_variable_state",
            setup_code="x = 1",
            expressions=["x", "x + 1"]
        )
        assert action.action_type == "get_variable_state"
        assert len(action.expressions) == 2
    
    def test_run_training_probe_action(self):
        """run_training_probe action should parse correctly."""
        from models import MLDebugAction
        
        action = MLDebugAction(
            action_type="run_training_probe",
            code="# training code",
            steps=5
        )
        assert action.action_type == "run_training_probe"
        assert action.steps == 5
    
    def test_inspect_diff_action(self):
        """inspect_diff action should parse correctly."""
        from models import MLDebugAction
        
        action = MLDebugAction(
            action_type="inspect_diff",
            proposed_code="# fixed code"
        )
        assert action.action_type == "inspect_diff"


class TestObservationModel:
    """Test observation model fields."""
    
    def test_observation_has_all_fields(self):
        """Observation should have fields for all tools."""
        from models import MLDebugObservation
        
        obs = MLDebugObservation()
        
        # Common fields
        assert hasattr(obs, "turn")
        assert hasattr(obs, "episode_done")
        assert hasattr(obs, "task_id")
        
        # execute_snippet fields
        assert hasattr(obs, "stdout")
        assert hasattr(obs, "stderr")
        assert hasattr(obs, "exit_code")
        assert hasattr(obs, "timed_out")
        
        # inspect_tensor fields
        assert hasattr(obs, "shape")
        assert hasattr(obs, "dtype")
        assert hasattr(obs, "requires_grad")
        assert hasattr(obs, "grad_is_none")
        assert hasattr(obs, "min_val")
        assert hasattr(obs, "max_val")
        assert hasattr(obs, "mean_val")
        assert hasattr(obs, "is_nan")
        assert hasattr(obs, "is_inf")
        
        # run_training_probe fields
        assert hasattr(obs, "losses")
        assert hasattr(obs, "grad_norms")
        assert hasattr(obs, "optimizer_param_count")
        assert hasattr(obs, "final_loss")
        assert hasattr(obs, "loss_is_nan")
        assert hasattr(obs, "loss_is_inf")
        
        # get_variable_state fields
        assert hasattr(obs, "results")
        
        # inspect_diff fields
        assert hasattr(obs, "diff")
        assert hasattr(obs, "lines_changed")
        assert hasattr(obs, "additions")
        assert hasattr(obs, "deletions")


if __name__ == "__main__":
    pytest.main([__file__, "-v"])