""" Tool execution tests for WhipStudio. Tests the debugging tools: execute_snippet, inspect_tensor, run_training_probe, get_variable_state, inspect_diff. """ import pytest import sys import os # Add parent directory to path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class TestSecurityChecks: """Test security validation for code execution.""" def test_banned_import_socket(self): """Socket import should be rejected.""" from server.environment import check_code_security code = "import socket\ns = socket.socket()" is_safe, error = check_code_security(code) assert not is_safe assert "socket" in error.lower() def test_banned_import_requests(self): """Requests import should be rejected.""" from server.environment import check_code_security code = "import requests\nrequests.get('http://evil.com')" is_safe, error = check_code_security(code) assert not is_safe assert "requests" in error.lower() def test_banned_import_subprocess(self): """Subprocess import should be rejected.""" from server.environment import check_code_security code = "import subprocess\nsubprocess.run(['ls'])" is_safe, error = check_code_security(code) assert not is_safe assert "subprocess" in error.lower() def test_allowed_imports(self): """Standard ML imports should be allowed.""" from server.environment import check_code_security code = """ import torch import torch.nn as nn import numpy as np from sklearn.model_selection import train_test_split import math import json """ is_safe, error = check_code_security(code) assert is_safe, f"Should be safe: {error}" def test_file_write_outside_tmp(self): """File writes outside /tmp should be rejected.""" from server.environment import check_code_security code = "open('/etc/passwd', 'w').write('hacked')" is_safe, error = check_code_security(code) assert not is_safe assert "tmp" in error.lower() or "file" in error.lower() def test_file_write_in_tmp_allowed(self): """File writes in /tmp should be allowed.""" from server.environment import check_code_security code = "open('/tmp/test.txt', 'w').write('ok')" is_safe, error = check_code_security(code) assert is_safe, f"Should be safe: {error}" class TestToolDefinitions: """Test that tool definitions are complete.""" def test_all_tools_defined(self): """All 6 tools should be defined.""" from server.environment import TOOL_DEFINITIONS expected_tools = { "execute_snippet", "inspect_tensor", "run_training_probe", "get_variable_state", "inspect_diff", "submit_fix" } # TOOL_DEFINITIONS is a list of dicts defined_tools = {t["name"] for t in TOOL_DEFINITIONS} assert expected_tools == defined_tools def test_tool_definitions_have_required_fields(self): """Each tool definition should have name, description, action_fields.""" from server.environment import TOOL_DEFINITIONS for tool_def in TOOL_DEFINITIONS: assert "name" in tool_def, f"Tool missing name" assert "description" in tool_def, f"{tool_def.get('name')} missing description" assert "action_fields" in tool_def, f"{tool_def.get('name')} missing action_fields" class TestActionParsing: """Test that actions are parsed correctly.""" def test_submit_fix_action(self): """submit_fix action should parse correctly.""" from models import MLDebugAction action = MLDebugAction( action_type="submit_fix", fixed_code="import torch\nprint('hello')" ) assert action.action_type == "submit_fix" assert "import torch" in action.fixed_code def test_execute_snippet_action(self): """execute_snippet action should parse correctly.""" from models import MLDebugAction action = MLDebugAction( action_type="execute_snippet", code="print('test')" ) assert action.action_type == "execute_snippet" assert action.code == "print('test')" def test_inspect_tensor_action(self): """inspect_tensor action should parse correctly.""" from models import MLDebugAction action = MLDebugAction( action_type="inspect_tensor", setup_code="import torch; t = torch.randn(3, 4)", target_expression="t.shape" ) assert action.action_type == "inspect_tensor" assert action.target_expression == "t.shape" def test_get_variable_state_action(self): """get_variable_state action should parse correctly.""" from models import MLDebugAction action = MLDebugAction( action_type="get_variable_state", setup_code="x = 1", expressions=["x", "x + 1"] ) assert action.action_type == "get_variable_state" assert len(action.expressions) == 2 def test_run_training_probe_action(self): """run_training_probe action should parse correctly.""" from models import MLDebugAction action = MLDebugAction( action_type="run_training_probe", code="# training code", steps=5 ) assert action.action_type == "run_training_probe" assert action.steps == 5 def test_inspect_diff_action(self): """inspect_diff action should parse correctly.""" from models import MLDebugAction action = MLDebugAction( action_type="inspect_diff", proposed_code="# fixed code" ) assert action.action_type == "inspect_diff" class TestObservationModel: """Test observation model fields.""" def test_observation_has_all_fields(self): """Observation should have fields for all tools.""" from models import MLDebugObservation obs = MLDebugObservation() # Common fields assert hasattr(obs, "turn") assert hasattr(obs, "episode_done") assert hasattr(obs, "task_id") # execute_snippet fields assert hasattr(obs, "stdout") assert hasattr(obs, "stderr") assert hasattr(obs, "exit_code") assert hasattr(obs, "timed_out") # inspect_tensor fields assert hasattr(obs, "shape") assert hasattr(obs, "dtype") assert hasattr(obs, "requires_grad") assert hasattr(obs, "grad_is_none") assert hasattr(obs, "min_val") assert hasattr(obs, "max_val") assert hasattr(obs, "mean_val") assert hasattr(obs, "is_nan") assert hasattr(obs, "is_inf") # run_training_probe fields assert hasattr(obs, "losses") assert hasattr(obs, "grad_norms") assert hasattr(obs, "optimizer_param_count") assert hasattr(obs, "final_loss") assert hasattr(obs, "loss_is_nan") assert hasattr(obs, "loss_is_inf") # get_variable_state fields assert hasattr(obs, "results") # inspect_diff fields assert hasattr(obs, "diff") assert hasattr(obs, "lines_changed") assert hasattr(obs, "additions") assert hasattr(obs, "deletions") if __name__ == "__main__": pytest.main([__file__, "-v"])