Priyansh Saxena commited on
Commit ·
b0fdd8b
1
Parent(s): 1700927
test: assert task scores stay in (0,1)
Browse files
tests/test_task_score_bounds.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from src.pytorch_debug_env.bug_library import BUG_TEMPLATES
|
| 4 |
+
from src.pytorch_debug_env.environment import PyTorchDebugEnv
|
| 5 |
+
from src.pytorch_debug_env.graders import grade_easy, grade_medium, grade_hard
|
| 6 |
+
from src.pytorch_debug_env.models import FinalDiagnosis, Hypothesis, PyTorchDebugAction
|
| 7 |
+
from src.pytorch_debug_env.scenario_generator import ScenarioGenerator
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _build_action_from_gt(gt: dict) -> PyTorchDebugAction:
|
| 11 |
+
hypothesis = Hypothesis(
|
| 12 |
+
bug_type=gt["bug_type"],
|
| 13 |
+
affected_file=gt["primary_bug_file"],
|
| 14 |
+
confidence=0.9,
|
| 15 |
+
)
|
| 16 |
+
final = FinalDiagnosis(
|
| 17 |
+
bug_type=gt["bug_type"],
|
| 18 |
+
affected_file=gt["primary_bug_file"],
|
| 19 |
+
line_range=gt["line_range"],
|
| 20 |
+
fix_strategy=gt["fix_strategy"],
|
| 21 |
+
confidence=0.9,
|
| 22 |
+
)
|
| 23 |
+
return PyTorchDebugAction(
|
| 24 |
+
current_hypothesis=hypothesis,
|
| 25 |
+
commit_diagnosis=True,
|
| 26 |
+
final_diagnosis=final,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.mark.parametrize(
|
| 31 |
+
"task_id,grader",
|
| 32 |
+
[
|
| 33 |
+
("easy", grade_easy),
|
| 34 |
+
("medium", grade_medium),
|
| 35 |
+
("hard", grade_hard),
|
| 36 |
+
],
|
| 37 |
+
)
|
| 38 |
+
@pytest.mark.asyncio
|
| 39 |
+
async def test_task_scores_strict_bounds(task_id, grader):
|
| 40 |
+
env = PyTorchDebugEnv(generator=ScenarioGenerator(BUG_TEMPLATES))
|
| 41 |
+
await env.reset(task_id, seed=7)
|
| 42 |
+
scenario = env.runtime.scenario
|
| 43 |
+
action = _build_action_from_gt(scenario.ground_truth)
|
| 44 |
+
|
| 45 |
+
score = grader(action.final_diagnosis.model_dump(), scenario.ground_truth)
|
| 46 |
+
assert 0.0 < score < 1.0
|
| 47 |
+
|
| 48 |
+
result = await env.step(action)
|
| 49 |
+
assert 0.0 < result["reward"] < 1.0
|
| 50 |
+
state = await env.state()
|
| 51 |
+
assert 0.0 < state.final_score < 1.0
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@pytest.mark.parametrize(
|
| 55 |
+
"grader",
|
| 56 |
+
[grade_easy, grade_medium, grade_hard],
|
| 57 |
+
)
|
| 58 |
+
def test_empty_action_is_clamped(grader):
|
| 59 |
+
gt = {
|
| 60 |
+
"bug_type": "missing_zero_grad",
|
| 61 |
+
"primary_bug_file": "train.py",
|
| 62 |
+
"related_files": [],
|
| 63 |
+
"line_range": [10, 12],
|
| 64 |
+
"fix_strategy": "Call optimizer.zero_grad() before loss.backward()",
|
| 65 |
+
}
|
| 66 |
+
score = grader({}, gt)
|
| 67 |
+
assert 0.0 < score < 1.0
|