whipstudio / tests /test_tools.py
Amogh-kal1's picture
Upload folder using huggingface_hub
ffd85e1 verified
"""
Tool execution tests for WhipStudio.
Tests the debugging tools: execute_snippet, inspect_tensor,
run_training_probe, get_variable_state, inspect_diff.
"""
import pytest
import sys
import os
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class TestSecurityChecks:
"""Test security validation for code execution."""
def test_banned_import_socket(self):
"""Socket import should be rejected."""
from server.environment import check_code_security
code = "import socket\ns = socket.socket()"
is_safe, error = check_code_security(code)
assert not is_safe
assert "socket" in error.lower()
def test_banned_import_requests(self):
"""Requests import should be rejected."""
from server.environment import check_code_security
code = "import requests\nrequests.get('http://evil.com')"
is_safe, error = check_code_security(code)
assert not is_safe
assert "requests" in error.lower()
def test_banned_import_subprocess(self):
"""Subprocess import should be rejected."""
from server.environment import check_code_security
code = "import subprocess\nsubprocess.run(['ls'])"
is_safe, error = check_code_security(code)
assert not is_safe
assert "subprocess" in error.lower()
def test_allowed_imports(self):
"""Standard ML imports should be allowed."""
from server.environment import check_code_security
code = """
import torch
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
import math
import json
"""
is_safe, error = check_code_security(code)
assert is_safe, f"Should be safe: {error}"
def test_file_write_outside_tmp(self):
"""File writes outside /tmp should be rejected."""
from server.environment import check_code_security
code = "open('/etc/passwd', 'w').write('hacked')"
is_safe, error = check_code_security(code)
assert not is_safe
assert "tmp" in error.lower() or "file" in error.lower()
def test_file_write_in_tmp_allowed(self):
"""File writes in /tmp should be allowed."""
from server.environment import check_code_security
code = "open('/tmp/test.txt', 'w').write('ok')"
is_safe, error = check_code_security(code)
assert is_safe, f"Should be safe: {error}"
class TestToolDefinitions:
"""Test that tool definitions are complete."""
def test_all_tools_defined(self):
"""All 6 tools should be defined."""
from server.environment import TOOL_DEFINITIONS
expected_tools = {
"execute_snippet",
"inspect_tensor",
"run_training_probe",
"get_variable_state",
"inspect_diff",
"submit_fix"
}
# TOOL_DEFINITIONS is a list of dicts
defined_tools = {t["name"] for t in TOOL_DEFINITIONS}
assert expected_tools == defined_tools
def test_tool_definitions_have_required_fields(self):
"""Each tool definition should have name, description, action_fields."""
from server.environment import TOOL_DEFINITIONS
for tool_def in TOOL_DEFINITIONS:
assert "name" in tool_def, f"Tool missing name"
assert "description" in tool_def, f"{tool_def.get('name')} missing description"
assert "action_fields" in tool_def, f"{tool_def.get('name')} missing action_fields"
class TestActionParsing:
"""Test that actions are parsed correctly."""
def test_submit_fix_action(self):
"""submit_fix action should parse correctly."""
from models import MLDebugAction
action = MLDebugAction(
action_type="submit_fix",
fixed_code="import torch\nprint('hello')"
)
assert action.action_type == "submit_fix"
assert "import torch" in action.fixed_code
def test_execute_snippet_action(self):
"""execute_snippet action should parse correctly."""
from models import MLDebugAction
action = MLDebugAction(
action_type="execute_snippet",
code="print('test')"
)
assert action.action_type == "execute_snippet"
assert action.code == "print('test')"
def test_inspect_tensor_action(self):
"""inspect_tensor action should parse correctly."""
from models import MLDebugAction
action = MLDebugAction(
action_type="inspect_tensor",
setup_code="import torch; t = torch.randn(3, 4)",
target_expression="t.shape"
)
assert action.action_type == "inspect_tensor"
assert action.target_expression == "t.shape"
def test_get_variable_state_action(self):
"""get_variable_state action should parse correctly."""
from models import MLDebugAction
action = MLDebugAction(
action_type="get_variable_state",
setup_code="x = 1",
expressions=["x", "x + 1"]
)
assert action.action_type == "get_variable_state"
assert len(action.expressions) == 2
def test_run_training_probe_action(self):
"""run_training_probe action should parse correctly."""
from models import MLDebugAction
action = MLDebugAction(
action_type="run_training_probe",
code="# training code",
steps=5
)
assert action.action_type == "run_training_probe"
assert action.steps == 5
def test_inspect_diff_action(self):
"""inspect_diff action should parse correctly."""
from models import MLDebugAction
action = MLDebugAction(
action_type="inspect_diff",
proposed_code="# fixed code"
)
assert action.action_type == "inspect_diff"
class TestObservationModel:
"""Test observation model fields."""
def test_observation_has_all_fields(self):
"""Observation should have fields for all tools."""
from models import MLDebugObservation
obs = MLDebugObservation()
# Common fields
assert hasattr(obs, "turn")
assert hasattr(obs, "episode_done")
assert hasattr(obs, "task_id")
# execute_snippet fields
assert hasattr(obs, "stdout")
assert hasattr(obs, "stderr")
assert hasattr(obs, "exit_code")
assert hasattr(obs, "timed_out")
# inspect_tensor fields
assert hasattr(obs, "shape")
assert hasattr(obs, "dtype")
assert hasattr(obs, "requires_grad")
assert hasattr(obs, "grad_is_none")
assert hasattr(obs, "min_val")
assert hasattr(obs, "max_val")
assert hasattr(obs, "mean_val")
assert hasattr(obs, "is_nan")
assert hasattr(obs, "is_inf")
# run_training_probe fields
assert hasattr(obs, "losses")
assert hasattr(obs, "grad_norms")
assert hasattr(obs, "optimizer_param_count")
assert hasattr(obs, "final_loss")
assert hasattr(obs, "loss_is_nan")
assert hasattr(obs, "loss_is_inf")
# get_variable_state fields
assert hasattr(obs, "results")
# inspect_diff fields
assert hasattr(obs, "diff")
assert hasattr(obs, "lines_changed")
assert hasattr(obs, "additions")
assert hasattr(obs, "deletions")
if __name__ == "__main__":
pytest.main([__file__, "-v"])