# tests/test_endpoints.py # Basic endpoint tests for the environment. # Run: python -m pytest tests/ -v import requests import pytest BASE_URL = 'http://localhost:7860' def test_health_check(): """GET / should return 200 with status ok.""" r = requests.get(f'{BASE_URL}/') assert r.status_code == 200 data = r.json() assert data['status'] == 'ok' assert data['tasks'] == 9 def test_reset_valid_task(): """POST /reset with valid task_id should return episode_id and observation.""" r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'}) assert r.status_code == 200 data = r.json() assert 'episode_id' in data assert 'observation' in data assert data['observation']['task_type'] == 'security' def test_reset_all_tasks(): """POST /reset should work for all 9 task IDs.""" tasks = [ 'sec_easy', 'sec_medium', 'sec_hard', 'dep_easy', 'dep_medium', 'dep_hard', 'cli_easy', 'cli_medium', 'cli_hard', ] for task_id in tasks: r = requests.post(f'{BASE_URL}/reset', json={'task_id': task_id}) assert r.status_code == 200 data = r.json() assert 'episode_id' in data, f'No episode_id for {task_id}' assert 'observation' in data, f'No observation for {task_id}' def test_reset_invalid_task(): """POST /reset with invalid task_id should still return 200.""" r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'nonexistent'}) assert r.status_code == 200 def test_step_valid_action(): """POST /step with valid action should return reward and observation.""" # Reset first r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'}) ep_id = r.json()['episode_id'] # Step action = { 'episode_id': ep_id, 'action_type': 'identify_vulnerability', 'vuln_type': 'sql_injection', 'cvss_score': 9.1, 'severity': 'critical', 'affected_line': 1, } r = requests.post(f'{BASE_URL}/step', json=action) assert r.status_code == 200 data = r.json() assert 'reward' in data assert 'done' in data assert 'observation' in data assert 0.0 <= data['reward'] <= 1.0 def test_step_invalid_episode(): """POST /step with invalid episode_id should return 200 with done=True.""" r = requests.post(f'{BASE_URL}/step', json={ 'episode_id': 'nonexistent', 'action_type': 'identify_vulnerability', }) assert r.status_code == 200 data = r.json() assert data['done'] is True def test_state_endpoint(): """GET /state should return episode info.""" r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'}) ep_id = r.json()['episode_id'] r = requests.get(f'{BASE_URL}/state', params={'episode_id': ep_id}) assert r.status_code == 200 data = r.json() assert data['episode_id'] == ep_id assert data['done'] is False def test_reward_range(): """Rewards should always be in [0.0, 1.0].""" tasks = ['sec_easy', 'dep_easy', 'cli_easy'] for task_id in tasks: r = requests.post(f'{BASE_URL}/reset', json={'task_id': task_id}) ep_id = r.json()['episode_id'] # Send an invalid action r = requests.post(f'{BASE_URL}/step', json={ 'episode_id': ep_id, 'action_type': 'invalid_action_type', }) data = r.json() assert 0.0 <= data['reward'] <= 1.0, f'Reward out of range for {task_id}' def test_step_enriched_observation(): """Step observations should include task context fields.""" r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'}) ep_id = r.json()['episode_id'] action = { 'episode_id': ep_id, 'action_type': 'identify_vulnerability', 'vuln_type': 'sql_injection', 'cvss_score': 9.1, 'severity': 'critical', 'affected_line': 1, } r = requests.post(f'{BASE_URL}/step', json=action) obs = r.json()['observation'] assert 'task_type' in obs assert 'max_steps' in obs assert 'steps_remaining' in obs