File size: 5,438 Bytes
f577d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b44d7b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Unit tests for the Grader class — OpenAI calls are always mocked."""

import pathlib
import pytest
from unittest.mock import MagicMock, patch

from deceit_env.server.grader import Grader, GraderResult


@pytest.fixture
def tmp_grader(tmp_path):
    return Grader(cache_path=tmp_path / "cache.json", openai_api_key=None)


@pytest.fixture
def api_grader(tmp_path):
    return Grader(cache_path=tmp_path / "cache.json", openai_api_key="fake-key")


class TestExactMatch:
    def test_identical_strings(self, tmp_grader):
        result = tmp_grader.check("Canberra", "Canberra")
        assert result.correct is True
        assert result.method == "exact"

    def test_case_insensitive(self, tmp_grader):
        result = tmp_grader.check("canberra", "Canberra")
        assert result.correct is True
        assert result.method == "exact"

    def test_trailing_punctuation_stripped(self, tmp_grader):
        result = tmp_grader.check("Canberra.", "Canberra")
        assert result.correct is True
        assert result.method == "exact"

    def test_extra_whitespace_stripped(self, tmp_grader):
        result = tmp_grader.check("  Canberra  ", "Canberra")
        assert result.correct is True
        assert result.method == "exact"

    def test_wrong_answer_fails_exact(self, tmp_grader):
        with pytest.raises(RuntimeError, match="no OpenAI API key"):
            tmp_grader.check("Sydney", "Canberra")

    def test_empty_answer_returns_incorrect(self, tmp_grader):
        result = tmp_grader.check("", "Canberra")
        assert result.correct is False
        assert result.method == "exact"


class TestSemanticMatch:
    def _mock_openai_response(self, verdict: str):
        mock_client = MagicMock()
        mock_choice = MagicMock()
        mock_choice.message.content = verdict
        mock_client.chat.completions.create.return_value.choices = [mock_choice]
        return mock_client

    def test_semantic_called_when_exact_fails(self, api_grader):
        mock_client = self._mock_openai_response("YES")
        with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
            result = api_grader.check("The Australian capital", "Canberra")
        assert result.method == "semantic"
        assert result.correct is True
        mock_client.chat.completions.create.assert_called_once()

    def test_semantic_no_called_when_exact_matches(self, api_grader):
        mock_client = self._mock_openai_response("YES")
        with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
            api_grader.check("Canberra", "Canberra")
        mock_client.chat.completions.create.assert_not_called()

    def test_semantic_returns_false_on_no(self, api_grader):
        mock_client = self._mock_openai_response("NO")
        with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
            result = api_grader.check("Sydney", "Canberra")
        assert result.correct is False

    def test_cache_prevents_duplicate_api_calls(self, api_grader):
        mock_client = self._mock_openai_response("YES")
        with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
            result1 = api_grader.check("The Australian capital", "Canberra")
            result2 = api_grader.check("The Australian capital", "Canberra")
        assert mock_client.chat.completions.create.call_count == 1
        assert result1.correct == result2.correct

    def test_cache_persists_to_disk(self, tmp_path):
        cache_path = tmp_path / "cache.json"
        grader1 = Grader(cache_path=cache_path, openai_api_key="fake-key")
        mock_client = self._mock_openai_response("YES")
        with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
            grader1.check("The Australian capital", "Canberra")

        grader2 = Grader(cache_path=cache_path, openai_api_key="fake-key")
        with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
            result = grader2.check("The Australian capital", "Canberra")
        assert mock_client.chat.completions.create.call_count == 1
        assert result.correct is True

    def test_error_raised_without_api_key(self, tmp_grader):
        with pytest.raises(RuntimeError, match="no OpenAI API key"):
            tmp_grader.check("Sydney", "Canberra")


class TestRateLimitRetry:
    def test_retries_on_429_then_succeeds(self, api_grader):
        from openai import RateLimitError
        import httpx

        mock_client = MagicMock()
        mock_choice = MagicMock()
        mock_choice.message.content = "YES"
        ok_response = MagicMock()
        ok_response.choices = [mock_choice]

        raw_response = MagicMock()
        raw_response.headers = {}
        raw_response.status_code = 429
        _dummy_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
        rate_err = RateLimitError("rate limited", response=httpx.Response(429, request=_dummy_request), body={})
        mock_client.chat.completions.create.side_effect = [rate_err, ok_response]

        with patch("deceit_env.server.grader.OpenAI", return_value=mock_client):
            with patch("time.sleep") as mock_sleep:
                result = api_grader.check("The Australian capital", "Canberra")

        assert result.correct is True
        assert mock_client.chat.completions.create.call_count == 2
        mock_sleep.assert_called_once_with(25)