Priyansh Saxena commited on
Commit
b0fdd8b
·
1 Parent(s): 1700927

test: assert task scores stay in (0,1)

Browse files
Files changed (1) hide show
  1. tests/test_task_score_bounds.py +67 -0
tests/test_task_score_bounds.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from src.pytorch_debug_env.bug_library import BUG_TEMPLATES
4
+ from src.pytorch_debug_env.environment import PyTorchDebugEnv
5
+ from src.pytorch_debug_env.graders import grade_easy, grade_medium, grade_hard
6
+ from src.pytorch_debug_env.models import FinalDiagnosis, Hypothesis, PyTorchDebugAction
7
+ from src.pytorch_debug_env.scenario_generator import ScenarioGenerator
8
+
9
+
10
+ def _build_action_from_gt(gt: dict) -> PyTorchDebugAction:
11
+ hypothesis = Hypothesis(
12
+ bug_type=gt["bug_type"],
13
+ affected_file=gt["primary_bug_file"],
14
+ confidence=0.9,
15
+ )
16
+ final = FinalDiagnosis(
17
+ bug_type=gt["bug_type"],
18
+ affected_file=gt["primary_bug_file"],
19
+ line_range=gt["line_range"],
20
+ fix_strategy=gt["fix_strategy"],
21
+ confidence=0.9,
22
+ )
23
+ return PyTorchDebugAction(
24
+ current_hypothesis=hypothesis,
25
+ commit_diagnosis=True,
26
+ final_diagnosis=final,
27
+ )
28
+
29
+
30
+ @pytest.mark.parametrize(
31
+ "task_id,grader",
32
+ [
33
+ ("easy", grade_easy),
34
+ ("medium", grade_medium),
35
+ ("hard", grade_hard),
36
+ ],
37
+ )
38
+ @pytest.mark.asyncio
39
+ async def test_task_scores_strict_bounds(task_id, grader):
40
+ env = PyTorchDebugEnv(generator=ScenarioGenerator(BUG_TEMPLATES))
41
+ await env.reset(task_id, seed=7)
42
+ scenario = env.runtime.scenario
43
+ action = _build_action_from_gt(scenario.ground_truth)
44
+
45
+ score = grader(action.final_diagnosis.model_dump(), scenario.ground_truth)
46
+ assert 0.0 < score < 1.0
47
+
48
+ result = await env.step(action)
49
+ assert 0.0 < result["reward"] < 1.0
50
+ state = await env.state()
51
+ assert 0.0 < state.final_score < 1.0
52
+
53
+
54
+ @pytest.mark.parametrize(
55
+ "grader",
56
+ [grade_easy, grade_medium, grade_hard],
57
+ )
58
+ def test_empty_action_is_clamped(grader):
59
+ gt = {
60
+ "bug_type": "missing_zero_grad",
61
+ "primary_bug_file": "train.py",
62
+ "related_files": [],
63
+ "line_range": [10, 12],
64
+ "fix_strategy": "Call optimizer.zero_grad() before loss.backward()",
65
+ }
66
+ score = grader({}, gt)
67
+ assert 0.0 < score < 1.0