File size: 5,438 Bytes
f577d1f b44d7b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """Unit tests for the Grader class — OpenAI calls are always mocked."""
import pathlib
import pytest
from unittest.mock import MagicMock, patch
from deceit_env.server.grader import Grader, GraderResult
@pytest.fixture
def tmp_grader(tmp_path):
return Grader(cache_path=tmp_path / "cache.json", openai_api_key=None)
@pytest.fixture
def api_grader(tmp_path):
return Grader(cache_path=tmp_path / "cache.json", openai_api_key="fake-key")
class TestExactMatch:
def test_identical_strings(self, tmp_grader):
result = tmp_grader.check("Canberra", "Canberra")
assert result.correct is True
assert result.method == "exact"
def test_case_insensitive(self, tmp_grader):
result = tmp_grader.check("canberra", "Canberra")
assert result.correct is True
assert result.method == "exact"
def test_trailing_punctuation_stripped(self, tmp_grader):
result = tmp_grader.check("Canberra.", "Canberra")
assert result.correct is True
assert result.method == "exact"
def test_extra_whitespace_stripped(self, tmp_grader):
result = tmp_grader.check(" Canberra ", "Canberra")
assert result.correct is True
assert result.method == "exact"
def test_wrong_answer_fails_exact(self, tmp_grader):
with pytest.raises(RuntimeError, match="no OpenAI API key"):
tmp_grader.check("Sydney", "Canberra")
def test_empty_answer_returns_incorrect(self, tmp_grader):
result = tmp_grader.check("", "Canberra")
assert result.correct is False
assert result.method == "exact"
class TestSemanticMatch:
def _mock_openai_response(self, verdict: str):
mock_client = MagicMock()
mock_choice = MagicMock()
mock_choice.message.content = verdict
mock_client.chat.completions.create.return_value.choices = [mock_choice]
return mock_client
def test_semantic_called_when_exact_fails(self, api_grader):
mock_client = self._mock_openai_response("YES")
with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
result = api_grader.check("The Australian capital", "Canberra")
assert result.method == "semantic"
assert result.correct is True
mock_client.chat.completions.create.assert_called_once()
def test_semantic_no_called_when_exact_matches(self, api_grader):
mock_client = self._mock_openai_response("YES")
with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
api_grader.check("Canberra", "Canberra")
mock_client.chat.completions.create.assert_not_called()
def test_semantic_returns_false_on_no(self, api_grader):
mock_client = self._mock_openai_response("NO")
with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
result = api_grader.check("Sydney", "Canberra")
assert result.correct is False
def test_cache_prevents_duplicate_api_calls(self, api_grader):
mock_client = self._mock_openai_response("YES")
with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
result1 = api_grader.check("The Australian capital", "Canberra")
result2 = api_grader.check("The Australian capital", "Canberra")
assert mock_client.chat.completions.create.call_count == 1
assert result1.correct == result2.correct
def test_cache_persists_to_disk(self, tmp_path):
cache_path = tmp_path / "cache.json"
grader1 = Grader(cache_path=cache_path, openai_api_key="fake-key")
mock_client = self._mock_openai_response("YES")
with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
grader1.check("The Australian capital", "Canberra")
grader2 = Grader(cache_path=cache_path, openai_api_key="fake-key")
with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
result = grader2.check("The Australian capital", "Canberra")
assert mock_client.chat.completions.create.call_count == 1
assert result.correct is True
def test_error_raised_without_api_key(self, tmp_grader):
with pytest.raises(RuntimeError, match="no OpenAI API key"):
tmp_grader.check("Sydney", "Canberra")
class TestRateLimitRetry:
def test_retries_on_429_then_succeeds(self, api_grader):
from openai import RateLimitError
import httpx
mock_client = MagicMock()
mock_choice = MagicMock()
mock_choice.message.content = "YES"
ok_response = MagicMock()
ok_response.choices = [mock_choice]
raw_response = MagicMock()
raw_response.headers = {}
raw_response.status_code = 429
_dummy_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
rate_err = RateLimitError("rate limited", response=httpx.Response(429, request=_dummy_request), body={})
mock_client.chat.completions.create.side_effect = [rate_err, ok_response]
with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
with patch("time.sleep") as mock_sleep:
result = api_grader.check("The Australian capital", "Canberra")
assert result.correct is True
assert mock_client.chat.completions.create.call_count == 2
mock_sleep.assert_called_once_with(25)
|