Spaces:
Sleeping
Sleeping
Commit Β·
791664b
1
Parent(s): ca9160b
change in reward system to strict it between the 0-1
Browse files- graders/attacks.py +28 -59
- graders/clamp.py +17 -0
- graders/consistency.py +26 -54
- graders/correctness.py +76 -108
- graders/documentation.py +52 -118
- graders/performance.py +64 -34
- graders/reward_aggregator.py +64 -74
- graders/static_analysis.py +64 -122
- inference.py +85 -37
graders/attacks.py
CHANGED
|
@@ -1,14 +1,10 @@
|
|
| 1 |
"""
|
| 2 |
-
SecureCodeEnv - Attack Resistance Grader
|
| 3 |
-
|
| 4 |
-
FIXED:
|
| 5 |
-
- All check lambdas handle list OR tuple (JSON round-trip)
|
| 6 |
-
- Path traversal check: blocks absolute paths + encoded paths, not just '..'
|
| 7 |
-
- SQL check: verifies placeholder present AND payload not injected
|
| 8 |
-
- Consistent scoring regardless of seed
|
| 9 |
"""
|
| 10 |
from sandbox.executor import safe_exec
|
| 11 |
from sandbox.payload_gen import get_payloads
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def _is_seq(v):
|
|
@@ -16,37 +12,25 @@ def _is_seq(v):
|
|
| 16 |
|
| 17 |
|
| 18 |
def _sql_safe(result, payload):
|
| 19 |
-
"""SQL query output is safe when: is 2-element seq, query has placeholder, raw payload not in query."""
|
| 20 |
if not _is_seq(result) or len(result) < 2:
|
| 21 |
return False
|
| 22 |
query = str(result[0])
|
| 23 |
-
# Must use parameterized placeholder
|
| 24 |
if not any(p in query for p in ("%s", "?", ":param", "%(", ":%")):
|
| 25 |
return False
|
| 26 |
-
# The raw payload (stripped of whitespace) must NOT appear verbatim in the query
|
| 27 |
payload_stripped = payload.strip().replace(" ", "").lower()
|
| 28 |
query_lower = query.lower().replace(" ", "")
|
| 29 |
-
|
| 30 |
-
return False
|
| 31 |
-
return True
|
| 32 |
|
| 33 |
|
| 34 |
def _path_safe(result, payload):
|
| 35 |
-
"""
|
| 36 |
-
Path output is safe when resolved path stays inside /var/uploads.
|
| 37 |
-
Also catches: absolute paths (/etc/), encoded paths (%2e), null bytes.
|
| 38 |
-
"""
|
| 39 |
if result is None:
|
| 40 |
-
return True
|
| 41 |
if not isinstance(result, str):
|
| 42 |
return True
|
| 43 |
r = result.lower()
|
| 44 |
-
# Must stay inside base
|
| 45 |
if not result.startswith("/var/uploads"):
|
| 46 |
return False
|
| 47 |
-
|
| 48 |
-
bad = ["/etc/", "/root/", "/proc/", "/var/log", "passwd", "shadow",
|
| 49 |
-
"%2e", "%2f", ".."]
|
| 50 |
return not any(b in r for b in bad)
|
| 51 |
|
| 52 |
|
|
@@ -56,14 +40,14 @@ ATTACK_CONFIGS = {
|
|
| 56 |
"fn_args_template": ["{payload}", "admin"],
|
| 57 |
"attack_type": "injection",
|
| 58 |
"check": _sql_safe,
|
| 59 |
-
"exception_is_pass": False,
|
| 60 |
},
|
| 61 |
"medium_file_path_handler": {
|
| 62 |
"fn_name": "resolve_safe_path",
|
| 63 |
"fn_args_template": ["/var/uploads", "{payload}"],
|
| 64 |
"attack_type": "traversal",
|
| 65 |
"check": _path_safe,
|
| 66 |
-
"exception_is_pass": True,
|
| 67 |
},
|
| 68 |
"hard_file_upload_handler": {
|
| 69 |
"fn_name": "validate_upload",
|
|
@@ -97,18 +81,8 @@ ATTACK_CONFIGS = {
|
|
| 97 |
"attack_type": "xss",
|
| 98 |
"check": lambda result, payload: (
|
| 99 |
isinstance(result, str) and
|
| 100 |
-
not any(b in result.lower()
|
| 101 |
-
["<script", "onerror", "javascript:", "<img ", "<svg", "onload"])
|
| 102 |
-
),
|
| 103 |
-
"exception_is_pass": False,
|
| 104 |
-
},
|
| 105 |
-
"easy_input_sanitizer_filename": {
|
| 106 |
-
"fn_name": "sanitize_filename",
|
| 107 |
-
"fn_args_template": ["{payload}"],
|
| 108 |
-
"attack_type": "traversal",
|
| 109 |
-
"check": lambda result, payload: (
|
| 110 |
-
isinstance(result, str) and
|
| 111 |
-
".." not in result and "/" not in result and "\\" not in result
|
| 112 |
),
|
| 113 |
"exception_is_pass": False,
|
| 114 |
},
|
|
@@ -117,56 +91,51 @@ ATTACK_CONFIGS = {
|
|
| 117 |
|
| 118 |
def grade_attacks(code: str, task: dict, seed: int) -> dict:
|
| 119 |
task_id = task.get("id", "")
|
| 120 |
-
config
|
| 121 |
if not config:
|
| 122 |
-
return {
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
"feedback": "No attack grading defined for this task type",
|
| 126 |
-
}
|
| 127 |
|
| 128 |
payloads = get_payloads(config["attack_type"], seed=seed, count=8)
|
| 129 |
if not payloads:
|
| 130 |
-
return {"score":
|
| 131 |
-
"attack_type": config["attack_type"],
|
|
|
|
| 132 |
|
| 133 |
-
blocked
|
| 134 |
exception_is_pass = config.get("exception_is_pass", True)
|
| 135 |
|
| 136 |
for payload in payloads:
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
for a in raw_args]
|
| 140 |
result = safe_exec(code, args, function_name=config["fn_name"], timeout=3)
|
| 141 |
|
| 142 |
if not result["ok"]:
|
| 143 |
if exception_is_pass:
|
| 144 |
blocked += 1
|
| 145 |
else:
|
| 146 |
-
output = result.get("output")
|
| 147 |
try:
|
| 148 |
-
if config["check"](output, payload):
|
| 149 |
blocked += 1
|
| 150 |
except Exception:
|
| 151 |
pass
|
| 152 |
|
| 153 |
-
|
| 154 |
return {
|
| 155 |
-
"score":
|
| 156 |
"blocked": blocked,
|
| 157 |
"total": len(payloads),
|
| 158 |
"attack_type": config["attack_type"],
|
| 159 |
-
"feedback": _feedback(
|
| 160 |
}
|
| 161 |
|
| 162 |
|
| 163 |
def _feedback(score: float, attack_type: str) -> str:
|
| 164 |
-
names = {
|
| 165 |
-
|
| 166 |
-
"auth_bypass": "authentication bypass", "xss": "XSS",
|
| 167 |
-
}
|
| 168 |
name = names.get(attack_type, attack_type)
|
| 169 |
if score >= 0.875: return f"Excellent β {name} attacks blocked ({score:.0%})"
|
| 170 |
elif score >= 0.625: return f"Good β most {name} attacks blocked ({score:.0%})"
|
| 171 |
-
elif score >= 0.375: return f"Partial β
|
| 172 |
-
else:
|
|
|
|
| 1 |
"""
|
| 2 |
+
SecureCodeEnv - Attack Resistance Grader v4
|
| 3 |
+
All scores clamped to (0.001, 0.999).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
from sandbox.executor import safe_exec
|
| 6 |
from sandbox.payload_gen import get_payloads
|
| 7 |
+
from graders.clamp import clamp
|
| 8 |
|
| 9 |
|
| 10 |
def _is_seq(v):
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def _sql_safe(result, payload):
|
|
|
|
| 15 |
if not _is_seq(result) or len(result) < 2:
|
| 16 |
return False
|
| 17 |
query = str(result[0])
|
|
|
|
| 18 |
if not any(p in query for p in ("%s", "?", ":param", "%(", ":%")):
|
| 19 |
return False
|
|
|
|
| 20 |
payload_stripped = payload.strip().replace(" ", "").lower()
|
| 21 |
query_lower = query.lower().replace(" ", "")
|
| 22 |
+
return payload_stripped not in query_lower
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def _path_safe(result, payload):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
if result is None:
|
| 27 |
+
return True
|
| 28 |
if not isinstance(result, str):
|
| 29 |
return True
|
| 30 |
r = result.lower()
|
|
|
|
| 31 |
if not result.startswith("/var/uploads"):
|
| 32 |
return False
|
| 33 |
+
bad = ["/etc/", "/root/", "/proc/", "passwd", "shadow", "%2e", "%2f", ".."]
|
|
|
|
|
|
|
| 34 |
return not any(b in r for b in bad)
|
| 35 |
|
| 36 |
|
|
|
|
| 40 |
"fn_args_template": ["{payload}", "admin"],
|
| 41 |
"attack_type": "injection",
|
| 42 |
"check": _sql_safe,
|
| 43 |
+
"exception_is_pass": False,
|
| 44 |
},
|
| 45 |
"medium_file_path_handler": {
|
| 46 |
"fn_name": "resolve_safe_path",
|
| 47 |
"fn_args_template": ["/var/uploads", "{payload}"],
|
| 48 |
"attack_type": "traversal",
|
| 49 |
"check": _path_safe,
|
| 50 |
+
"exception_is_pass": True,
|
| 51 |
},
|
| 52 |
"hard_file_upload_handler": {
|
| 53 |
"fn_name": "validate_upload",
|
|
|
|
| 81 |
"attack_type": "xss",
|
| 82 |
"check": lambda result, payload: (
|
| 83 |
isinstance(result, str) and
|
| 84 |
+
not any(b in result.lower()
|
| 85 |
+
for b in ["<script", "onerror", "javascript:", "<img ", "<svg", "onload"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
),
|
| 87 |
"exception_is_pass": False,
|
| 88 |
},
|
|
|
|
| 91 |
|
| 92 |
def grade_attacks(code: str, task: dict, seed: int) -> dict:
|
| 93 |
task_id = task.get("id", "")
|
| 94 |
+
config = ATTACK_CONFIGS.get(task_id)
|
| 95 |
if not config:
|
| 96 |
+
return {"score": clamp(0.5), "blocked": 0, "total": 0,
|
| 97 |
+
"attack_type": "none",
|
| 98 |
+
"feedback": "No attack grading defined β neutral score"}
|
|
|
|
|
|
|
| 99 |
|
| 100 |
payloads = get_payloads(config["attack_type"], seed=seed, count=8)
|
| 101 |
if not payloads:
|
| 102 |
+
return {"score": clamp(0.5), "blocked": 0, "total": 0,
|
| 103 |
+
"attack_type": config["attack_type"],
|
| 104 |
+
"feedback": "No payloads generated β neutral score"}
|
| 105 |
|
| 106 |
+
blocked = 0
|
| 107 |
exception_is_pass = config.get("exception_is_pass", True)
|
| 108 |
|
| 109 |
for payload in payloads:
|
| 110 |
+
args = [a.replace("{payload}", payload) if isinstance(a, str) else a
|
| 111 |
+
for a in config["fn_args_template"]]
|
|
|
|
| 112 |
result = safe_exec(code, args, function_name=config["fn_name"], timeout=3)
|
| 113 |
|
| 114 |
if not result["ok"]:
|
| 115 |
if exception_is_pass:
|
| 116 |
blocked += 1
|
| 117 |
else:
|
|
|
|
| 118 |
try:
|
| 119 |
+
if config["check"](result.get("output"), payload):
|
| 120 |
blocked += 1
|
| 121 |
except Exception:
|
| 122 |
pass
|
| 123 |
|
| 124 |
+
raw = blocked / len(payloads)
|
| 125 |
return {
|
| 126 |
+
"score": clamp(raw),
|
| 127 |
"blocked": blocked,
|
| 128 |
"total": len(payloads),
|
| 129 |
"attack_type": config["attack_type"],
|
| 130 |
+
"feedback": _feedback(raw, config["attack_type"]),
|
| 131 |
}
|
| 132 |
|
| 133 |
|
| 134 |
def _feedback(score: float, attack_type: str) -> str:
|
| 135 |
+
names = {"injection": "SQL injection", "traversal": "path traversal",
|
| 136 |
+
"auth_bypass": "auth bypass", "xss": "XSS"}
|
|
|
|
|
|
|
| 137 |
name = names.get(attack_type, attack_type)
|
| 138 |
if score >= 0.875: return f"Excellent β {name} attacks blocked ({score:.0%})"
|
| 139 |
elif score >= 0.625: return f"Good β most {name} attacks blocked ({score:.0%})"
|
| 140 |
+
elif score >= 0.375: return f"Partial β {score:.0%} of {name} attacks blocked"
|
| 141 |
+
else: return f"Vulnerable β {score:.0%} of {name} attacks blocked"
|
graders/clamp.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared epsilon-clamping utility.
|
| 3 |
+
Validator requires scores strictly between 0 and 1: (0.001 β¦ 0.999)
|
| 4 |
+
"""
|
| 5 |
+
EPSILON = 0.001
|
| 6 |
+
SCORE_MIN = EPSILON # 0.001
|
| 7 |
+
SCORE_MAX = 1.0 - EPSILON # 0.999
|
| 8 |
+
|
| 9 |
+
def clamp(score: float) -> float:
|
| 10 |
+
"""Clamp any score to (0.001, 0.999) β never exactly 0 or 1."""
|
| 11 |
+
try:
|
| 12 |
+
v = float(score)
|
| 13 |
+
except (TypeError, ValueError):
|
| 14 |
+
return 0.5 # safe default for bad inputs
|
| 15 |
+
if v != v: # NaN guard
|
| 16 |
+
return 0.5
|
| 17 |
+
return max(SCORE_MIN, min(SCORE_MAX, v))
|
graders/consistency.py
CHANGED
|
@@ -1,46 +1,31 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SecureCodeEnv - Consistency Grader v3
|
| 3 |
-
FIXED: Step 0 no longer gives free 1.0 β rewards ESTABLISHING good practices
|
| 4 |
-
"""
|
| 5 |
from codegraph.graph import CodeGraph
|
| 6 |
from codegraph.extractor import extract_metadata
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
# Minimum quality bar for first submission (establishing conventions)
|
| 10 |
GOOD_PRACTICES = {
|
| 11 |
-
"uses_type_hints":
|
| 12 |
-
"uses_docstrings":
|
| 13 |
-
"uses_try_catch":
|
| 14 |
-
"no_print_stmts":
|
| 15 |
-
"no_hardcoded_secrets":
|
| 16 |
}
|
| 17 |
|
| 18 |
|
| 19 |
def grade_consistency(code: str, filename: str, graph: CodeGraph, step: int) -> dict:
|
| 20 |
new_meta = extract_metadata(code, filename, step)
|
| 21 |
-
conv
|
| 22 |
|
| 23 |
if not graph.components:
|
| 24 |
-
|
| 25 |
-
checks
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
score = sum(checks.values()) / max(len(checks), 1)
|
| 30 |
-
# Minimum 0.5 so this doesn't destroy reward on first step
|
| 31 |
-
score = max(0.5, score)
|
| 32 |
-
|
| 33 |
-
return {
|
| 34 |
-
"score": round(score, 4),
|
| 35 |
-
"checks": checks,
|
| 36 |
-
"feedback": _first_step_feedback(score, checks),
|
| 37 |
-
}
|
| 38 |
|
| 39 |
-
# Step 1+: check consistency with established conventions
|
| 40 |
established = graph.conventions
|
| 41 |
-
checks
|
| 42 |
|
| 43 |
-
# Naming convention
|
| 44 |
naming = established.get("naming")
|
| 45 |
if naming and naming != "mixed" and new_meta.functions:
|
| 46 |
fns = new_meta.functions
|
|
@@ -51,53 +36,40 @@ def grade_consistency(code: str, filename: str, graph: CodeGraph, step: int) ->
|
|
| 51 |
and any(c.isupper() for c in f["name"]))
|
| 52 |
checks["naming_convention"] = correct / len(fns)
|
| 53 |
|
| 54 |
-
# Error handling
|
| 55 |
if established.get("error_handling") == "try_catch":
|
| 56 |
checks["error_handling"] = 1.0 if conv.get("uses_try_catch") else 0.3
|
| 57 |
-
|
| 58 |
-
# Type hints
|
| 59 |
if established.get("uses_type_hints"):
|
| 60 |
checks["type_hints"] = 1.0 if conv.get("uses_type_hints") else 0.4
|
| 61 |
-
|
| 62 |
-
# Docstrings
|
| 63 |
if established.get("uses_docstrings"):
|
| 64 |
checks["docstrings"] = 1.0 if conv.get("uses_docstrings") else 0.5
|
| 65 |
|
| 66 |
-
# No print drift
|
| 67 |
existing_no_print = all(c.conventions.get("no_print_stmts", True)
|
| 68 |
for c in graph.components.values())
|
| 69 |
if existing_no_print:
|
| 70 |
checks["no_print_drift"] = 1.0 if conv.get("no_print_stmts", True) else 0.3
|
| 71 |
|
| 72 |
-
# Component reuse
|
| 73 |
reuse_opp = reuse_taken = 0
|
| 74 |
-
for
|
| 75 |
-
if
|
| 76 |
reuse_opp += 1
|
| 77 |
-
if
|
| 78 |
reuse_taken += 1
|
| 79 |
if reuse_opp > 0:
|
| 80 |
checks["component_reuse"] = reuse_taken / reuse_opp
|
| 81 |
|
| 82 |
-
|
| 83 |
-
return {
|
| 84 |
-
|
| 85 |
-
"checks": checks,
|
| 86 |
-
"feedback": _consistency_feedback(score, checks),
|
| 87 |
-
}
|
| 88 |
|
| 89 |
|
| 90 |
-
def
|
| 91 |
missing = [k for k, v in checks.items() if v == 0.0]
|
| 92 |
if not missing:
|
| 93 |
-
return f"Good conventions established (
|
| 94 |
-
return f"Missing
|
| 95 |
|
| 96 |
|
| 97 |
-
def
|
| 98 |
-
if score >= 0.
|
| 99 |
-
return "Excellent consistency with existing codebase conventions"
|
| 100 |
failing = [k for k, v in checks.items() if isinstance(v, float) and v < 0.5]
|
| 101 |
-
if failing:
|
| 102 |
-
return f"Convention drift in: {', '.join(failing)}"
|
| 103 |
-
return f"Minor convention drift (score: {score:.2f})"
|
|
|
|
| 1 |
+
"""SecureCodeEnv - Consistency Grader v4 β clamped scores"""
|
|
|
|
|
|
|
|
|
|
| 2 |
from codegraph.graph import CodeGraph
|
| 3 |
from codegraph.extractor import extract_metadata
|
| 4 |
+
from graders.clamp import clamp
|
| 5 |
|
|
|
|
|
|
|
| 6 |
GOOD_PRACTICES = {
|
| 7 |
+
"uses_type_hints": 0.15,
|
| 8 |
+
"uses_docstrings": 0.15,
|
| 9 |
+
"uses_try_catch": 0.10,
|
| 10 |
+
"no_print_stmts": 0.10,
|
| 11 |
+
"no_hardcoded_secrets": 0.10,
|
| 12 |
}
|
| 13 |
|
| 14 |
|
| 15 |
def grade_consistency(code: str, filename: str, graph: CodeGraph, step: int) -> dict:
|
| 16 |
new_meta = extract_metadata(code, filename, step)
|
| 17 |
+
conv = new_meta.conventions
|
| 18 |
|
| 19 |
if not graph.components:
|
| 20 |
+
checks = {k: 1.0 if conv.get(k, False) else 0.0 for k in GOOD_PRACTICES}
|
| 21 |
+
raw = sum(checks.values()) / max(len(checks), 1)
|
| 22 |
+
raw = max(0.45, raw) # floor so first step never crushes reward
|
| 23 |
+
return {"score": clamp(raw), "checks": checks,
|
| 24 |
+
"feedback": _first_feedback(raw, checks)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
| 26 |
established = graph.conventions
|
| 27 |
+
checks = {}
|
| 28 |
|
|
|
|
| 29 |
naming = established.get("naming")
|
| 30 |
if naming and naming != "mixed" and new_meta.functions:
|
| 31 |
fns = new_meta.functions
|
|
|
|
| 36 |
and any(c.isupper() for c in f["name"]))
|
| 37 |
checks["naming_convention"] = correct / len(fns)
|
| 38 |
|
|
|
|
| 39 |
if established.get("error_handling") == "try_catch":
|
| 40 |
checks["error_handling"] = 1.0 if conv.get("uses_try_catch") else 0.3
|
|
|
|
|
|
|
| 41 |
if established.get("uses_type_hints"):
|
| 42 |
checks["type_hints"] = 1.0 if conv.get("uses_type_hints") else 0.4
|
|
|
|
|
|
|
| 43 |
if established.get("uses_docstrings"):
|
| 44 |
checks["docstrings"] = 1.0 if conv.get("uses_docstrings") else 0.5
|
| 45 |
|
|
|
|
| 46 |
existing_no_print = all(c.conventions.get("no_print_stmts", True)
|
| 47 |
for c in graph.components.values())
|
| 48 |
if existing_no_print:
|
| 49 |
checks["no_print_drift"] = 1.0 if conv.get("no_print_stmts", True) else 0.3
|
| 50 |
|
|
|
|
| 51 |
reuse_opp = reuse_taken = 0
|
| 52 |
+
for name in graph.components:
|
| 53 |
+
if name.lower() in code.lower():
|
| 54 |
reuse_opp += 1
|
| 55 |
+
if name in code:
|
| 56 |
reuse_taken += 1
|
| 57 |
if reuse_opp > 0:
|
| 58 |
checks["component_reuse"] = reuse_taken / reuse_opp
|
| 59 |
|
| 60 |
+
raw = sum(checks.values()) / max(len(checks), 1) if checks else 0.7
|
| 61 |
+
return {"score": clamp(raw), "checks": checks,
|
| 62 |
+
"feedback": _feedback(raw, checks)}
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
+
def _first_feedback(score, checks):
|
| 66 |
missing = [k for k, v in checks.items() if v == 0.0]
|
| 67 |
if not missing:
|
| 68 |
+
return f"Good conventions established ({score:.2f})"
|
| 69 |
+
return f"Missing practices: {', '.join(missing)}"
|
| 70 |
|
| 71 |
|
| 72 |
+
def _feedback(score, checks):
|
| 73 |
+
if score >= 0.85: return "Excellent consistency with codebase"
|
|
|
|
| 74 |
failing = [k for k, v in checks.items() if isinstance(v, float) and v < 0.5]
|
| 75 |
+
return f"Convention drift in: {', '.join(failing)}" if failing else f"Minor drift ({score:.2f})"
|
|
|
|
|
|
graders/correctness.py
CHANGED
|
@@ -1,56 +1,45 @@
|
|
| 1 |
"""
|
| 2 |
-
SecureCodeEnv - Correctness Grader
|
| 3 |
-
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
from sandbox.executor import safe_exec
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def _is_seq(v):
|
| 9 |
return isinstance(v, (list, tuple))
|
| 10 |
|
| 11 |
|
| 12 |
def grade_correctness(code: str, task: dict) -> dict:
|
| 13 |
-
"""
|
| 14 |
-
Runs the task's test cases against the agent's code.
|
| 15 |
-
|
| 16 |
-
Returns:
|
| 17 |
-
{
|
| 18 |
-
"score": float 0.0-1.0,
|
| 19 |
-
"passed": int,
|
| 20 |
-
"total": int,
|
| 21 |
-
"details": list of per-test results
|
| 22 |
-
}
|
| 23 |
-
"""
|
| 24 |
test_cases = task.get("test_cases", [])
|
| 25 |
if not test_cases:
|
| 26 |
-
return {"score":
|
|
|
|
| 27 |
|
| 28 |
passed = 0
|
| 29 |
details = []
|
| 30 |
-
|
| 31 |
for tc in test_cases:
|
| 32 |
result = _run_test_case(code, tc)
|
| 33 |
if result["passed"]:
|
| 34 |
passed += 1
|
| 35 |
details.append(result)
|
| 36 |
|
| 37 |
-
|
| 38 |
return {
|
| 39 |
-
"score":
|
| 40 |
"passed": passed,
|
| 41 |
"total": len(test_cases),
|
| 42 |
"details": details,
|
| 43 |
-
"feedback":
|
| 44 |
}
|
| 45 |
|
| 46 |
|
| 47 |
def _run_test_case(code: str, tc: dict) -> dict:
|
| 48 |
-
"""Execute a single test case and evaluate the result."""
|
| 49 |
fn_name = tc.get("fn", "solution")
|
| 50 |
-
inputs
|
| 51 |
-
|
| 52 |
|
| 53 |
-
# Handle class-based tasks
|
| 54 |
if "fn_class" in tc:
|
| 55 |
return _run_class_test(code, tc)
|
| 56 |
|
|
@@ -58,102 +47,92 @@ def _run_test_case(code: str, tc: dict) -> dict:
|
|
| 58 |
|
| 59 |
if not exec_result["ok"]:
|
| 60 |
expected_exc = tc.get("expected_exception")
|
| 61 |
-
error_str
|
| 62 |
-
exc_type
|
| 63 |
if expected_exc:
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
return {"passed": False, "description": description, "error": error_str[:200]}
|
| 72 |
|
| 73 |
output = exec_result.get("output")
|
| 74 |
|
| 75 |
-
#
|
| 76 |
-
if "expected_not_none" in tc:
|
| 77 |
-
ok = output is not None
|
| 78 |
-
return {"passed": ok, "description": description}
|
| 79 |
-
|
| 80 |
-
# SQL injection safety check: payload must NOT appear in query, placeholder must exist
|
| 81 |
if tc.get("sql_injection_check"):
|
| 82 |
if not _is_seq(output) or len(output) < 2:
|
| 83 |
-
return {"passed": False, "description":
|
| 84 |
-
query
|
| 85 |
payload_val = inputs[0] if inputs else ""
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
#
|
| 93 |
if "expected" in tc:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
return {"passed": ok, "description": description, "got": output, "expected": expected}
|
| 97 |
|
| 98 |
-
# Type check (JSON
|
| 99 |
if "expected_type" in tc:
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
ok = actual_type == type_name or (actual_type, type_name) in equivalent or (type_name, actual_type) in equivalent
|
| 105 |
if ok and "expected_len" in tc:
|
| 106 |
ok = hasattr(output, "__len__") and len(output) == tc["expected_len"]
|
| 107 |
-
return {"passed": ok, "description":
|
| 108 |
|
| 109 |
-
# Contains
|
| 110 |
if "expected_contains" in tc:
|
| 111 |
-
|
| 112 |
-
return {"passed": ok, "description": description}
|
| 113 |
|
| 114 |
-
# Not-contains
|
| 115 |
if "expected_not_contains" in tc:
|
| 116 |
forbidden = tc["expected_not_contains"]
|
| 117 |
if isinstance(forbidden, list):
|
| 118 |
ok = not any(f in str(output) for f in forbidden)
|
| 119 |
else:
|
| 120 |
ok = forbidden not in str(output)
|
| 121 |
-
return {"passed": ok, "description":
|
| 122 |
|
| 123 |
-
# Min length
|
| 124 |
if "expected_min_len" in tc:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
|
| 128 |
-
# Max length
|
| 129 |
if "expected_max_len" in tc:
|
| 130 |
-
|
| 131 |
-
|
| 132 |
|
| 133 |
-
# Ok-flag
|
| 134 |
if "expected_ok" in tc:
|
| 135 |
-
|
| 136 |
-
|
| 137 |
|
| 138 |
-
|
| 139 |
-
return {"passed": True, "description": description, "note": "No assertion defined"}
|
| 140 |
|
| 141 |
|
| 142 |
def _run_class_test(code: str, tc: dict) -> dict:
|
| 143 |
-
"""Run a test against a class-based task (e.g. RateLimiter)."""
|
| 144 |
class_name = tc.get("fn_class", "Solution")
|
| 145 |
-
init_args
|
| 146 |
-
method
|
| 147 |
-
inputs
|
| 148 |
-
|
| 149 |
|
| 150 |
-
|
| 151 |
{code}
|
| 152 |
|
| 153 |
def run_task(args):
|
| 154 |
-
init_args = args[0]
|
| 155 |
-
method = args[1]
|
| 156 |
-
inputs = args[2]
|
| 157 |
obj = {class_name}(*init_args)
|
| 158 |
if method == "is_allowed_multi":
|
| 159 |
result = None
|
|
@@ -161,34 +140,23 @@ def run_task(args):
|
|
| 161 |
result = obj.is_allowed(inputs[0])
|
| 162 |
return result
|
| 163 |
if method == "independent_clients":
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
return r1 == r2 == True
|
| 167 |
-
fn = getattr(obj, method)
|
| 168 |
-
return fn(*inputs)
|
| 169 |
"""
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
if not result["ok"]:
|
| 174 |
-
return {"passed": False, "description":
|
| 175 |
-
|
| 176 |
output = result.get("output")
|
| 177 |
if "expected" in tc:
|
| 178 |
-
|
| 179 |
-
return {"passed": ok, "description": description}
|
| 180 |
if "expected_last" in tc:
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
return f"Good β {passed}/{total} tests passed. Minor edge cases missing"
|
| 191 |
-
elif score >= 0.5:
|
| 192 |
-
return f"Partial β {passed}/{total} tests passed. Fix failing cases"
|
| 193 |
-
else:
|
| 194 |
-
return f"Poor β {passed}/{total} tests passed. Core logic incorrect"
|
|
|
|
| 1 |
"""
|
| 2 |
+
SecureCodeEnv - Correctness Grader v4
|
| 3 |
+
Weight: 25% of total reward.
|
| 4 |
+
All scores clamped to (0.001, 0.999).
|
| 5 |
"""
|
| 6 |
from sandbox.executor import safe_exec
|
| 7 |
+
from graders.clamp import clamp
|
| 8 |
+
|
| 9 |
|
| 10 |
def _is_seq(v):
|
| 11 |
return isinstance(v, (list, tuple))
|
| 12 |
|
| 13 |
|
| 14 |
def grade_correctness(code: str, task: dict) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
test_cases = task.get("test_cases", [])
|
| 16 |
if not test_cases:
|
| 17 |
+
return {"score": clamp(0.5), "passed": 0, "total": 0,
|
| 18 |
+
"details": [], "feedback": "No test cases defined"}
|
| 19 |
|
| 20 |
passed = 0
|
| 21 |
details = []
|
|
|
|
| 22 |
for tc in test_cases:
|
| 23 |
result = _run_test_case(code, tc)
|
| 24 |
if result["passed"]:
|
| 25 |
passed += 1
|
| 26 |
details.append(result)
|
| 27 |
|
| 28 |
+
raw = passed / len(test_cases)
|
| 29 |
return {
|
| 30 |
+
"score": clamp(raw),
|
| 31 |
"passed": passed,
|
| 32 |
"total": len(test_cases),
|
| 33 |
"details": details,
|
| 34 |
+
"feedback": _feedback(raw, passed, len(test_cases)),
|
| 35 |
}
|
| 36 |
|
| 37 |
|
| 38 |
def _run_test_case(code: str, tc: dict) -> dict:
|
|
|
|
| 39 |
fn_name = tc.get("fn", "solution")
|
| 40 |
+
inputs = tc.get("input", [])
|
| 41 |
+
desc = tc.get("description", "")
|
| 42 |
|
|
|
|
| 43 |
if "fn_class" in tc:
|
| 44 |
return _run_class_test(code, tc)
|
| 45 |
|
|
|
|
| 47 |
|
| 48 |
if not exec_result["ok"]:
|
| 49 |
expected_exc = tc.get("expected_exception")
|
| 50 |
+
error_str = exec_result.get("error", "")
|
| 51 |
+
exc_type = exec_result.get("type", "")
|
| 52 |
if expected_exc:
|
| 53 |
+
if (exc_type == expected_exc or
|
| 54 |
+
expected_exc.lower() in error_str.lower() or
|
| 55 |
+
expected_exc.lower() in exc_type.lower()):
|
| 56 |
+
return {"passed": True, "description": desc,
|
| 57 |
+
"note": f"Expected {expected_exc} raised"}
|
| 58 |
+
return {"passed": False, "description": desc,
|
| 59 |
+
"error": error_str[:200]}
|
|
|
|
| 60 |
|
| 61 |
output = exec_result.get("output")
|
| 62 |
|
| 63 |
+
# SQL injection parameterization check
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if tc.get("sql_injection_check"):
|
| 65 |
if not _is_seq(output) or len(output) < 2:
|
| 66 |
+
return {"passed": False, "description": desc, "error": "Not a 2-element sequence"}
|
| 67 |
+
query = str(output[0])
|
| 68 |
payload_val = inputs[0] if inputs else ""
|
| 69 |
+
has_ph = any(p in query for p in ("%s", "?", ":param", "%(username"))
|
| 70 |
+
safe = str(payload_val).strip() not in query
|
| 71 |
+
return {"passed": has_ph and safe, "description": desc,
|
| 72 |
+
"note": f"placeholder={has_ph} payload_safe={safe}"}
|
| 73 |
+
|
| 74 |
+
# Not-None
|
| 75 |
+
if "expected_not_none" in tc:
|
| 76 |
+
return {"passed": output is not None, "description": desc}
|
| 77 |
|
| 78 |
+
# Equality
|
| 79 |
if "expected" in tc:
|
| 80 |
+
return {"passed": output == tc["expected"], "description": desc,
|
| 81 |
+
"got": output, "expected": tc["expected"]}
|
|
|
|
| 82 |
|
| 83 |
+
# Type check (JSON converts tupleβlist)
|
| 84 |
if "expected_type" in tc:
|
| 85 |
+
tname = tc["expected_type"]
|
| 86 |
+
atype = type(output).__name__
|
| 87 |
+
equiv = {("tuple","list"),("list","tuple")}
|
| 88 |
+
ok = atype == tname or (atype, tname) in equiv or (tname, atype) in equiv
|
|
|
|
| 89 |
if ok and "expected_len" in tc:
|
| 90 |
ok = hasattr(output, "__len__") and len(output) == tc["expected_len"]
|
| 91 |
+
return {"passed": ok, "description": desc, "got_type": atype}
|
| 92 |
|
| 93 |
+
# Contains
|
| 94 |
if "expected_contains" in tc:
|
| 95 |
+
return {"passed": tc["expected_contains"] in str(output), "description": desc}
|
|
|
|
| 96 |
|
| 97 |
+
# Not-contains
|
| 98 |
if "expected_not_contains" in tc:
|
| 99 |
forbidden = tc["expected_not_contains"]
|
| 100 |
if isinstance(forbidden, list):
|
| 101 |
ok = not any(f in str(output) for f in forbidden)
|
| 102 |
else:
|
| 103 |
ok = forbidden not in str(output)
|
| 104 |
+
return {"passed": ok, "description": desc, "got": str(output)[:100]}
|
| 105 |
|
| 106 |
+
# Min length
|
| 107 |
if "expected_min_len" in tc:
|
| 108 |
+
return {"passed": output is not None and len(str(output)) >= tc["expected_min_len"],
|
| 109 |
+
"description": desc}
|
| 110 |
|
| 111 |
+
# Max length
|
| 112 |
if "expected_max_len" in tc:
|
| 113 |
+
return {"passed": output is not None and len(str(output)) <= tc["expected_max_len"],
|
| 114 |
+
"description": desc}
|
| 115 |
|
| 116 |
+
# Ok-flag (dict with "ok" key)
|
| 117 |
if "expected_ok" in tc:
|
| 118 |
+
return {"passed": isinstance(output, dict) and output.get("ok") == tc["expected_ok"],
|
| 119 |
+
"description": desc}
|
| 120 |
|
| 121 |
+
return {"passed": True, "description": desc, "note": "No assertion"}
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
def _run_class_test(code: str, tc: dict) -> dict:
|
|
|
|
| 125 |
class_name = tc.get("fn_class", "Solution")
|
| 126 |
+
init_args = tc.get("init_args", [])
|
| 127 |
+
method = tc.get("method", "is_allowed")
|
| 128 |
+
inputs = tc.get("input", [])
|
| 129 |
+
desc = tc.get("description", "")
|
| 130 |
|
| 131 |
+
harness = f"""
|
| 132 |
{code}
|
| 133 |
|
| 134 |
def run_task(args):
|
| 135 |
+
init_args = args[0]; method = args[1]; inputs = args[2]
|
|
|
|
|
|
|
| 136 |
obj = {class_name}(*init_args)
|
| 137 |
if method == "is_allowed_multi":
|
| 138 |
result = None
|
|
|
|
| 140 |
result = obj.is_allowed(inputs[0])
|
| 141 |
return result
|
| 142 |
if method == "independent_clients":
|
| 143 |
+
return obj.is_allowed("client_a") == obj.is_allowed("client_b") == True
|
| 144 |
+
return getattr(obj, method)(*inputs)
|
|
|
|
|
|
|
|
|
|
| 145 |
"""
|
| 146 |
+
result = safe_exec(harness, [[init_args, method, inputs]],
|
| 147 |
+
function_name="run_task", timeout=5)
|
|
|
|
| 148 |
if not result["ok"]:
|
| 149 |
+
return {"passed": False, "description": desc, "error": result.get("error","")[:200]}
|
|
|
|
| 150 |
output = result.get("output")
|
| 151 |
if "expected" in tc:
|
| 152 |
+
return {"passed": output == tc["expected"], "description": desc}
|
|
|
|
| 153 |
if "expected_last" in tc:
|
| 154 |
+
return {"passed": output == tc["expected_last"], "description": desc}
|
| 155 |
+
return {"passed": True, "description": desc}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _feedback(score: float, passed: int, total: int) -> str:
|
| 159 |
+
if score >= 0.9: return f"Excellent β {passed}/{total} tests passed"
|
| 160 |
+
elif score >= 0.7: return f"Good β {passed}/{total} tests passed"
|
| 161 |
+
elif score >= 0.5: return f"Partial β {passed}/{total} tests passed"
|
| 162 |
+
else: return f"Poor β {passed}/{total} tests passed"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graders/documentation.py
CHANGED
|
@@ -1,142 +1,76 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SecureCodeEnv - Documentation & Code Structure Graders
|
| 3 |
-
Documentation weight: 5% | Code Structure weight: 5%
|
| 4 |
-
"""
|
| 5 |
import ast
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def grade_documentation(code: str) -> dict:
|
| 9 |
-
"""
|
| 10 |
-
Grade docstring and type hint coverage.
|
| 11 |
-
Rewards: functions with docstrings, full type annotations, module docstring.
|
| 12 |
-
|
| 13 |
-
Returns:
|
| 14 |
-
{"score": float, "documented_fns": int, "total_fns": int, "feedback": str}
|
| 15 |
-
"""
|
| 16 |
try:
|
| 17 |
tree = ast.parse(code)
|
| 18 |
except SyntaxError:
|
| 19 |
-
return {"score": 0.0, "documented_fns": 0, "total_fns": 0,
|
| 20 |
-
|
| 21 |
-
functions = [
|
| 22 |
-
n for n in ast.walk(tree)
|
| 23 |
-
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
|
| 24 |
-
]
|
| 25 |
|
|
|
|
|
|
|
| 26 |
if not functions:
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
"documented_fns": 0,
|
| 32 |
-
"total_fns": 0,
|
| 33 |
-
"feedback": "No functions found β module-level code only",
|
| 34 |
-
}
|
| 35 |
|
| 36 |
-
documented = 0
|
| 37 |
-
typed = 0
|
| 38 |
scores = []
|
| 39 |
-
|
| 40 |
for fn in functions:
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
documented += 1
|
| 49 |
-
fn_score += 0.5
|
| 50 |
-
|
| 51 |
-
if has_any_types:
|
| 52 |
-
typed += 1
|
| 53 |
-
fn_score += 0.5
|
| 54 |
-
|
| 55 |
-
scores.append(fn_score)
|
| 56 |
|
| 57 |
total = len(functions)
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
"documented_fns": documented,
|
| 63 |
-
"typed_fns": typed,
|
| 64 |
-
"total_fns": total,
|
| 65 |
-
"feedback": _doc_feedback(score, documented, typed, total),
|
| 66 |
-
}
|
| 67 |
|
| 68 |
|
| 69 |
def grade_code_structure(code: str) -> dict:
|
| 70 |
-
"""
|
| 71 |
-
Grade code structure quality:
|
| 72 |
-
- No bare print() statements
|
| 73 |
-
- Exception handling present where needed
|
| 74 |
-
- No bare except clauses
|
| 75 |
-
- No hardcoded magic strings
|
| 76 |
-
- Functions not excessively long (>50 lines)
|
| 77 |
-
|
| 78 |
-
Returns:
|
| 79 |
-
{"score": float, "checks": dict, "feedback": str}
|
| 80 |
-
"""
|
| 81 |
try:
|
| 82 |
tree = ast.parse(code)
|
| 83 |
except SyntaxError:
|
| 84 |
-
return {"score": 0.0, "checks": {}, "feedback": "Syntax error"}
|
| 85 |
|
| 86 |
-
checks: dict[str, bool] = {}
|
| 87 |
lines = code.splitlines()
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
checks["
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# Check 5: Handles None inputs (basic check)
|
| 118 |
-
checks["handles_none"] = "None" in code or "is not None" in code or "if not " in code
|
| 119 |
-
|
| 120 |
-
score = sum(1 for v in checks.values() if v) / max(len(checks), 1)
|
| 121 |
-
|
| 122 |
-
return {
|
| 123 |
-
"score": round(score, 4),
|
| 124 |
-
"checks": checks,
|
| 125 |
-
"feedback": _structure_feedback(score, checks),
|
| 126 |
-
}
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def _doc_feedback(score: float, documented: int, typed: int, total: int) -> str:
|
| 130 |
-
if score >= 0.9:
|
| 131 |
-
return f"Well documented β {documented}/{total} functions have docstrings, {typed}/{total} typed"
|
| 132 |
-
elif score >= 0.6:
|
| 133 |
-
return f"Partial documentation β {documented}/{total} docstrings, {typed}/{total} type hints"
|
| 134 |
-
else:
|
| 135 |
-
return f"Poor documentation β add docstrings and type hints to all {total} functions"
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def _structure_feedback(score: float, checks: dict) -> str:
|
| 139 |
-
if score >= 0.9:
|
| 140 |
-
return "Clean code structure"
|
| 141 |
failing = [k for k, v in checks.items() if not v]
|
| 142 |
return f"Structure issues: {', '.join(failing)}"
|
|
|
|
| 1 |
+
"""SecureCodeEnv - Documentation & Structure Graders v4 β clamped scores"""
|
|
|
|
|
|
|
|
|
|
| 2 |
import ast
|
| 3 |
+
from graders.clamp import clamp
|
| 4 |
|
| 5 |
|
| 6 |
def grade_documentation(code: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
try:
|
| 8 |
tree = ast.parse(code)
|
| 9 |
except SyntaxError:
|
| 10 |
+
return {"score": clamp(0.0), "documented_fns": 0, "total_fns": 0,
|
| 11 |
+
"feedback": "Syntax error β cannot parse"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
functions = [n for n in ast.walk(tree)
|
| 14 |
+
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
|
| 15 |
if not functions:
|
| 16 |
+
has_mod_doc = bool(ast.get_docstring(tree))
|
| 17 |
+
return {"score": clamp(0.65 if has_mod_doc else 0.5),
|
| 18 |
+
"documented_fns": 0, "total_fns": 0,
|
| 19 |
+
"feedback": "No functions β module-level code only"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
| 21 |
scores = []
|
| 22 |
+
documented = typed = 0
|
| 23 |
for fn in functions:
|
| 24 |
+
s = 0.0
|
| 25 |
+
hd = bool(ast.get_docstring(fn))
|
| 26 |
+
hr = fn.returns is not None
|
| 27 |
+
hp = any(a.annotation is not None for a in fn.args.args)
|
| 28 |
+
if hd: documented += 1; s += 0.5
|
| 29 |
+
if hr or hp: typed += 1; s += 0.5
|
| 30 |
+
scores.append(s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
total = len(functions)
|
| 33 |
+
raw = sum(scores) / total
|
| 34 |
+
return {"score": clamp(raw), "documented_fns": documented,
|
| 35 |
+
"typed_fns": typed, "total_fns": total,
|
| 36 |
+
"feedback": _doc_feedback(raw, documented, typed, total)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def grade_code_structure(code: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
try:
|
| 41 |
tree = ast.parse(code)
|
| 42 |
except SyntaxError:
|
| 43 |
+
return {"score": clamp(0.0), "checks": {}, "feedback": "Syntax error"}
|
| 44 |
|
|
|
|
| 45 |
lines = code.splitlines()
|
| 46 |
+
checks = {}
|
| 47 |
+
checks["no_bare_print"] = "print(" not in code
|
| 48 |
+
checks["no_bare_except"] = not any(
|
| 49 |
+
isinstance(n, ast.ExceptHandler) and n.type is None
|
| 50 |
+
for n in ast.walk(tree))
|
| 51 |
+
checks["reasonable_fn_size"] = not any(
|
| 52 |
+
isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) and
|
| 53 |
+
(n.end_lineno or 0) - n.lineno > 50
|
| 54 |
+
for n in ast.walk(tree))
|
| 55 |
+
checks["no_todo_comments"] = not any(
|
| 56 |
+
any(kw in line.upper() for kw in ["# TODO", "# FIXME", "# HACK"])
|
| 57 |
+
for line in lines)
|
| 58 |
+
checks["handles_none"] = any(
|
| 59 |
+
token in code for token in
|
| 60 |
+
["None", "is not None", "if not ", "Optional", "is None"])
|
| 61 |
+
|
| 62 |
+
raw = sum(1 for v in checks.values() if v) / max(len(checks), 1)
|
| 63 |
+
return {"score": clamp(raw), "checks": checks,
|
| 64 |
+
"feedback": _struct_feedback(raw, checks)}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _doc_feedback(score, documented, typed, total):
|
| 68 |
+
if score >= 0.85: return f"Well documented β {documented}/{total} docstrings, {typed}/{total} typed"
|
| 69 |
+
elif score >= 0.55: return f"Partial β {documented}/{total} docstrings, {typed}/{total} type hints"
|
| 70 |
+
return f"Poor β add docstrings and type hints to all {total} functions"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _struct_feedback(score, checks):
|
| 74 |
+
if score >= 0.85: return "Clean code structure"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
failing = [k for k, v in checks.items() if not v]
|
| 76 |
return f"Structure issues: {', '.join(failing)}"
|
graders/performance.py
CHANGED
|
@@ -1,63 +1,93 @@
|
|
| 1 |
"""
|
| 2 |
-
SecureCodeEnv - Performance Grader
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
import sys, tempfile, os, json, subprocess
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def grade_performance(code: str, task: dict) -> dict:
|
| 9 |
-
test_cases
|
| 10 |
-
naive_code
|
| 11 |
optimal_code = task.get("optimal_code", "")
|
| 12 |
|
| 13 |
if not test_cases or not naive_code or not optimal_code:
|
| 14 |
-
return {"score":
|
| 15 |
-
"feedback": "No baselines
|
| 16 |
|
| 17 |
tc = next((t for t in test_cases
|
| 18 |
if "fn" in t and "input" in t
|
| 19 |
and "fn_class" not in t
|
| 20 |
and "expected_exception" not in t), None)
|
| 21 |
if not tc:
|
| 22 |
-
return {"score":
|
| 23 |
-
"feedback": "No
|
| 24 |
|
| 25 |
fn_name = tc["fn"]
|
| 26 |
-
inputs
|
| 27 |
|
| 28 |
try:
|
| 29 |
-
agent_ms = _measure_ms(code,
|
| 30 |
-
naive_ms = _measure_ms(naive_code,
|
| 31 |
optimal_ms = _measure_ms(optimal_code, fn_name, inputs)
|
| 32 |
|
| 33 |
-
#
|
| 34 |
-
if
|
| 35 |
-
return
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
time_range =
|
| 42 |
-
raw
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
return {
|
| 47 |
-
"score":
|
| 48 |
-
"time_score":
|
| 49 |
-
"memory_score":
|
| 50 |
-
"agent_ms":
|
| 51 |
-
"naive_ms":
|
| 52 |
"optimal_ms": round(optimal_ms, 3),
|
| 53 |
-
"feedback": _feedback(
|
| 54 |
}
|
| 55 |
except Exception as e:
|
| 56 |
-
return {"score":
|
|
|
|
| 57 |
"feedback": f"Measurement error: {str(e)[:60]}"}
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def _measure_ms(code: str, fn_name: str, inputs: list, runs: int = 50) -> float:
|
|
|
|
| 61 |
script = f"""
|
| 62 |
import timeit, json, sys
|
| 63 |
{code}
|
|
@@ -79,7 +109,7 @@ sys.stdout.flush()
|
|
| 79 |
line = line.strip()
|
| 80 |
if line.startswith("{"):
|
| 81 |
return json.loads(line)["ms"]
|
| 82 |
-
return -1.0
|
| 83 |
except Exception:
|
| 84 |
return -1.0
|
| 85 |
finally:
|
|
@@ -89,7 +119,7 @@ sys.stdout.flush()
|
|
| 89 |
|
| 90 |
|
| 91 |
def _feedback(score: float) -> str:
|
| 92 |
-
if score >= 0.
|
| 93 |
-
elif score >= 0.
|
| 94 |
-
elif score >= 0.
|
| 95 |
-
else:
|
|
|
|
| 1 |
"""
|
| 2 |
+
SecureCodeEnv - Performance Grader v4
|
| 3 |
+
|
| 4 |
+
FIXES:
|
| 5 |
+
- Inverted baseline (naive faster than optimal) β return neutral 0.5
|
| 6 |
+
- Unmeasurable (-1.0) β return neutral 0.5
|
| 7 |
+
- Both timings identical β return neutral 0.5
|
| 8 |
+
- Agent faster than optimal β clamp to max 0.999 (not >1.0)
|
| 9 |
+
- All scores clamped to (0.001, 0.999)
|
| 10 |
"""
|
| 11 |
import sys, tempfile, os, json, subprocess
|
| 12 |
+
from graders.clamp import clamp
|
| 13 |
+
|
| 14 |
+
NEUTRAL = 0.5 # returned when measurement is unreliable
|
| 15 |
|
| 16 |
|
| 17 |
def grade_performance(code: str, task: dict) -> dict:
|
| 18 |
+
test_cases = task.get("test_cases", [])
|
| 19 |
+
naive_code = task.get("naive_code", "")
|
| 20 |
optimal_code = task.get("optimal_code", "")
|
| 21 |
|
| 22 |
if not test_cases or not naive_code or not optimal_code:
|
| 23 |
+
return {"score": clamp(NEUTRAL), "time_score": clamp(NEUTRAL),
|
| 24 |
+
"memory_score": clamp(NEUTRAL), "feedback": "No baselines β neutral score"}
|
| 25 |
|
| 26 |
tc = next((t for t in test_cases
|
| 27 |
if "fn" in t and "input" in t
|
| 28 |
and "fn_class" not in t
|
| 29 |
and "expected_exception" not in t), None)
|
| 30 |
if not tc:
|
| 31 |
+
return {"score": clamp(NEUTRAL), "time_score": clamp(NEUTRAL),
|
| 32 |
+
"memory_score": clamp(NEUTRAL), "feedback": "No usable test case β neutral score"}
|
| 33 |
|
| 34 |
fn_name = tc["fn"]
|
| 35 |
+
inputs = tc["input"]
|
| 36 |
|
| 37 |
try:
|
| 38 |
+
agent_ms = _measure_ms(code, fn_name, inputs)
|
| 39 |
+
naive_ms = _measure_ms(naive_code, fn_name, inputs)
|
| 40 |
optimal_ms = _measure_ms(optimal_code, fn_name, inputs)
|
| 41 |
|
| 42 |
+
# Any unmeasurable result β neutral
|
| 43 |
+
if any(x < 0 for x in [agent_ms, naive_ms, optimal_ms]):
|
| 44 |
+
return _neutral(agent_ms, naive_ms, optimal_ms, "Unmeasurable timing")
|
| 45 |
+
|
| 46 |
+
# Indistinguishable β neutral
|
| 47 |
+
if abs(naive_ms - optimal_ms) < 0.05:
|
| 48 |
+
return _neutral(agent_ms, naive_ms, optimal_ms, "Timings indistinguishable")
|
| 49 |
+
|
| 50 |
+
# Inverted baseline (naive < optimal means naive is actually "better")
|
| 51 |
+
# This happens when optimal uses safer-but-slower code (e.g. Path.resolve vs os.path.join)
|
| 52 |
+
# In that case performance cannot be meaningfully scored β neutral
|
| 53 |
+
if naive_ms < optimal_ms:
|
| 54 |
+
return _neutral(agent_ms, naive_ms, optimal_ms,
|
| 55 |
+
"Baseline inverted (naive faster than optimal) β neutral")
|
| 56 |
|
| 57 |
+
time_range = naive_ms - optimal_ms
|
| 58 |
+
raw = 1.0 - ((agent_ms - optimal_ms) / time_range)
|
| 59 |
+
# raw > 1.0 when agent faster than optimal β clamp handles it
|
| 60 |
+
time_score = clamp(raw)
|
| 61 |
|
| 62 |
return {
|
| 63 |
+
"score": time_score,
|
| 64 |
+
"time_score": time_score,
|
| 65 |
+
"memory_score": time_score,
|
| 66 |
+
"agent_ms": round(agent_ms, 3),
|
| 67 |
+
"naive_ms": round(naive_ms, 3),
|
| 68 |
"optimal_ms": round(optimal_ms, 3),
|
| 69 |
+
"feedback": _feedback(time_score),
|
| 70 |
}
|
| 71 |
except Exception as e:
|
| 72 |
+
return {"score": clamp(NEUTRAL), "time_score": clamp(NEUTRAL),
|
| 73 |
+
"memory_score": clamp(NEUTRAL),
|
| 74 |
"feedback": f"Measurement error: {str(e)[:60]}"}
|
| 75 |
|
| 76 |
|
| 77 |
+
def _neutral(agent_ms, naive_ms, optimal_ms, reason: str) -> dict:
|
| 78 |
+
return {
|
| 79 |
+
"score": clamp(NEUTRAL),
|
| 80 |
+
"time_score": clamp(NEUTRAL),
|
| 81 |
+
"memory_score": clamp(NEUTRAL),
|
| 82 |
+
"agent_ms": round(agent_ms, 3) if agent_ms >= 0 else None,
|
| 83 |
+
"naive_ms": round(naive_ms, 3) if naive_ms >= 0 else None,
|
| 84 |
+
"optimal_ms": round(optimal_ms, 3) if optimal_ms >= 0 else None,
|
| 85 |
+
"feedback": reason,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
def _measure_ms(code: str, fn_name: str, inputs: list, runs: int = 50) -> float:
|
| 90 |
+
"""Returns ms or -1.0 if unmeasurable."""
|
| 91 |
script = f"""
|
| 92 |
import timeit, json, sys
|
| 93 |
{code}
|
|
|
|
| 109 |
line = line.strip()
|
| 110 |
if line.startswith("{"):
|
| 111 |
return json.loads(line)["ms"]
|
| 112 |
+
return -1.0
|
| 113 |
except Exception:
|
| 114 |
return -1.0
|
| 115 |
finally:
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def _feedback(score: float) -> str:
|
| 122 |
+
if score >= 0.85: return "Excellent β near-optimal efficiency"
|
| 123 |
+
elif score >= 0.65: return "Good β minor optimisation possible"
|
| 124 |
+
elif score >= 0.45: return "Acceptable β room for improvement"
|
| 125 |
+
else: return "Poor β significant performance gap"
|
graders/reward_aggregator.py
CHANGED
|
@@ -1,58 +1,52 @@
|
|
| 1 |
"""
|
| 2 |
-
SecureCodeEnv - Reward Aggregator
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
3. DONE threshold raised to 0.92.
|
| 9 |
-
4. Security floor penalty: if attack_resist < 0.5 OR static_security < 0.5,
|
| 10 |
-
total reward is capped at 0.65 (cannot fool the system with correctness alone).
|
| 11 |
"""
|
| 12 |
-
from graders.correctness
|
| 13 |
-
from graders.attacks
|
| 14 |
from graders.static_analysis import grade_static_analysis
|
| 15 |
-
from graders.performance
|
| 16 |
-
from graders.consistency
|
| 17 |
-
from graders.documentation
|
| 18 |
-
from
|
| 19 |
-
from codegraph.
|
|
|
|
| 20 |
|
| 21 |
-
# REBALANCED weights β security dimensions raised
|
| 22 |
WEIGHTS = {
|
| 23 |
-
"correctness": 0.25,
|
| 24 |
-
"attack_resist": 0.25,
|
| 25 |
-
"static_security": 0.20,
|
| 26 |
-
"consistency": 0.10,
|
| 27 |
-
"performance": 0.08,
|
| 28 |
-
"documentation": 0.07,
|
| 29 |
-
"code_structure": 0.05,
|
| 30 |
}
|
| 31 |
assert abs(sum(WEIGHTS.values()) - 1.0) < 1e-9
|
| 32 |
|
| 33 |
-
# Security gate: these minimums must ALL be met for episode to be DONE
|
| 34 |
SECURITY_GATE = {
|
| 35 |
-
"attack_resist": 0.75,
|
| 36 |
-
"static_security": 0.70,
|
| 37 |
-
"correctness": 0.80,
|
| 38 |
}
|
| 39 |
-
DONE_THRESHOLD
|
| 40 |
-
|
| 41 |
-
# Security floor: if security is critically low, cap total reward
|
| 42 |
SECURITY_FLOOR_DIMS = ["attack_resist", "static_security"]
|
| 43 |
-
|
| 44 |
-
SECURITY_FLOOR_CAP = 0.65
|
| 45 |
|
| 46 |
|
| 47 |
def grade_submission(code, filename, task, graph, step, seed):
|
| 48 |
-
corr
|
| 49 |
-
atk
|
| 50 |
-
stat
|
| 51 |
-
perf
|
| 52 |
-
cons
|
| 53 |
-
doc
|
| 54 |
struct = grade_code_structure(code)
|
| 55 |
|
|
|
|
| 56 |
scores = {
|
| 57 |
"correctness": corr["score"],
|
| 58 |
"attack_resist": atk["score"],
|
|
@@ -63,21 +57,19 @@ def grade_submission(code, filename, task, graph, step, seed):
|
|
| 63 |
"code_structure": struct["score"],
|
| 64 |
}
|
| 65 |
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
)
|
| 73 |
-
if security_critical_fail:
|
| 74 |
-
raw_reward = min(raw_reward, SECURITY_FLOOR_CAP)
|
| 75 |
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
-
#
|
| 79 |
-
gate_passed
|
| 80 |
-
for dim, threshold in SECURITY_GATE.items())
|
| 81 |
done_eligible = total_reward >= DONE_THRESHOLD and gate_passed
|
| 82 |
|
| 83 |
feedback = {
|
|
@@ -94,45 +86,43 @@ def grade_submission(code, filename, task, graph, step, seed):
|
|
| 94 |
|
| 95 |
details = {
|
| 96 |
"correctness": {"passed": corr.get("passed"), "total": corr.get("total")},
|
| 97 |
-
"attacks":
|
| 98 |
-
|
| 99 |
-
"static":
|
| 100 |
-
|
| 101 |
-
|
| 102 |
"security_gate_passed": gate_passed,
|
| 103 |
-
"done_eligible":
|
| 104 |
}
|
| 105 |
|
| 106 |
return {
|
| 107 |
-
"scores":
|
| 108 |
-
"total_reward":
|
| 109 |
"done_eligible": done_eligible,
|
| 110 |
-
"feedback":
|
| 111 |
-
"details":
|
| 112 |
-
"agent_ms":
|
| 113 |
-
"naive_ms":
|
| 114 |
-
"optimal_ms":
|
| 115 |
-
"new_metadata":
|
| 116 |
}
|
| 117 |
|
| 118 |
|
| 119 |
-
def _gate_status(scores
|
| 120 |
-
failing = [f"{
|
| 121 |
-
for
|
| 122 |
-
|
| 123 |
-
return f"BLOCKED β security gate not met: {', '.join(failing)}"
|
| 124 |
|
| 125 |
|
| 126 |
def _summary(reward, scores, gate_passed):
|
| 127 |
if reward >= DONE_THRESHOLD and gate_passed:
|
| 128 |
-
return f"β
Excellent ({reward:.3f}) β
|
| 129 |
if not gate_passed:
|
| 130 |
-
|
| 131 |
-
return f"π {gate_msg} (reward: {reward:.3f})"
|
| 132 |
if reward >= 0.75:
|
| 133 |
weakest = min(scores, key=scores.get)
|
| 134 |
return f"π‘ Good ({reward:.3f}) β improve: {weakest} ({scores[weakest]:.2f})"
|
| 135 |
if reward >= 0.55:
|
| 136 |
weak = [k for k, v in scores.items() if v < 0.5]
|
| 137 |
return f"π Needs work ({reward:.3f}) β fix: {', '.join(weak[:3])}"
|
| 138 |
-
return f"π΄ Poor ({reward:.3f}) β major
|
|
|
|
| 1 |
"""
|
| 2 |
+
SecureCodeEnv - Reward Aggregator v4
|
| 3 |
+
|
| 4 |
+
ALL scores are epsilon-clamped to (0.001 β¦ 0.999).
|
| 5 |
+
Security gate: episode done only when attackβ₯0.75, staticβ₯0.70, correctnessβ₯0.80.
|
| 6 |
+
Security floor: if attack<0.5 OR static<0.5, total capped at 0.65.
|
| 7 |
+
DONE threshold: 0.92 (after clamping).
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
+
from graders.correctness import grade_correctness
|
| 10 |
+
from graders.attacks import grade_attacks
|
| 11 |
from graders.static_analysis import grade_static_analysis
|
| 12 |
+
from graders.performance import grade_performance
|
| 13 |
+
from graders.consistency import grade_consistency
|
| 14 |
+
from graders.documentation import grade_documentation, grade_code_structure
|
| 15 |
+
from graders.clamp import clamp
|
| 16 |
+
from codegraph.extractor import extract_metadata
|
| 17 |
+
from codegraph.graph import CodeGraph
|
| 18 |
|
|
|
|
| 19 |
WEIGHTS = {
|
| 20 |
+
"correctness": 0.25,
|
| 21 |
+
"attack_resist": 0.25,
|
| 22 |
+
"static_security": 0.20,
|
| 23 |
+
"consistency": 0.10,
|
| 24 |
+
"performance": 0.08,
|
| 25 |
+
"documentation": 0.07,
|
| 26 |
+
"code_structure": 0.05,
|
| 27 |
}
|
| 28 |
assert abs(sum(WEIGHTS.values()) - 1.0) < 1e-9
|
| 29 |
|
|
|
|
| 30 |
SECURITY_GATE = {
|
| 31 |
+
"attack_resist": 0.75,
|
| 32 |
+
"static_security": 0.70,
|
| 33 |
+
"correctness": 0.80,
|
| 34 |
}
|
| 35 |
+
DONE_THRESHOLD = 0.92
|
|
|
|
|
|
|
| 36 |
SECURITY_FLOOR_DIMS = ["attack_resist", "static_security"]
|
| 37 |
+
SECURITY_FLOOR_CAP = 0.65 # unclamped; after clamping β β€ 0.649
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
def grade_submission(code, filename, task, graph, step, seed):
|
| 41 |
+
corr = grade_correctness(code, task)
|
| 42 |
+
atk = grade_attacks(code, task, seed)
|
| 43 |
+
stat = grade_static_analysis(code, task)
|
| 44 |
+
perf = grade_performance(code, task)
|
| 45 |
+
cons = grade_consistency(code, filename, graph, step)
|
| 46 |
+
doc = grade_documentation(code)
|
| 47 |
struct = grade_code_structure(code)
|
| 48 |
|
| 49 |
+
# All individual dimension scores are already clamped by each grader.
|
| 50 |
scores = {
|
| 51 |
"correctness": corr["score"],
|
| 52 |
"attack_resist": atk["score"],
|
|
|
|
| 57 |
"code_structure": struct["score"],
|
| 58 |
}
|
| 59 |
|
| 60 |
+
# Weighted sum
|
| 61 |
+
raw = sum(scores[k] * WEIGHTS[k] for k in WEIGHTS)
|
| 62 |
|
| 63 |
+
# Security floor
|
| 64 |
+
security_fail = any(scores[d] < 0.5 for d in SECURITY_FLOOR_DIMS)
|
| 65 |
+
if security_fail:
|
| 66 |
+
raw = min(raw, SECURITY_FLOOR_CAP)
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
# Final clamp β guarantees (0.001 β¦ 0.999)
|
| 69 |
+
total_reward = clamp(raw)
|
| 70 |
|
| 71 |
+
# Security gate for done
|
| 72 |
+
gate_passed = all(scores[d] >= thr for d, thr in SECURITY_GATE.items())
|
|
|
|
| 73 |
done_eligible = total_reward >= DONE_THRESHOLD and gate_passed
|
| 74 |
|
| 75 |
feedback = {
|
|
|
|
| 86 |
|
| 87 |
details = {
|
| 88 |
"correctness": {"passed": corr.get("passed"), "total": corr.get("total")},
|
| 89 |
+
"attacks": {"blocked": atk.get("blocked"), "total": atk.get("total"),
|
| 90 |
+
"type": atk.get("attack_type")},
|
| 91 |
+
"static": {"bandit_score": stat.get("bandit_score"),
|
| 92 |
+
"hard_fail": stat.get("hard_fail", False),
|
| 93 |
+
"issues": stat.get("issues", [])[:3]},
|
| 94 |
"security_gate_passed": gate_passed,
|
| 95 |
+
"done_eligible": done_eligible,
|
| 96 |
}
|
| 97 |
|
| 98 |
return {
|
| 99 |
+
"scores": scores,
|
| 100 |
+
"total_reward": total_reward,
|
| 101 |
"done_eligible": done_eligible,
|
| 102 |
+
"feedback": feedback,
|
| 103 |
+
"details": details,
|
| 104 |
+
"agent_ms": perf.get("agent_ms"),
|
| 105 |
+
"naive_ms": perf.get("naive_ms"),
|
| 106 |
+
"optimal_ms": perf.get("optimal_ms"),
|
| 107 |
+
"new_metadata": extract_metadata(code, filename, step),
|
| 108 |
}
|
| 109 |
|
| 110 |
|
| 111 |
+
def _gate_status(scores):
|
| 112 |
+
failing = [f"{d}({scores[d]:.2f}<{thr})"
|
| 113 |
+
for d, thr in SECURITY_GATE.items() if scores[d] < thr]
|
| 114 |
+
return f"BLOCKED β {', '.join(failing)}"
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
def _summary(reward, scores, gate_passed):
|
| 118 |
if reward >= DONE_THRESHOLD and gate_passed:
|
| 119 |
+
return f"β
Excellent ({reward:.3f}) β security gate passed"
|
| 120 |
if not gate_passed:
|
| 121 |
+
return f"π {_gate_status(scores)} (reward: {reward:.3f})"
|
|
|
|
| 122 |
if reward >= 0.75:
|
| 123 |
weakest = min(scores, key=scores.get)
|
| 124 |
return f"π‘ Good ({reward:.3f}) β improve: {weakest} ({scores[weakest]:.2f})"
|
| 125 |
if reward >= 0.55:
|
| 126 |
weak = [k for k, v in scores.items() if v < 0.5]
|
| 127 |
return f"π Needs work ({reward:.3f}) β fix: {', '.join(weak[:3])}"
|
| 128 |
+
return f"π΄ Poor ({reward:.3f}) β major failures"
|
graders/static_analysis.py
CHANGED
|
@@ -1,33 +1,28 @@
|
|
| 1 |
"""
|
| 2 |
-
SecureCodeEnv - Static Analysis Grader
|
| 3 |
-
|
| 4 |
-
FIXED:
|
| 5 |
-
- HIGH severity issues now cap the score at 0.40 max (was just subtracting 0.30)
|
| 6 |
-
- Task-specific security checks have hard caps when violated
|
| 7 |
-
- bandit penalty curve is steeper
|
| 8 |
"""
|
| 9 |
-
import subprocess, json, tempfile, os,
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def grade_static_analysis(code: str, task: dict) -> dict:
|
| 13 |
bandit = _run_bandit(code)
|
| 14 |
custom = _run_custom_checks(code, task)
|
| 15 |
|
| 16 |
-
# If a HARD security requirement is violated, cap at 0.40 regardless of bandit
|
| 17 |
if custom.get("hard_fail"):
|
| 18 |
-
|
| 19 |
else:
|
| 20 |
-
|
| 21 |
|
| 22 |
all_issues = bandit.get("issues", []) + custom.get("issues", [])
|
| 23 |
-
|
| 24 |
return {
|
| 25 |
-
"score":
|
| 26 |
-
"bandit_score": bandit["score"],
|
| 27 |
-
"ast_score": custom["score"],
|
| 28 |
"hard_fail": custom.get("hard_fail", False),
|
| 29 |
"issues": all_issues[:10],
|
| 30 |
-
"feedback": _feedback(
|
| 31 |
}
|
| 32 |
|
| 33 |
|
|
@@ -37,182 +32,129 @@ def _run_bandit(code: str) -> dict:
|
|
| 37 |
with tempfile.NamedTemporaryFile(mode="w", suffix=".py",
|
| 38 |
delete=False, prefix="sce_ban_") as f:
|
| 39 |
f.write(code); tmp = f.name
|
| 40 |
-
|
| 41 |
res = subprocess.run(
|
| 42 |
["bandit", "-r", tmp, "-f", "json", "-q", "--exit-zero"],
|
| 43 |
-
capture_output=True, text=True, timeout=15
|
| 44 |
-
)
|
| 45 |
-
data = json.loads(res.stdout or '{"results":[]}')
|
| 46 |
issues = data.get("results", [])
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
if sev == "HIGH": penalty += 0.40
|
| 53 |
-
elif sev == "MEDIUM": penalty += 0.20
|
| 54 |
-
else: penalty += 0.05
|
| 55 |
-
|
| 56 |
-
score = max(0.0, 1.0 - penalty)
|
| 57 |
return {
|
| 58 |
-
"score":
|
| 59 |
"issues": [{"severity": i.get("issue_severity"),
|
| 60 |
"text": i.get("issue_text", "")[:100],
|
| 61 |
"line": i.get("line_number")} for i in issues[:5]],
|
| 62 |
}
|
| 63 |
except FileNotFoundError:
|
| 64 |
-
return {"score":
|
| 65 |
except Exception as e:
|
| 66 |
-
return {"score":
|
| 67 |
finally:
|
| 68 |
if tmp and os.path.exists(tmp):
|
| 69 |
try: os.unlink(tmp)
|
| 70 |
except OSError: pass
|
| 71 |
|
| 72 |
|
| 73 |
-
# Security requirements per task β violation = hard_fail (cap at 0.40)
|
| 74 |
HARD_REQUIREMENTS = {
|
| 75 |
"easy_password_validator": [
|
| 76 |
-
{"type": "forbidden_any",
|
| 77 |
-
|
| 78 |
-
"message": "Weak hash
|
| 79 |
-
{"type": "
|
| 80 |
-
"message": "
|
| 81 |
],
|
| 82 |
"easy_token_generator": [
|
| 83 |
-
{"type": "forbidden_any",
|
| 84 |
-
|
| 85 |
-
"message": "Weak PRNG
|
| 86 |
{"type": "required_any", "patterns": ["secrets"],
|
| 87 |
"message": "secrets module not imported"},
|
| 88 |
],
|
| 89 |
"medium_sql_query_builder": [
|
| 90 |
-
{"type": "forbidden_pattern",
|
| 91 |
-
"
|
| 92 |
-
"message": "f-string SQL query β SQL injection vulnerability"},
|
| 93 |
{"type": "forbidden_pattern",
|
| 94 |
"regex": r'["\']\s*\+\s*(username|role|user_input|query)',
|
| 95 |
-
"message": "String
|
| 96 |
-
{"type": "forbidden_pattern",
|
| 97 |
-
"regex": r'%\s*\(',
|
| 98 |
-
"message": "% formatting in SQL β potential injection"},
|
| 99 |
],
|
| 100 |
"medium_file_path_handler": [
|
| 101 |
-
{"type": "
|
| 102 |
-
"
|
| 103 |
-
"message": "String check for '..' is bypassable β use Path.resolve()"},
|
| 104 |
-
{"type": "required_any",
|
| 105 |
-
"patterns": ["resolve()", ".resolve", "realpath"],
|
| 106 |
-
"message": "Must use os.path.realpath or Path.resolve() to prevent traversal"},
|
| 107 |
],
|
| 108 |
"hard_jwt_validator": [
|
| 109 |
{"type": "forbidden_any",
|
| 110 |
-
"patterns": ["verify_signature\": False",
|
| 111 |
-
"verify_exp\": False",
|
| 112 |
-
"message": "JWT verification disabled
|
| 113 |
-
{"type": "required_any",
|
| 114 |
-
"
|
| 115 |
-
"message": "algorithms= not specified β alg:none attack possible"},
|
| 116 |
],
|
| 117 |
"hard_auth_middleware": [
|
| 118 |
-
{"type": "required_any",
|
| 119 |
-
"
|
| 120 |
-
"message": "hmac.compare_digest not used β timing attack possible"},
|
| 121 |
-
{"type": "forbidden_pattern",
|
| 122 |
-
"regex": r'==\s*(session_token|request_token|secret|token)',
|
| 123 |
-
"message": "== used for secret comparison β timing attack"},
|
| 124 |
],
|
| 125 |
"easy_input_sanitizer": [
|
| 126 |
-
{"type": "forbidden_any",
|
| 127 |
-
"
|
| 128 |
-
"message": "eval/exec used β remote code execution risk"},
|
| 129 |
],
|
| 130 |
"hard_file_upload_handler": [
|
| 131 |
-
{"type": "required_any",
|
| 132 |
-
"
|
| 133 |
-
"message": "uuid not used β filename collisions possible"},
|
| 134 |
-
{"type": "required_any",
|
| 135 |
-
"patterns": [".suffix", "splitext", "os.path.splitext"],
|
| 136 |
-
"message": "Extension not checked via proper method"},
|
| 137 |
],
|
| 138 |
}
|
| 139 |
|
| 140 |
-
# Soft checks (no hard cap) β applied to all tasks
|
| 141 |
GENERIC_CHECKS = [
|
| 142 |
-
{"type": "forbidden_any", "patterns": ["eval(",
|
| 143 |
"message": "eval/exec usage", "severity": "HIGH"},
|
| 144 |
{"type": "forbidden_any", "patterns": ["shell=True"],
|
| 145 |
-
"message": "shell=True
|
| 146 |
-
{"type": "forbidden_any", "patterns": ["pickle.loads",
|
| 147 |
-
"message": "Unsafe pickle
|
| 148 |
-
{"type": "forbidden_any", "patterns": ["
|
| 149 |
-
"message": "
|
| 150 |
-
{"type": "forbidden_any", "patterns": ["hashlib.md5", "hashlib.sha1"],
|
| 151 |
-
"message": "Weak hash function", "severity": "HIGH"},
|
| 152 |
]
|
| 153 |
|
| 154 |
|
| 155 |
def _run_custom_checks(code: str, task: dict) -> dict:
|
| 156 |
-
issues = []
|
| 157 |
-
hard_fail = False
|
| 158 |
-
checks_passed = 0
|
| 159 |
-
total_checks = 0
|
| 160 |
|
| 161 |
-
# Generic checks
|
| 162 |
for chk in GENERIC_CHECKS:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
if found:
|
| 166 |
issues.append({"check": chk["message"], "severity": chk.get("severity","MEDIUM"),
|
| 167 |
"message": chk["message"]})
|
| 168 |
else:
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
total_checks += 1
|
| 175 |
-
violated = _check_requirement_violated(code, req)
|
| 176 |
-
if violated:
|
| 177 |
hard_fail = True
|
| 178 |
issues.append({"check": req["message"], "severity": "CRITICAL",
|
| 179 |
"message": req["message"]})
|
| 180 |
else:
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
score = checks_passed / max(total_checks, 1)
|
| 184 |
-
return {"score": round(score, 4), "issues": issues, "hard_fail": hard_fail}
|
| 185 |
-
|
| 186 |
|
| 187 |
-
|
| 188 |
-
"""Returns True if the violation is found."""
|
| 189 |
-
t = chk.get("type", "")
|
| 190 |
-
if t == "forbidden_any":
|
| 191 |
-
return any(p in code for p in chk.get("patterns", []))
|
| 192 |
-
if t == "required_any":
|
| 193 |
-
return not any(p in code for p in chk.get("patterns", []))
|
| 194 |
-
if t == "forbidden_pattern":
|
| 195 |
-
return bool(re.search(chk.get("regex", "NOMATCH"), code, re.IGNORECASE))
|
| 196 |
-
return False
|
| 197 |
|
| 198 |
|
| 199 |
-
def
|
| 200 |
-
"""Returns True if requirement is violated (= bad)."""
|
| 201 |
t = req.get("type", "")
|
| 202 |
if t == "forbidden_any":
|
| 203 |
return any(p in code for p in req.get("patterns", []))
|
| 204 |
if t == "required_any":
|
| 205 |
return not any(p in code for p in req.get("patterns", []))
|
| 206 |
if t == "forbidden_pattern":
|
| 207 |
-
return bool(re.search(req.get("regex",
|
| 208 |
return False
|
| 209 |
|
| 210 |
|
| 211 |
def _feedback(score: float, issues: list, hard_fail: bool) -> str:
|
| 212 |
if hard_fail:
|
| 213 |
-
|
| 214 |
-
return f"CRITICAL security violation: {'; '.join(critical[:2])}"
|
| 215 |
if score >= 0.9: return "Clean β no significant security issues"
|
| 216 |
high = sum(1 for i in issues if i.get("severity") == "HIGH")
|
| 217 |
-
|
| 218 |
-
return f"Some security issues found (score: {score:.2f})"
|
|
|
|
| 1 |
"""
|
| 2 |
+
SecureCodeEnv - Static Analysis Grader v4
|
| 3 |
+
All scores clamped to (0.001, 0.999).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
+
import subprocess, json, tempfile, os, re
|
| 6 |
+
from graders.clamp import clamp
|
| 7 |
|
| 8 |
|
| 9 |
def grade_static_analysis(code: str, task: dict) -> dict:
|
| 10 |
bandit = _run_bandit(code)
|
| 11 |
custom = _run_custom_checks(code, task)
|
| 12 |
|
|
|
|
| 13 |
if custom.get("hard_fail"):
|
| 14 |
+
final = min(bandit["score"] * 0.4, 0.40)
|
| 15 |
else:
|
| 16 |
+
final = (bandit["score"] * 0.60) + (custom["score"] * 0.40)
|
| 17 |
|
| 18 |
all_issues = bandit.get("issues", []) + custom.get("issues", [])
|
|
|
|
| 19 |
return {
|
| 20 |
+
"score": clamp(final),
|
| 21 |
+
"bandit_score": clamp(bandit["score"]),
|
| 22 |
+
"ast_score": clamp(custom["score"]),
|
| 23 |
"hard_fail": custom.get("hard_fail", False),
|
| 24 |
"issues": all_issues[:10],
|
| 25 |
+
"feedback": _feedback(final, all_issues, custom.get("hard_fail", False)),
|
| 26 |
}
|
| 27 |
|
| 28 |
|
|
|
|
| 32 |
with tempfile.NamedTemporaryFile(mode="w", suffix=".py",
|
| 33 |
delete=False, prefix="sce_ban_") as f:
|
| 34 |
f.write(code); tmp = f.name
|
|
|
|
| 35 |
res = subprocess.run(
|
| 36 |
["bandit", "-r", tmp, "-f", "json", "-q", "--exit-zero"],
|
| 37 |
+
capture_output=True, text=True, timeout=15)
|
| 38 |
+
data = json.loads(res.stdout or '{"results":[]}')
|
|
|
|
| 39 |
issues = data.get("results", [])
|
| 40 |
+
penalty = sum(
|
| 41 |
+
0.40 if i.get("issue_severity") == "HIGH"
|
| 42 |
+
else 0.20 if i.get("issue_severity") == "MEDIUM"
|
| 43 |
+
else 0.05
|
| 44 |
+
for i in issues)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
return {
|
| 46 |
+
"score": max(0.0, 1.0 - penalty),
|
| 47 |
"issues": [{"severity": i.get("issue_severity"),
|
| 48 |
"text": i.get("issue_text", "")[:100],
|
| 49 |
"line": i.get("line_number")} for i in issues[:5]],
|
| 50 |
}
|
| 51 |
except FileNotFoundError:
|
| 52 |
+
return {"score": 0.75, "issues": [], "note": "bandit not installed"}
|
| 53 |
except Exception as e:
|
| 54 |
+
return {"score": 0.75, "issues": [], "note": str(e)[:40]}
|
| 55 |
finally:
|
| 56 |
if tmp and os.path.exists(tmp):
|
| 57 |
try: os.unlink(tmp)
|
| 58 |
except OSError: pass
|
| 59 |
|
| 60 |
|
|
|
|
| 61 |
HARD_REQUIREMENTS = {
|
| 62 |
"easy_password_validator": [
|
| 63 |
+
{"type": "forbidden_any",
|
| 64 |
+
"patterns": ["hashlib.md5","hashlib.sha1","hashlib.sha256","md5(","sha1(","sha256("],
|
| 65 |
+
"message": "Weak hash β must use bcrypt"},
|
| 66 |
+
{"type": "required_any", "patterns": ["bcrypt"],
|
| 67 |
+
"message": "bcrypt not imported"},
|
| 68 |
],
|
| 69 |
"easy_token_generator": [
|
| 70 |
+
{"type": "forbidden_any",
|
| 71 |
+
"patterns": ["random.random(","random.randint(","random.choice(","random.seed("],
|
| 72 |
+
"message": "Weak PRNG β must use secrets"},
|
| 73 |
{"type": "required_any", "patterns": ["secrets"],
|
| 74 |
"message": "secrets module not imported"},
|
| 75 |
],
|
| 76 |
"medium_sql_query_builder": [
|
| 77 |
+
{"type": "forbidden_pattern", "regex": r'f["\'].*SELECT.*{',
|
| 78 |
+
"message": "f-string SQL β injection vulnerability"},
|
|
|
|
| 79 |
{"type": "forbidden_pattern",
|
| 80 |
"regex": r'["\']\s*\+\s*(username|role|user_input|query)',
|
| 81 |
+
"message": "String concat SQL β injection vulnerability"},
|
|
|
|
|
|
|
|
|
|
| 82 |
],
|
| 83 |
"medium_file_path_handler": [
|
| 84 |
+
{"type": "required_any", "patterns": ["resolve()","resolve(","realpath"],
|
| 85 |
+
"message": "Must use Path.resolve() or realpath"},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
],
|
| 87 |
"hard_jwt_validator": [
|
| 88 |
{"type": "forbidden_any",
|
| 89 |
+
"patterns": ["verify_signature\": False","verify_signature':False",
|
| 90 |
+
"verify_exp\": False","algorithms=[\"none\"","algorithms=['none'"],
|
| 91 |
+
"message": "JWT verification disabled"},
|
| 92 |
+
{"type": "required_any", "patterns": ["algorithms="],
|
| 93 |
+
"message": "algorithms= not specified"},
|
|
|
|
| 94 |
],
|
| 95 |
"hard_auth_middleware": [
|
| 96 |
+
{"type": "required_any", "patterns": ["hmac.compare_digest"],
|
| 97 |
+
"message": "hmac.compare_digest not used"},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
],
|
| 99 |
"easy_input_sanitizer": [
|
| 100 |
+
{"type": "forbidden_any", "patterns": ["eval(","exec("],
|
| 101 |
+
"message": "eval/exec usage"},
|
|
|
|
| 102 |
],
|
| 103 |
"hard_file_upload_handler": [
|
| 104 |
+
{"type": "required_any", "patterns": ["uuid"],
|
| 105 |
+
"message": "uuid not used"},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
],
|
| 107 |
}
|
| 108 |
|
|
|
|
| 109 |
GENERIC_CHECKS = [
|
| 110 |
+
{"type": "forbidden_any", "patterns": ["eval(","exec("],
|
| 111 |
"message": "eval/exec usage", "severity": "HIGH"},
|
| 112 |
{"type": "forbidden_any", "patterns": ["shell=True"],
|
| 113 |
+
"message": "shell=True", "severity": "HIGH"},
|
| 114 |
+
{"type": "forbidden_any", "patterns": ["pickle.loads","pickle.load"],
|
| 115 |
+
"message": "Unsafe pickle", "severity": "HIGH"},
|
| 116 |
+
{"type": "forbidden_any", "patterns": ["hashlib.md5","hashlib.sha1"],
|
| 117 |
+
"message": "Weak hash", "severity": "HIGH"},
|
|
|
|
|
|
|
| 118 |
]
|
| 119 |
|
| 120 |
|
| 121 |
def _run_custom_checks(code: str, task: dict) -> dict:
|
| 122 |
+
issues = []; hard_fail = False; passed = total = 0
|
|
|
|
|
|
|
|
|
|
| 123 |
|
|
|
|
| 124 |
for chk in GENERIC_CHECKS:
|
| 125 |
+
total += 1
|
| 126 |
+
if _violated(code, chk):
|
|
|
|
| 127 |
issues.append({"check": chk["message"], "severity": chk.get("severity","MEDIUM"),
|
| 128 |
"message": chk["message"]})
|
| 129 |
else:
|
| 130 |
+
passed += 1
|
| 131 |
+
|
| 132 |
+
for req in HARD_REQUIREMENTS.get(task.get("id",""), []):
|
| 133 |
+
total += 1
|
| 134 |
+
if _violated(code, req):
|
|
|
|
|
|
|
|
|
|
| 135 |
hard_fail = True
|
| 136 |
issues.append({"check": req["message"], "severity": "CRITICAL",
|
| 137 |
"message": req["message"]})
|
| 138 |
else:
|
| 139 |
+
passed += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
return {"score": passed / max(total, 1), "issues": issues, "hard_fail": hard_fail}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
+
def _violated(code: str, req: dict) -> bool:
|
|
|
|
| 145 |
t = req.get("type", "")
|
| 146 |
if t == "forbidden_any":
|
| 147 |
return any(p in code for p in req.get("patterns", []))
|
| 148 |
if t == "required_any":
|
| 149 |
return not any(p in code for p in req.get("patterns", []))
|
| 150 |
if t == "forbidden_pattern":
|
| 151 |
+
return bool(re.search(req.get("regex","NOMATCH"), code, re.IGNORECASE | re.DOTALL))
|
| 152 |
return False
|
| 153 |
|
| 154 |
|
| 155 |
def _feedback(score: float, issues: list, hard_fail: bool) -> str:
|
| 156 |
if hard_fail:
|
| 157 |
+
return f"CRITICAL: {'; '.join(i['message'] for i in issues if i.get('severity')=='CRITICAL')[:120]}"
|
|
|
|
| 158 |
if score >= 0.9: return "Clean β no significant security issues"
|
| 159 |
high = sum(1 for i in issues if i.get("severity") == "HIGH")
|
| 160 |
+
return f"{high} HIGH severity issue(s)" if high else f"Some issues (score: {score:.2f})"
|
|
|
inference.py
CHANGED
|
@@ -13,92 +13,140 @@ from typing import Dict, List, Any
|
|
| 13 |
|
| 14 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 16 |
-
MODEL_NAME
|
| 17 |
-
HF_TOKEN
|
| 18 |
-
ENV_URL
|
| 19 |
|
| 20 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "sk-placeholder")
|
| 21 |
|
|
|
|
| 22 |
def clamp_score(score: float) -> float:
|
| 23 |
-
"""
|
| 24 |
-
Ensures score is strictly between 0 and 1 (e.g., 0.001 to 0.999).
|
| 25 |
-
Required by validator range constraints.
|
| 26 |
-
"""
|
| 27 |
epsilon = 0.001
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def clean_code(raw: str) -> str:
|
| 31 |
"""Removes markdown code fences safely."""
|
| 32 |
-
lines = [line for line in raw.splitlines()
|
|
|
|
| 33 |
return "\n".join(lines).strip()
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def run_episode(difficulty: str) -> None:
|
| 36 |
-
"""Runs episode and prints
|
| 37 |
try:
|
| 38 |
-
r = requests.post(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
r.raise_for_status()
|
| 40 |
data = r.json()
|
| 41 |
except Exception as e:
|
| 42 |
print(f"Failed to reset {difficulty}: {e}", file=sys.stderr)
|
| 43 |
return
|
| 44 |
|
| 45 |
-
sid
|
| 46 |
-
tid
|
| 47 |
-
|
| 48 |
-
# [START] block
|
| 49 |
print(f"[START] task={tid} difficulty={difficulty}", flush=True)
|
| 50 |
|
| 51 |
-
final_score = 0.0
|
| 52 |
total_steps = 0
|
| 53 |
|
| 54 |
for i in range(1, 6):
|
| 55 |
total_steps = i
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
try:
|
| 59 |
resp = client.chat.completions.create(
|
| 60 |
model=MODEL_NAME,
|
| 61 |
-
messages=[
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
code = clean_code(resp.choices[0].message.content or "")
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
step_r = requests.post(
|
| 67 |
f"{ENV_URL}/step",
|
| 68 |
-
json={
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
)
|
| 71 |
step_r.raise_for_status()
|
| 72 |
res = step_r.json()
|
| 73 |
-
|
| 74 |
-
raw_reward
|
| 75 |
-
|
| 76 |
-
final_score
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
print(f"[STEP] step={i} reward={clamped_reward:.3f}", flush=True)
|
| 80 |
|
| 81 |
if res.get("done"):
|
| 82 |
break
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
except Exception as e:
|
| 86 |
print(f"Error in step {i}: {e}", file=sys.stderr)
|
| 87 |
-
break
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
# [END] block with clamped final score
|
| 90 |
-
print(f"[END] task={tid} score={final_score:.3f} steps={total_steps}", flush=True)
|
| 91 |
|
| 92 |
def main():
|
|
|
|
| 93 |
try:
|
| 94 |
-
requests.get(f"{ENV_URL}/health", timeout=
|
|
|
|
| 95 |
except Exception as e:
|
| 96 |
print(f"Health check failed: {e}", file=sys.stderr)
|
| 97 |
sys.exit(1)
|
| 98 |
|
| 99 |
for diff in ["easy", "medium", "hard"]:
|
| 100 |
run_episode(diff)
|
| 101 |
-
time.sleep(
|
|
|
|
| 102 |
|
| 103 |
if __name__ == "__main__":
|
| 104 |
main()
|
|
|
|
| 13 |
|
| 14 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 16 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 17 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 18 |
+
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
|
| 19 |
|
| 20 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "sk-placeholder")
|
| 21 |
|
| 22 |
+
|
| 23 |
def clamp_score(score: float) -> float:
|
| 24 |
+
"""Ensures score is strictly between 0 and 1 (0.001 β¦ 0.999)."""
|
|
|
|
|
|
|
|
|
|
| 25 |
epsilon = 0.001
|
| 26 |
+
try:
|
| 27 |
+
v = float(score)
|
| 28 |
+
except (TypeError, ValueError):
|
| 29 |
+
return 0.5
|
| 30 |
+
if v != v: # NaN
|
| 31 |
+
return 0.5
|
| 32 |
+
return max(epsilon, min(1.0 - epsilon, v))
|
| 33 |
+
|
| 34 |
|
| 35 |
def clean_code(raw: str) -> str:
|
| 36 |
"""Removes markdown code fences safely."""
|
| 37 |
+
lines = [line for line in raw.splitlines()
|
| 38 |
+
if not line.strip().startswith("```")]
|
| 39 |
return "\n".join(lines).strip()
|
| 40 |
|
| 41 |
+
|
| 42 |
+
SYSTEM_PROMPT = """You are a senior Python security engineer.
|
| 43 |
+
Output ONLY raw Python code β no markdown, no explanations.
|
| 44 |
+
Your code must:
|
| 45 |
+
1. Solve the problem correctly
|
| 46 |
+
2. Resist SQL injection, path traversal, and auth bypass attacks
|
| 47 |
+
3. Use parameterized queries β never f-string SQL
|
| 48 |
+
4. Use secrets module (not random) for tokens
|
| 49 |
+
5. Use bcrypt (not hashlib) for passwords
|
| 50 |
+
6. Use hmac.compare_digest for secret comparison
|
| 51 |
+
7. Have type hints and docstrings on every function"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
def run_episode(difficulty: str) -> None:
|
| 55 |
+
"""Runs one episode and prints [START], [STEP], [END] blocks."""
|
| 56 |
try:
|
| 57 |
+
r = requests.post(
|
| 58 |
+
f"{ENV_URL}/reset",
|
| 59 |
+
json={"difficulty": difficulty},
|
| 60 |
+
timeout=30,
|
| 61 |
+
)
|
| 62 |
r.raise_for_status()
|
| 63 |
data = r.json()
|
| 64 |
except Exception as e:
|
| 65 |
print(f"Failed to reset {difficulty}: {e}", file=sys.stderr)
|
| 66 |
return
|
| 67 |
|
| 68 |
+
sid = data["session_id"]
|
| 69 |
+
tid = data["task_id"]
|
|
|
|
|
|
|
| 70 |
print(f"[START] task={tid} difficulty={difficulty}", flush=True)
|
| 71 |
|
| 72 |
+
final_score = clamp_score(0.0) # starts at epsilon, not 0.0
|
| 73 |
total_steps = 0
|
| 74 |
|
| 75 |
for i in range(1, 6):
|
| 76 |
total_steps = i
|
| 77 |
+
context_str = json.dumps(data.get("codegraph", {}))[:2000]
|
| 78 |
+
prev_fb = data.get("last_feedback", "")
|
| 79 |
+
|
| 80 |
+
user_msg = (
|
| 81 |
+
f"Task: {data['problem_statement']}\n\n"
|
| 82 |
+
f"Security targets: {data.get('cwe_targets', [])}\n\n"
|
| 83 |
+
f"Codebase context:\n{context_str}"
|
| 84 |
+
)
|
| 85 |
+
if prev_fb:
|
| 86 |
+
user_msg += f"\n\nPrevious feedback:\n{prev_fb}"
|
| 87 |
+
user_msg += "\n\nWrite the complete Python implementation now:"
|
| 88 |
+
|
| 89 |
try:
|
| 90 |
resp = client.chat.completions.create(
|
| 91 |
model=MODEL_NAME,
|
| 92 |
+
messages=[
|
| 93 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 94 |
+
{"role": "user", "content": user_msg},
|
| 95 |
+
],
|
| 96 |
+
max_tokens=1500,
|
| 97 |
+
temperature=0.1,
|
| 98 |
)
|
| 99 |
code = clean_code(resp.choices[0].message.content or "")
|
| 100 |
+
if not code.strip():
|
| 101 |
+
code = "def placeholder(): pass"
|
| 102 |
+
|
| 103 |
step_r = requests.post(
|
| 104 |
f"{ENV_URL}/step",
|
| 105 |
+
json={
|
| 106 |
+
"session_id": sid,
|
| 107 |
+
"code": code,
|
| 108 |
+
"filename": f"step_{i}.py",
|
| 109 |
+
"task_id": tid,
|
| 110 |
+
},
|
| 111 |
+
timeout=65,
|
| 112 |
)
|
| 113 |
step_r.raise_for_status()
|
| 114 |
res = step_r.json()
|
| 115 |
+
|
| 116 |
+
raw_reward = res.get("total_reward", 0.0)
|
| 117 |
+
clamped = clamp_score(raw_reward)
|
| 118 |
+
final_score = clamped
|
| 119 |
+
|
| 120 |
+
print(f"[STEP] step={i} reward={clamped:.4f}", flush=True)
|
|
|
|
| 121 |
|
| 122 |
if res.get("done"):
|
| 123 |
break
|
| 124 |
+
|
| 125 |
+
# Feed updated context back for next step
|
| 126 |
+
data["codegraph"] = res.get("codegraph", {})
|
| 127 |
+
data["last_feedback"] = res.get("feedback", {}).get("summary", "")
|
| 128 |
+
|
| 129 |
except Exception as e:
|
| 130 |
print(f"Error in step {i}: {e}", file=sys.stderr)
|
| 131 |
+
# Don't break β try remaining steps
|
| 132 |
+
time.sleep(1)
|
| 133 |
+
|
| 134 |
+
print(f"[END] task={tid} score={final_score:.4f} steps={total_steps}", flush=True)
|
| 135 |
|
|
|
|
|
|
|
| 136 |
|
| 137 |
def main():
|
| 138 |
+
# Health check
|
| 139 |
try:
|
| 140 |
+
requests.get(f"{ENV_URL}/health", timeout=10).raise_for_status()
|
| 141 |
+
print(f"Environment healthy: {ENV_URL}", file=sys.stderr)
|
| 142 |
except Exception as e:
|
| 143 |
print(f"Health check failed: {e}", file=sys.stderr)
|
| 144 |
sys.exit(1)
|
| 145 |
|
| 146 |
for diff in ["easy", "medium", "hard"]:
|
| 147 |
run_episode(diff)
|
| 148 |
+
time.sleep(2)
|
| 149 |
+
|
| 150 |
|
| 151 |
if __name__ == "__main__":
|
| 152 |
main()
|