vishaldhakad commited on
Commit
791664b
Β·
1 Parent(s): ca9160b

change in reward system to strict it between the 0-1

Browse files
graders/attacks.py CHANGED
@@ -1,14 +1,10 @@
1
  """
2
- SecureCodeEnv - Attack Resistance Grader v3
3
-
4
- FIXED:
5
- - All check lambdas handle list OR tuple (JSON round-trip)
6
- - Path traversal check: blocks absolute paths + encoded paths, not just '..'
7
- - SQL check: verifies placeholder present AND payload not injected
8
- - Consistent scoring regardless of seed
9
  """
10
  from sandbox.executor import safe_exec
11
  from sandbox.payload_gen import get_payloads
 
12
 
13
 
14
  def _is_seq(v):
@@ -16,37 +12,25 @@ def _is_seq(v):
16
 
17
 
18
  def _sql_safe(result, payload):
19
- """SQL query output is safe when: is 2-element seq, query has placeholder, raw payload not in query."""
20
  if not _is_seq(result) or len(result) < 2:
21
  return False
22
  query = str(result[0])
23
- # Must use parameterized placeholder
24
  if not any(p in query for p in ("%s", "?", ":param", "%(", ":%")):
25
  return False
26
- # The raw payload (stripped of whitespace) must NOT appear verbatim in the query
27
  payload_stripped = payload.strip().replace(" ", "").lower()
28
  query_lower = query.lower().replace(" ", "")
29
- if payload_stripped in query_lower:
30
- return False
31
- return True
32
 
33
 
34
  def _path_safe(result, payload):
35
- """
36
- Path output is safe when resolved path stays inside /var/uploads.
37
- Also catches: absolute paths (/etc/), encoded paths (%2e), null bytes.
38
- """
39
  if result is None:
40
- return True # Rejected = safe
41
  if not isinstance(result, str):
42
  return True
43
  r = result.lower()
44
- # Must stay inside base
45
  if not result.startswith("/var/uploads"):
46
  return False
47
- # Must not contain escape indicators
48
- bad = ["/etc/", "/root/", "/proc/", "/var/log", "passwd", "shadow",
49
- "%2e", "%2f", ".."]
50
  return not any(b in r for b in bad)
51
 
52
 
@@ -56,14 +40,14 @@ ATTACK_CONFIGS = {
56
  "fn_args_template": ["{payload}", "admin"],
57
  "attack_type": "injection",
58
  "check": _sql_safe,
59
- "exception_is_pass": False, # SQL builder shouldn't raise on injection input
60
  },
61
  "medium_file_path_handler": {
62
  "fn_name": "resolve_safe_path",
63
  "fn_args_template": ["/var/uploads", "{payload}"],
64
  "attack_type": "traversal",
65
  "check": _path_safe,
66
- "exception_is_pass": True, # Raising ValueError = correct
67
  },
68
  "hard_file_upload_handler": {
69
  "fn_name": "validate_upload",
@@ -97,18 +81,8 @@ ATTACK_CONFIGS = {
97
  "attack_type": "xss",
98
  "check": lambda result, payload: (
99
  isinstance(result, str) and
100
- not any(b in result.lower() for b in
101
- ["<script", "onerror", "javascript:", "<img ", "<svg", "onload"])
102
- ),
103
- "exception_is_pass": False,
104
- },
105
- "easy_input_sanitizer_filename": {
106
- "fn_name": "sanitize_filename",
107
- "fn_args_template": ["{payload}"],
108
- "attack_type": "traversal",
109
- "check": lambda result, payload: (
110
- isinstance(result, str) and
111
- ".." not in result and "/" not in result and "\\" not in result
112
  ),
113
  "exception_is_pass": False,
114
  },
@@ -117,56 +91,51 @@ ATTACK_CONFIGS = {
117
 
118
  def grade_attacks(code: str, task: dict, seed: int) -> dict:
119
  task_id = task.get("id", "")
120
- config = ATTACK_CONFIGS.get(task_id)
121
  if not config:
122
- return {
123
- "score": 1.0, "blocked": 0, "total": 0,
124
- "attack_type": "none",
125
- "feedback": "No attack grading defined for this task type",
126
- }
127
 
128
  payloads = get_payloads(config["attack_type"], seed=seed, count=8)
129
  if not payloads:
130
- return {"score": 1.0, "blocked": 0, "total": 0,
131
- "attack_type": config["attack_type"], "feedback": "No payloads generated"}
 
132
 
133
- blocked = 0
134
  exception_is_pass = config.get("exception_is_pass", True)
135
 
136
  for payload in payloads:
137
- raw_args = config["fn_args_template"]
138
- args = [a.replace("{payload}", payload) if isinstance(a, str) else a
139
- for a in raw_args]
140
  result = safe_exec(code, args, function_name=config["fn_name"], timeout=3)
141
 
142
  if not result["ok"]:
143
  if exception_is_pass:
144
  blocked += 1
145
  else:
146
- output = result.get("output")
147
  try:
148
- if config["check"](output, payload):
149
  blocked += 1
150
  except Exception:
151
  pass
152
 
153
- score = blocked / len(payloads)
154
  return {
155
- "score": round(score, 4),
156
  "blocked": blocked,
157
  "total": len(payloads),
158
  "attack_type": config["attack_type"],
159
- "feedback": _feedback(score, config["attack_type"]),
160
  }
161
 
162
 
163
  def _feedback(score: float, attack_type: str) -> str:
164
- names = {
165
- "injection": "SQL injection", "traversal": "path traversal",
166
- "auth_bypass": "authentication bypass", "xss": "XSS",
167
- }
168
  name = names.get(attack_type, attack_type)
169
  if score >= 0.875: return f"Excellent β€” {name} attacks blocked ({score:.0%})"
170
  elif score >= 0.625: return f"Good β€” most {name} attacks blocked ({score:.0%})"
171
- elif score >= 0.375: return f"Partial β€” only {score:.0%} of {name} attacks blocked"
172
- else: return f"Vulnerable β€” {score:.0%} of {name} attacks blocked β€” CRITICAL"
 
1
  """
2
+ SecureCodeEnv - Attack Resistance Grader v4
3
+ All scores clamped to (0.001, 0.999).
 
 
 
 
 
4
  """
5
  from sandbox.executor import safe_exec
6
  from sandbox.payload_gen import get_payloads
7
+ from graders.clamp import clamp
8
 
9
 
10
  def _is_seq(v):
 
12
 
13
 
14
  def _sql_safe(result, payload):
 
15
  if not _is_seq(result) or len(result) < 2:
16
  return False
17
  query = str(result[0])
 
18
  if not any(p in query for p in ("%s", "?", ":param", "%(", ":%")):
19
  return False
 
20
  payload_stripped = payload.strip().replace(" ", "").lower()
21
  query_lower = query.lower().replace(" ", "")
22
+ return payload_stripped not in query_lower
 
 
23
 
24
 
25
  def _path_safe(result, payload):
 
 
 
 
26
  if result is None:
27
+ return True
28
  if not isinstance(result, str):
29
  return True
30
  r = result.lower()
 
31
  if not result.startswith("/var/uploads"):
32
  return False
33
+ bad = ["/etc/", "/root/", "/proc/", "passwd", "shadow", "%2e", "%2f", ".."]
 
 
34
  return not any(b in r for b in bad)
35
 
36
 
 
40
  "fn_args_template": ["{payload}", "admin"],
41
  "attack_type": "injection",
42
  "check": _sql_safe,
43
+ "exception_is_pass": False,
44
  },
45
  "medium_file_path_handler": {
46
  "fn_name": "resolve_safe_path",
47
  "fn_args_template": ["/var/uploads", "{payload}"],
48
  "attack_type": "traversal",
49
  "check": _path_safe,
50
+ "exception_is_pass": True,
51
  },
52
  "hard_file_upload_handler": {
53
  "fn_name": "validate_upload",
 
81
  "attack_type": "xss",
82
  "check": lambda result, payload: (
83
  isinstance(result, str) and
84
+ not any(b in result.lower()
85
+ for b in ["<script", "onerror", "javascript:", "<img ", "<svg", "onload"])
 
 
 
 
 
 
 
 
 
 
86
  ),
87
  "exception_is_pass": False,
88
  },
 
91
 
92
  def grade_attacks(code: str, task: dict, seed: int) -> dict:
93
  task_id = task.get("id", "")
94
+ config = ATTACK_CONFIGS.get(task_id)
95
  if not config:
96
+ return {"score": clamp(0.5), "blocked": 0, "total": 0,
97
+ "attack_type": "none",
98
+ "feedback": "No attack grading defined β€” neutral score"}
 
 
99
 
100
  payloads = get_payloads(config["attack_type"], seed=seed, count=8)
101
  if not payloads:
102
+ return {"score": clamp(0.5), "blocked": 0, "total": 0,
103
+ "attack_type": config["attack_type"],
104
+ "feedback": "No payloads generated β€” neutral score"}
105
 
106
+ blocked = 0
107
  exception_is_pass = config.get("exception_is_pass", True)
108
 
109
  for payload in payloads:
110
+ args = [a.replace("{payload}", payload) if isinstance(a, str) else a
111
+ for a in config["fn_args_template"]]
 
112
  result = safe_exec(code, args, function_name=config["fn_name"], timeout=3)
113
 
114
  if not result["ok"]:
115
  if exception_is_pass:
116
  blocked += 1
117
  else:
 
118
  try:
119
+ if config["check"](result.get("output"), payload):
120
  blocked += 1
121
  except Exception:
122
  pass
123
 
124
+ raw = blocked / len(payloads)
125
  return {
126
+ "score": clamp(raw),
127
  "blocked": blocked,
128
  "total": len(payloads),
129
  "attack_type": config["attack_type"],
130
+ "feedback": _feedback(raw, config["attack_type"]),
131
  }
132
 
133
 
134
  def _feedback(score: float, attack_type: str) -> str:
135
+ names = {"injection": "SQL injection", "traversal": "path traversal",
136
+ "auth_bypass": "auth bypass", "xss": "XSS"}
 
 
137
  name = names.get(attack_type, attack_type)
138
  if score >= 0.875: return f"Excellent β€” {name} attacks blocked ({score:.0%})"
139
  elif score >= 0.625: return f"Good β€” most {name} attacks blocked ({score:.0%})"
140
+ elif score >= 0.375: return f"Partial β€” {score:.0%} of {name} attacks blocked"
141
+ else: return f"Vulnerable β€” {score:.0%} of {name} attacks blocked"
graders/clamp.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared epsilon-clamping utility.
3
+ Validator requires scores strictly between 0 and 1: (0.001 … 0.999)
4
+ """
5
+ EPSILON = 0.001
6
+ SCORE_MIN = EPSILON # 0.001
7
+ SCORE_MAX = 1.0 - EPSILON # 0.999
8
+
9
+ def clamp(score: float) -> float:
10
+ """Clamp any score to (0.001, 0.999) β€” never exactly 0 or 1."""
11
+ try:
12
+ v = float(score)
13
+ except (TypeError, ValueError):
14
+ return 0.5 # safe default for bad inputs
15
+ if v != v: # NaN guard
16
+ return 0.5
17
+ return max(SCORE_MIN, min(SCORE_MAX, v))
graders/consistency.py CHANGED
@@ -1,46 +1,31 @@
1
- """
2
- SecureCodeEnv - Consistency Grader v3
3
- FIXED: Step 0 no longer gives free 1.0 β€” rewards ESTABLISHING good practices
4
- """
5
  from codegraph.graph import CodeGraph
6
  from codegraph.extractor import extract_metadata
 
7
 
8
-
9
- # Minimum quality bar for first submission (establishing conventions)
10
  GOOD_PRACTICES = {
11
- "uses_type_hints": ("Type hints present", 0.15),
12
- "uses_docstrings": ("Docstrings present", 0.15),
13
- "uses_try_catch": ("Error handling present", 0.10),
14
- "no_print_stmts": ("No debug print statements", 0.10),
15
- "no_hardcoded_secrets": ("No hardcoded secrets detected", 0.10),
16
  }
17
 
18
 
19
  def grade_consistency(code: str, filename: str, graph: CodeGraph, step: int) -> dict:
20
  new_meta = extract_metadata(code, filename, step)
21
- conv = new_meta.conventions
22
 
23
  if not graph.components:
24
- # Step 0: score on how well the agent ESTABLISHES good practices
25
- checks = {}
26
- for key, (label, _) in GOOD_PRACTICES.items():
27
- checks[key] = 1.0 if conv.get(key, False) else 0.0
28
-
29
- score = sum(checks.values()) / max(len(checks), 1)
30
- # Minimum 0.5 so this doesn't destroy reward on first step
31
- score = max(0.5, score)
32
-
33
- return {
34
- "score": round(score, 4),
35
- "checks": checks,
36
- "feedback": _first_step_feedback(score, checks),
37
- }
38
 
39
- # Step 1+: check consistency with established conventions
40
  established = graph.conventions
41
- checks = {}
42
 
43
- # Naming convention
44
  naming = established.get("naming")
45
  if naming and naming != "mixed" and new_meta.functions:
46
  fns = new_meta.functions
@@ -51,53 +36,40 @@ def grade_consistency(code: str, filename: str, graph: CodeGraph, step: int) ->
51
  and any(c.isupper() for c in f["name"]))
52
  checks["naming_convention"] = correct / len(fns)
53
 
54
- # Error handling
55
  if established.get("error_handling") == "try_catch":
56
  checks["error_handling"] = 1.0 if conv.get("uses_try_catch") else 0.3
57
-
58
- # Type hints
59
  if established.get("uses_type_hints"):
60
  checks["type_hints"] = 1.0 if conv.get("uses_type_hints") else 0.4
61
-
62
- # Docstrings
63
  if established.get("uses_docstrings"):
64
  checks["docstrings"] = 1.0 if conv.get("uses_docstrings") else 0.5
65
 
66
- # No print drift
67
  existing_no_print = all(c.conventions.get("no_print_stmts", True)
68
  for c in graph.components.values())
69
  if existing_no_print:
70
  checks["no_print_drift"] = 1.0 if conv.get("no_print_stmts", True) else 0.3
71
 
72
- # Component reuse
73
  reuse_opp = reuse_taken = 0
74
- for comp_name in graph.components:
75
- if comp_name.lower() in code.lower():
76
  reuse_opp += 1
77
- if comp_name in code:
78
  reuse_taken += 1
79
  if reuse_opp > 0:
80
  checks["component_reuse"] = reuse_taken / reuse_opp
81
 
82
- score = sum(checks.values()) / max(len(checks), 1) if checks else 0.8
83
- return {
84
- "score": round(score, 4),
85
- "checks": checks,
86
- "feedback": _consistency_feedback(score, checks),
87
- }
88
 
89
 
90
- def _first_step_feedback(score: float, checks: dict) -> str:
91
  missing = [k for k, v in checks.items() if v == 0.0]
92
  if not missing:
93
- return f"Good conventions established (score: {score:.2f})"
94
- return f"Missing good practices: {', '.join(missing)} β€” add type hints, docstrings, error handling"
95
 
96
 
97
- def _consistency_feedback(score: float, checks: dict) -> str:
98
- if score >= 0.9:
99
- return "Excellent consistency with existing codebase conventions"
100
  failing = [k for k, v in checks.items() if isinstance(v, float) and v < 0.5]
101
- if failing:
102
- return f"Convention drift in: {', '.join(failing)}"
103
- return f"Minor convention drift (score: {score:.2f})"
 
1
+ """SecureCodeEnv - Consistency Grader v4 β€” clamped scores"""
 
 
 
2
  from codegraph.graph import CodeGraph
3
  from codegraph.extractor import extract_metadata
4
+ from graders.clamp import clamp
5
 
 
 
6
  GOOD_PRACTICES = {
7
+ "uses_type_hints": 0.15,
8
+ "uses_docstrings": 0.15,
9
+ "uses_try_catch": 0.10,
10
+ "no_print_stmts": 0.10,
11
+ "no_hardcoded_secrets": 0.10,
12
  }
13
 
14
 
15
  def grade_consistency(code: str, filename: str, graph: CodeGraph, step: int) -> dict:
16
  new_meta = extract_metadata(code, filename, step)
17
+ conv = new_meta.conventions
18
 
19
  if not graph.components:
20
+ checks = {k: 1.0 if conv.get(k, False) else 0.0 for k in GOOD_PRACTICES}
21
+ raw = sum(checks.values()) / max(len(checks), 1)
22
+ raw = max(0.45, raw) # floor so first step never crushes reward
23
+ return {"score": clamp(raw), "checks": checks,
24
+ "feedback": _first_feedback(raw, checks)}
 
 
 
 
 
 
 
 
 
25
 
 
26
  established = graph.conventions
27
+ checks = {}
28
 
 
29
  naming = established.get("naming")
30
  if naming and naming != "mixed" and new_meta.functions:
31
  fns = new_meta.functions
 
36
  and any(c.isupper() for c in f["name"]))
37
  checks["naming_convention"] = correct / len(fns)
38
 
 
39
  if established.get("error_handling") == "try_catch":
40
  checks["error_handling"] = 1.0 if conv.get("uses_try_catch") else 0.3
 
 
41
  if established.get("uses_type_hints"):
42
  checks["type_hints"] = 1.0 if conv.get("uses_type_hints") else 0.4
 
 
43
  if established.get("uses_docstrings"):
44
  checks["docstrings"] = 1.0 if conv.get("uses_docstrings") else 0.5
45
 
 
46
  existing_no_print = all(c.conventions.get("no_print_stmts", True)
47
  for c in graph.components.values())
48
  if existing_no_print:
49
  checks["no_print_drift"] = 1.0 if conv.get("no_print_stmts", True) else 0.3
50
 
 
51
  reuse_opp = reuse_taken = 0
52
+ for name in graph.components:
53
+ if name.lower() in code.lower():
54
  reuse_opp += 1
55
+ if name in code:
56
  reuse_taken += 1
57
  if reuse_opp > 0:
58
  checks["component_reuse"] = reuse_taken / reuse_opp
59
 
60
+ raw = sum(checks.values()) / max(len(checks), 1) if checks else 0.7
61
+ return {"score": clamp(raw), "checks": checks,
62
+ "feedback": _feedback(raw, checks)}
 
 
 
63
 
64
 
65
+ def _first_feedback(score, checks):
66
  missing = [k for k, v in checks.items() if v == 0.0]
67
  if not missing:
68
+ return f"Good conventions established ({score:.2f})"
69
+ return f"Missing practices: {', '.join(missing)}"
70
 
71
 
72
+ def _feedback(score, checks):
73
+ if score >= 0.85: return "Excellent consistency with codebase"
 
74
  failing = [k for k, v in checks.items() if isinstance(v, float) and v < 0.5]
75
+ return f"Convention drift in: {', '.join(failing)}" if failing else f"Minor drift ({score:.2f})"
 
 
graders/correctness.py CHANGED
@@ -1,56 +1,45 @@
1
  """
2
- SecureCodeEnv - Correctness Grader
3
- Runs each task's test cases against the agent's submitted code.
4
- Weight: 30% of total reward β€” the highest single weight.
5
  """
6
  from sandbox.executor import safe_exec
 
 
7
 
8
  def _is_seq(v):
9
  return isinstance(v, (list, tuple))
10
 
11
 
12
  def grade_correctness(code: str, task: dict) -> dict:
13
- """
14
- Runs the task's test cases against the agent's code.
15
-
16
- Returns:
17
- {
18
- "score": float 0.0-1.0,
19
- "passed": int,
20
- "total": int,
21
- "details": list of per-test results
22
- }
23
- """
24
  test_cases = task.get("test_cases", [])
25
  if not test_cases:
26
- return {"score": 1.0, "passed": 0, "total": 0, "details": [], "feedback": "No test cases defined"}
 
27
 
28
  passed = 0
29
  details = []
30
-
31
  for tc in test_cases:
32
  result = _run_test_case(code, tc)
33
  if result["passed"]:
34
  passed += 1
35
  details.append(result)
36
 
37
- score = passed / len(test_cases) if test_cases else 1.0
38
  return {
39
- "score": round(score, 4),
40
  "passed": passed,
41
  "total": len(test_cases),
42
  "details": details,
43
- "feedback": _correctness_feedback(score, passed, len(test_cases)),
44
  }
45
 
46
 
47
  def _run_test_case(code: str, tc: dict) -> dict:
48
- """Execute a single test case and evaluate the result."""
49
  fn_name = tc.get("fn", "solution")
50
- inputs = tc.get("input", [])
51
- description = tc.get("description", "")
52
 
53
- # Handle class-based tasks
54
  if "fn_class" in tc:
55
  return _run_class_test(code, tc)
56
 
@@ -58,102 +47,92 @@ def _run_test_case(code: str, tc: dict) -> dict:
58
 
59
  if not exec_result["ok"]:
60
  expected_exc = tc.get("expected_exception")
61
- error_str = exec_result.get("error", "")
62
- exc_type = exec_result.get("type", "") # executor returns type field
63
  if expected_exc:
64
- exc_raised = (
65
- exc_type == expected_exc or
66
- expected_exc.lower() in error_str.lower() or
67
- expected_exc.lower() in exc_type.lower()
68
- )
69
- if exc_raised:
70
- return {"passed": True, "description": description, "note": f"Expected {expected_exc} raised"}
71
- return {"passed": False, "description": description, "error": error_str[:200]}
72
 
73
  output = exec_result.get("output")
74
 
75
- # Not-None check
76
- if "expected_not_none" in tc:
77
- ok = output is not None
78
- return {"passed": ok, "description": description}
79
-
80
- # SQL injection safety check: payload must NOT appear in query, placeholder must exist
81
  if tc.get("sql_injection_check"):
82
  if not _is_seq(output) or len(output) < 2:
83
- return {"passed": False, "description": description, "error": "Not a 2-element tuple"}
84
- query = str(output[0])
85
  payload_val = inputs[0] if inputs else ""
86
- has_placeholder = any(p in query for p in ("%s", "?", ":param", "%(username"))
87
- payload_not_in_query = str(payload_val).strip() not in query
88
- ok = has_placeholder and payload_not_in_query
89
- return {"passed": ok, "description": description,
90
- "note": f"placeholder={has_placeholder} payload_safe={payload_not_in_query}"}
 
 
 
91
 
92
- # Standard equality check
93
  if "expected" in tc:
94
- expected = tc["expected"]
95
- ok = output == expected
96
- return {"passed": ok, "description": description, "got": output, "expected": expected}
97
 
98
- # Type check (JSON serialization converts tuple→list, so treat them as equivalent)
99
  if "expected_type" in tc:
100
- type_name = tc["expected_type"]
101
- actual_type = type(output).__name__
102
- # tuple and list are equivalent after JSON round-trip
103
- equivalent = {("tuple", "list"), ("list", "tuple")}
104
- ok = actual_type == type_name or (actual_type, type_name) in equivalent or (type_name, actual_type) in equivalent
105
  if ok and "expected_len" in tc:
106
  ok = hasattr(output, "__len__") and len(output) == tc["expected_len"]
107
- return {"passed": ok, "description": description, "got_type": actual_type}
108
 
109
- # Contains check
110
  if "expected_contains" in tc:
111
- ok = tc["expected_contains"] in str(output)
112
- return {"passed": ok, "description": description}
113
 
114
- # Not-contains check
115
  if "expected_not_contains" in tc:
116
  forbidden = tc["expected_not_contains"]
117
  if isinstance(forbidden, list):
118
  ok = not any(f in str(output) for f in forbidden)
119
  else:
120
  ok = forbidden not in str(output)
121
- return {"passed": ok, "description": description, "got": str(output)[:100]}
122
 
123
- # Min length check
124
  if "expected_min_len" in tc:
125
- ok = output is not None and len(str(output)) >= tc["expected_min_len"]
126
- return {"passed": ok, "description": description}
127
 
128
- # Max length check
129
  if "expected_max_len" in tc:
130
- ok = output is not None and len(str(output)) <= tc["expected_max_len"]
131
- return {"passed": ok, "description": description}
132
 
133
- # Ok-flag check (for validate_upload style returns)
134
  if "expected_ok" in tc:
135
- ok = isinstance(output, dict) and output.get("ok") == tc["expected_ok"]
136
- return {"passed": ok, "description": description}
137
 
138
- # No expected value defined β€” just check it didn't crash
139
- return {"passed": True, "description": description, "note": "No assertion defined"}
140
 
141
 
142
  def _run_class_test(code: str, tc: dict) -> dict:
143
- """Run a test against a class-based task (e.g. RateLimiter)."""
144
  class_name = tc.get("fn_class", "Solution")
145
- init_args = tc.get("init_args", [])
146
- method = tc.get("method", "is_allowed")
147
- inputs = tc.get("input", [])
148
- description = tc.get("description", "")
149
 
150
- harness_code = f"""
151
  {code}
152
 
153
  def run_task(args):
154
- init_args = args[0]
155
- method = args[1]
156
- inputs = args[2]
157
  obj = {class_name}(*init_args)
158
  if method == "is_allowed_multi":
159
  result = None
@@ -161,34 +140,23 @@ def run_task(args):
161
  result = obj.is_allowed(inputs[0])
162
  return result
163
  if method == "independent_clients":
164
- r1 = obj.is_allowed("client_a")
165
- r2 = obj.is_allowed("client_b")
166
- return r1 == r2 == True
167
- fn = getattr(obj, method)
168
- return fn(*inputs)
169
  """
170
- test_input = [[init_args, method, inputs]] # wrap in list so safe_exec unpacks correctly
171
- result = safe_exec(harness_code, test_input, function_name="run_task", timeout=5)
172
-
173
  if not result["ok"]:
174
- return {"passed": False, "description": description, "error": result.get("error", "")[:200]}
175
-
176
  output = result.get("output")
177
  if "expected" in tc:
178
- ok = output == tc["expected"]
179
- return {"passed": ok, "description": description}
180
  if "expected_last" in tc:
181
- ok = output == tc["expected_last"]
182
- return {"passed": ok, "description": description}
183
- return {"passed": True, "description": description}
184
-
185
-
186
- def _correctness_feedback(score: float, passed: int, total: int) -> str:
187
- if score >= 0.9:
188
- return f"Excellent β€” {passed}/{total} tests passed"
189
- elif score >= 0.7:
190
- return f"Good β€” {passed}/{total} tests passed. Minor edge cases missing"
191
- elif score >= 0.5:
192
- return f"Partial β€” {passed}/{total} tests passed. Fix failing cases"
193
- else:
194
- return f"Poor β€” {passed}/{total} tests passed. Core logic incorrect"
 
1
  """
2
+ SecureCodeEnv - Correctness Grader v4
3
+ Weight: 25% of total reward.
4
+ All scores clamped to (0.001, 0.999).
5
  """
6
  from sandbox.executor import safe_exec
7
+ from graders.clamp import clamp
8
+
9
 
10
  def _is_seq(v):
11
  return isinstance(v, (list, tuple))
12
 
13
 
14
  def grade_correctness(code: str, task: dict) -> dict:
 
 
 
 
 
 
 
 
 
 
 
15
  test_cases = task.get("test_cases", [])
16
  if not test_cases:
17
+ return {"score": clamp(0.5), "passed": 0, "total": 0,
18
+ "details": [], "feedback": "No test cases defined"}
19
 
20
  passed = 0
21
  details = []
 
22
  for tc in test_cases:
23
  result = _run_test_case(code, tc)
24
  if result["passed"]:
25
  passed += 1
26
  details.append(result)
27
 
28
+ raw = passed / len(test_cases)
29
  return {
30
+ "score": clamp(raw),
31
  "passed": passed,
32
  "total": len(test_cases),
33
  "details": details,
34
+ "feedback": _feedback(raw, passed, len(test_cases)),
35
  }
36
 
37
 
38
  def _run_test_case(code: str, tc: dict) -> dict:
 
39
  fn_name = tc.get("fn", "solution")
40
+ inputs = tc.get("input", [])
41
+ desc = tc.get("description", "")
42
 
 
43
  if "fn_class" in tc:
44
  return _run_class_test(code, tc)
45
 
 
47
 
48
  if not exec_result["ok"]:
49
  expected_exc = tc.get("expected_exception")
50
+ error_str = exec_result.get("error", "")
51
+ exc_type = exec_result.get("type", "")
52
  if expected_exc:
53
+ if (exc_type == expected_exc or
54
+ expected_exc.lower() in error_str.lower() or
55
+ expected_exc.lower() in exc_type.lower()):
56
+ return {"passed": True, "description": desc,
57
+ "note": f"Expected {expected_exc} raised"}
58
+ return {"passed": False, "description": desc,
59
+ "error": error_str[:200]}
 
60
 
61
  output = exec_result.get("output")
62
 
63
+ # SQL injection parameterization check
 
 
 
 
 
64
  if tc.get("sql_injection_check"):
65
  if not _is_seq(output) or len(output) < 2:
66
+ return {"passed": False, "description": desc, "error": "Not a 2-element sequence"}
67
+ query = str(output[0])
68
  payload_val = inputs[0] if inputs else ""
69
+ has_ph = any(p in query for p in ("%s", "?", ":param", "%(username"))
70
+ safe = str(payload_val).strip() not in query
71
+ return {"passed": has_ph and safe, "description": desc,
72
+ "note": f"placeholder={has_ph} payload_safe={safe}"}
73
+
74
+ # Not-None
75
+ if "expected_not_none" in tc:
76
+ return {"passed": output is not None, "description": desc}
77
 
78
+ # Equality
79
  if "expected" in tc:
80
+ return {"passed": output == tc["expected"], "description": desc,
81
+ "got": output, "expected": tc["expected"]}
 
82
 
83
+ # Type check (JSON converts tuple→list)
84
  if "expected_type" in tc:
85
+ tname = tc["expected_type"]
86
+ atype = type(output).__name__
87
+ equiv = {("tuple","list"),("list","tuple")}
88
+ ok = atype == tname or (atype, tname) in equiv or (tname, atype) in equiv
 
89
  if ok and "expected_len" in tc:
90
  ok = hasattr(output, "__len__") and len(output) == tc["expected_len"]
91
+ return {"passed": ok, "description": desc, "got_type": atype}
92
 
93
+ # Contains
94
  if "expected_contains" in tc:
95
+ return {"passed": tc["expected_contains"] in str(output), "description": desc}
 
96
 
97
+ # Not-contains
98
  if "expected_not_contains" in tc:
99
  forbidden = tc["expected_not_contains"]
100
  if isinstance(forbidden, list):
101
  ok = not any(f in str(output) for f in forbidden)
102
  else:
103
  ok = forbidden not in str(output)
104
+ return {"passed": ok, "description": desc, "got": str(output)[:100]}
105
 
106
+ # Min length
107
  if "expected_min_len" in tc:
108
+ return {"passed": output is not None and len(str(output)) >= tc["expected_min_len"],
109
+ "description": desc}
110
 
111
+ # Max length
112
  if "expected_max_len" in tc:
113
+ return {"passed": output is not None and len(str(output)) <= tc["expected_max_len"],
114
+ "description": desc}
115
 
116
+ # Ok-flag (dict with "ok" key)
117
  if "expected_ok" in tc:
118
+ return {"passed": isinstance(output, dict) and output.get("ok") == tc["expected_ok"],
119
+ "description": desc}
120
 
121
+ return {"passed": True, "description": desc, "note": "No assertion"}
 
122
 
123
 
124
  def _run_class_test(code: str, tc: dict) -> dict:
 
125
  class_name = tc.get("fn_class", "Solution")
126
+ init_args = tc.get("init_args", [])
127
+ method = tc.get("method", "is_allowed")
128
+ inputs = tc.get("input", [])
129
+ desc = tc.get("description", "")
130
 
131
+ harness = f"""
132
  {code}
133
 
134
  def run_task(args):
135
+ init_args = args[0]; method = args[1]; inputs = args[2]
 
 
136
  obj = {class_name}(*init_args)
137
  if method == "is_allowed_multi":
138
  result = None
 
140
  result = obj.is_allowed(inputs[0])
141
  return result
142
  if method == "independent_clients":
143
+ return obj.is_allowed("client_a") == obj.is_allowed("client_b") == True
144
+ return getattr(obj, method)(*inputs)
 
 
 
145
  """
146
+ result = safe_exec(harness, [[init_args, method, inputs]],
147
+ function_name="run_task", timeout=5)
 
148
  if not result["ok"]:
149
+ return {"passed": False, "description": desc, "error": result.get("error","")[:200]}
 
150
  output = result.get("output")
151
  if "expected" in tc:
152
+ return {"passed": output == tc["expected"], "description": desc}
 
153
  if "expected_last" in tc:
154
+ return {"passed": output == tc["expected_last"], "description": desc}
155
+ return {"passed": True, "description": desc}
156
+
157
+
158
+ def _feedback(score: float, passed: int, total: int) -> str:
159
+ if score >= 0.9: return f"Excellent β€” {passed}/{total} tests passed"
160
+ elif score >= 0.7: return f"Good β€” {passed}/{total} tests passed"
161
+ elif score >= 0.5: return f"Partial β€” {passed}/{total} tests passed"
162
+ else: return f"Poor β€” {passed}/{total} tests passed"
 
 
 
 
 
graders/documentation.py CHANGED
@@ -1,142 +1,76 @@
1
- """
2
- SecureCodeEnv - Documentation & Code Structure Graders
3
- Documentation weight: 5% | Code Structure weight: 5%
4
- """
5
  import ast
 
6
 
7
 
8
  def grade_documentation(code: str) -> dict:
9
- """
10
- Grade docstring and type hint coverage.
11
- Rewards: functions with docstrings, full type annotations, module docstring.
12
-
13
- Returns:
14
- {"score": float, "documented_fns": int, "total_fns": int, "feedback": str}
15
- """
16
  try:
17
  tree = ast.parse(code)
18
  except SyntaxError:
19
- return {"score": 0.0, "documented_fns": 0, "total_fns": 0, "feedback": "Syntax error β€” cannot parse"}
20
-
21
- functions = [
22
- n for n in ast.walk(tree)
23
- if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
24
- ]
25
 
 
 
26
  if not functions:
27
- # No functions β€” check for module docstring
28
- has_module_doc = bool(ast.get_docstring(tree))
29
- return {
30
- "score": 1.0 if has_module_doc else 0.7,
31
- "documented_fns": 0,
32
- "total_fns": 0,
33
- "feedback": "No functions found β€” module-level code only",
34
- }
35
 
36
- documented = 0
37
- typed = 0
38
  scores = []
39
-
40
  for fn in functions:
41
- fn_score = 0.0
42
- has_doc = bool(ast.get_docstring(fn))
43
- has_return_type = fn.returns is not None
44
- has_param_types = any(a.annotation is not None for a in fn.args.args)
45
- has_any_types = has_return_type or has_param_types
46
-
47
- if has_doc:
48
- documented += 1
49
- fn_score += 0.5
50
-
51
- if has_any_types:
52
- typed += 1
53
- fn_score += 0.5
54
-
55
- scores.append(fn_score)
56
 
57
  total = len(functions)
58
- score = sum(scores) / total if total > 0 else 1.0
59
-
60
- return {
61
- "score": round(score, 4),
62
- "documented_fns": documented,
63
- "typed_fns": typed,
64
- "total_fns": total,
65
- "feedback": _doc_feedback(score, documented, typed, total),
66
- }
67
 
68
 
69
  def grade_code_structure(code: str) -> dict:
70
- """
71
- Grade code structure quality:
72
- - No bare print() statements
73
- - Exception handling present where needed
74
- - No bare except clauses
75
- - No hardcoded magic strings
76
- - Functions not excessively long (>50 lines)
77
-
78
- Returns:
79
- {"score": float, "checks": dict, "feedback": str}
80
- """
81
  try:
82
  tree = ast.parse(code)
83
  except SyntaxError:
84
- return {"score": 0.0, "checks": {}, "feedback": "Syntax error"}
85
 
86
- checks: dict[str, bool] = {}
87
  lines = code.splitlines()
88
-
89
- # Check 1: No bare print statements (use logging)
90
- checks["no_bare_print"] = "print(" not in code
91
-
92
- # Check 2: No bare except (catches all exceptions silently)
93
- bare_except = False
94
- for node in ast.walk(tree):
95
- if isinstance(node, ast.ExceptHandler) and node.type is None:
96
- bare_except = True
97
- break
98
- checks["no_bare_except"] = not bare_except
99
-
100
- # Check 3: Functions are reasonably sized (<= 50 lines)
101
- oversized = False
102
- for node in ast.walk(tree):
103
- if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
104
- fn_lines = (node.end_lineno or 0) - node.lineno
105
- if fn_lines > 50:
106
- oversized = True
107
- break
108
- checks["reasonable_fn_size"] = not oversized
109
-
110
- # Check 4: No TODO/FIXME/HACK comments left in production code
111
- has_todo = any(
112
- "# TODO" in line.upper() or "# FIXME" in line.upper() or "# HACK" in line.upper()
113
- for line in lines
114
- )
115
- checks["no_todo_comments"] = not has_todo
116
-
117
- # Check 5: Handles None inputs (basic check)
118
- checks["handles_none"] = "None" in code or "is not None" in code or "if not " in code
119
-
120
- score = sum(1 for v in checks.values() if v) / max(len(checks), 1)
121
-
122
- return {
123
- "score": round(score, 4),
124
- "checks": checks,
125
- "feedback": _structure_feedback(score, checks),
126
- }
127
-
128
-
129
- def _doc_feedback(score: float, documented: int, typed: int, total: int) -> str:
130
- if score >= 0.9:
131
- return f"Well documented β€” {documented}/{total} functions have docstrings, {typed}/{total} typed"
132
- elif score >= 0.6:
133
- return f"Partial documentation β€” {documented}/{total} docstrings, {typed}/{total} type hints"
134
- else:
135
- return f"Poor documentation β€” add docstrings and type hints to all {total} functions"
136
-
137
-
138
- def _structure_feedback(score: float, checks: dict) -> str:
139
- if score >= 0.9:
140
- return "Clean code structure"
141
  failing = [k for k, v in checks.items() if not v]
142
  return f"Structure issues: {', '.join(failing)}"
 
1
+ """SecureCodeEnv - Documentation & Structure Graders v4 β€” clamped scores"""
 
 
 
2
  import ast
3
+ from graders.clamp import clamp
4
 
5
 
6
  def grade_documentation(code: str) -> dict:
 
 
 
 
 
 
 
7
  try:
8
  tree = ast.parse(code)
9
  except SyntaxError:
10
+ return {"score": clamp(0.0), "documented_fns": 0, "total_fns": 0,
11
+ "feedback": "Syntax error β€” cannot parse"}
 
 
 
 
12
 
13
+ functions = [n for n in ast.walk(tree)
14
+ if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
15
  if not functions:
16
+ has_mod_doc = bool(ast.get_docstring(tree))
17
+ return {"score": clamp(0.65 if has_mod_doc else 0.5),
18
+ "documented_fns": 0, "total_fns": 0,
19
+ "feedback": "No functions β€” module-level code only"}
 
 
 
 
20
 
 
 
21
  scores = []
22
+ documented = typed = 0
23
  for fn in functions:
24
+ s = 0.0
25
+ hd = bool(ast.get_docstring(fn))
26
+ hr = fn.returns is not None
27
+ hp = any(a.annotation is not None for a in fn.args.args)
28
+ if hd: documented += 1; s += 0.5
29
+ if hr or hp: typed += 1; s += 0.5
30
+ scores.append(s)
 
 
 
 
 
 
 
 
31
 
32
  total = len(functions)
33
+ raw = sum(scores) / total
34
+ return {"score": clamp(raw), "documented_fns": documented,
35
+ "typed_fns": typed, "total_fns": total,
36
+ "feedback": _doc_feedback(raw, documented, typed, total)}
 
 
 
 
 
37
 
38
 
39
  def grade_code_structure(code: str) -> dict:
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
  tree = ast.parse(code)
42
  except SyntaxError:
43
+ return {"score": clamp(0.0), "checks": {}, "feedback": "Syntax error"}
44
 
 
45
  lines = code.splitlines()
46
+ checks = {}
47
+ checks["no_bare_print"] = "print(" not in code
48
+ checks["no_bare_except"] = not any(
49
+ isinstance(n, ast.ExceptHandler) and n.type is None
50
+ for n in ast.walk(tree))
51
+ checks["reasonable_fn_size"] = not any(
52
+ isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) and
53
+ (n.end_lineno or 0) - n.lineno > 50
54
+ for n in ast.walk(tree))
55
+ checks["no_todo_comments"] = not any(
56
+ any(kw in line.upper() for kw in ["# TODO", "# FIXME", "# HACK"])
57
+ for line in lines)
58
+ checks["handles_none"] = any(
59
+ token in code for token in
60
+ ["None", "is not None", "if not ", "Optional", "is None"])
61
+
62
+ raw = sum(1 for v in checks.values() if v) / max(len(checks), 1)
63
+ return {"score": clamp(raw), "checks": checks,
64
+ "feedback": _struct_feedback(raw, checks)}
65
+
66
+
67
+ def _doc_feedback(score, documented, typed, total):
68
+ if score >= 0.85: return f"Well documented β€” {documented}/{total} docstrings, {typed}/{total} typed"
69
+ elif score >= 0.55: return f"Partial β€” {documented}/{total} docstrings, {typed}/{total} type hints"
70
+ return f"Poor β€” add docstrings and type hints to all {total} functions"
71
+
72
+
73
+ def _struct_feedback(score, checks):
74
+ if score >= 0.85: return "Clean code structure"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  failing = [k for k, v in checks.items() if not v]
76
  return f"Structure issues: {', '.join(failing)}"
graders/performance.py CHANGED
@@ -1,63 +1,93 @@
1
  """
2
- SecureCodeEnv - Performance Grader v3
3
- FIXED: 0ms measurement now returns 0.6 (neutral) not 1.0
 
 
 
 
 
 
4
  """
5
  import sys, tempfile, os, json, subprocess
 
 
 
6
 
7
 
8
  def grade_performance(code: str, task: dict) -> dict:
9
- test_cases = task.get("test_cases", [])
10
- naive_code = task.get("naive_code", "")
11
  optimal_code = task.get("optimal_code", "")
12
 
13
  if not test_cases or not naive_code or not optimal_code:
14
- return {"score": 0.6, "time_score": 0.6, "memory_score": 0.6,
15
- "feedback": "No baselines defined β€” neutral score applied"}
16
 
17
  tc = next((t for t in test_cases
18
  if "fn" in t and "input" in t
19
  and "fn_class" not in t
20
  and "expected_exception" not in t), None)
21
  if not tc:
22
- return {"score": 0.6, "time_score": 0.6, "memory_score": 0.6,
23
- "feedback": "No suitable test case β€” neutral score applied"}
24
 
25
  fn_name = tc["fn"]
26
- inputs = tc["input"]
27
 
28
  try:
29
- agent_ms = _measure_ms(code, fn_name, inputs)
30
- naive_ms = _measure_ms(naive_code, fn_name, inputs)
31
  optimal_ms = _measure_ms(optimal_code, fn_name, inputs)
32
 
33
- # FIXED: if measurements indistinguishable, return neutral 0.6
34
- if abs(naive_ms - optimal_ms) < 0.001:
35
- return {"score": 0.6, "time_score": 0.6, "memory_score": 0.6,
36
- "agent_ms": round(agent_ms, 3),
37
- "naive_ms": round(naive_ms, 3),
38
- "optimal_ms": round(optimal_ms, 3),
39
- "feedback": "Functions too fast to differentiate β€” neutral score"}
 
 
 
 
 
 
 
40
 
41
- time_range = max(naive_ms - optimal_ms, 0.01)
42
- raw = 1.0 - ((agent_ms - optimal_ms) / time_range)
43
- time_score = max(0.0, min(1.0, raw))
44
- combined = round((time_score * 0.7) + (time_score * 0.3), 4)
45
 
46
  return {
47
- "score": combined,
48
- "time_score": round(time_score, 4),
49
- "memory_score": round(time_score, 4),
50
- "agent_ms": round(agent_ms, 3),
51
- "naive_ms": round(naive_ms, 3),
52
  "optimal_ms": round(optimal_ms, 3),
53
- "feedback": _feedback(combined),
54
  }
55
  except Exception as e:
56
- return {"score": 0.6, "time_score": 0.6, "memory_score": 0.6,
 
57
  "feedback": f"Measurement error: {str(e)[:60]}"}
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def _measure_ms(code: str, fn_name: str, inputs: list, runs: int = 50) -> float:
 
61
  script = f"""
62
  import timeit, json, sys
63
  {code}
@@ -79,7 +109,7 @@ sys.stdout.flush()
79
  line = line.strip()
80
  if line.startswith("{"):
81
  return json.loads(line)["ms"]
82
- return -1.0 # Signal unmeasurable
83
  except Exception:
84
  return -1.0
85
  finally:
@@ -89,7 +119,7 @@ sys.stdout.flush()
89
 
90
 
91
  def _feedback(score: float) -> str:
92
- if score >= 0.9: return "Excellent β€” near-optimal efficiency"
93
- elif score >= 0.7: return "Good β€” minor optimisation possible"
94
- elif score >= 0.5: return "Acceptable β€” room for improvement"
95
- else: return "Poor β€” significant performance gap vs optimal"
 
1
  """
2
+ SecureCodeEnv - Performance Grader v4
3
+
4
+ FIXES:
5
+ - Inverted baseline (naive faster than optimal) β†’ return neutral 0.5
6
+ - Unmeasurable (-1.0) β†’ return neutral 0.5
7
+ - Both timings identical β†’ return neutral 0.5
8
+ - Agent faster than optimal β†’ clamp to max 0.999 (not >1.0)
9
+ - All scores clamped to (0.001, 0.999)
10
  """
11
  import sys, tempfile, os, json, subprocess
12
+ from graders.clamp import clamp
13
+
14
+ NEUTRAL = 0.5 # returned when measurement is unreliable
15
 
16
 
17
  def grade_performance(code: str, task: dict) -> dict:
18
+ test_cases = task.get("test_cases", [])
19
+ naive_code = task.get("naive_code", "")
20
  optimal_code = task.get("optimal_code", "")
21
 
22
  if not test_cases or not naive_code or not optimal_code:
23
+ return {"score": clamp(NEUTRAL), "time_score": clamp(NEUTRAL),
24
+ "memory_score": clamp(NEUTRAL), "feedback": "No baselines β€” neutral score"}
25
 
26
  tc = next((t for t in test_cases
27
  if "fn" in t and "input" in t
28
  and "fn_class" not in t
29
  and "expected_exception" not in t), None)
30
  if not tc:
31
+ return {"score": clamp(NEUTRAL), "time_score": clamp(NEUTRAL),
32
+ "memory_score": clamp(NEUTRAL), "feedback": "No usable test case β€” neutral score"}
33
 
34
  fn_name = tc["fn"]
35
+ inputs = tc["input"]
36
 
37
  try:
38
+ agent_ms = _measure_ms(code, fn_name, inputs)
39
+ naive_ms = _measure_ms(naive_code, fn_name, inputs)
40
  optimal_ms = _measure_ms(optimal_code, fn_name, inputs)
41
 
42
+ # Any unmeasurable result β†’ neutral
43
+ if any(x < 0 for x in [agent_ms, naive_ms, optimal_ms]):
44
+ return _neutral(agent_ms, naive_ms, optimal_ms, "Unmeasurable timing")
45
+
46
+ # Indistinguishable β†’ neutral
47
+ if abs(naive_ms - optimal_ms) < 0.05:
48
+ return _neutral(agent_ms, naive_ms, optimal_ms, "Timings indistinguishable")
49
+
50
+ # Inverted baseline (naive < optimal means naive is actually "better")
51
+ # This happens when optimal uses safer-but-slower code (e.g. Path.resolve vs os.path.join)
52
+ # In that case performance cannot be meaningfully scored β†’ neutral
53
+ if naive_ms < optimal_ms:
54
+ return _neutral(agent_ms, naive_ms, optimal_ms,
55
+ "Baseline inverted (naive faster than optimal) β€” neutral")
56
 
57
+ time_range = naive_ms - optimal_ms
58
+ raw = 1.0 - ((agent_ms - optimal_ms) / time_range)
59
+ # raw > 1.0 when agent faster than optimal β†’ clamp handles it
60
+ time_score = clamp(raw)
61
 
62
  return {
63
+ "score": time_score,
64
+ "time_score": time_score,
65
+ "memory_score": time_score,
66
+ "agent_ms": round(agent_ms, 3),
67
+ "naive_ms": round(naive_ms, 3),
68
  "optimal_ms": round(optimal_ms, 3),
69
+ "feedback": _feedback(time_score),
70
  }
71
  except Exception as e:
72
+ return {"score": clamp(NEUTRAL), "time_score": clamp(NEUTRAL),
73
+ "memory_score": clamp(NEUTRAL),
74
  "feedback": f"Measurement error: {str(e)[:60]}"}
75
 
76
 
77
+ def _neutral(agent_ms, naive_ms, optimal_ms, reason: str) -> dict:
78
+ return {
79
+ "score": clamp(NEUTRAL),
80
+ "time_score": clamp(NEUTRAL),
81
+ "memory_score": clamp(NEUTRAL),
82
+ "agent_ms": round(agent_ms, 3) if agent_ms >= 0 else None,
83
+ "naive_ms": round(naive_ms, 3) if naive_ms >= 0 else None,
84
+ "optimal_ms": round(optimal_ms, 3) if optimal_ms >= 0 else None,
85
+ "feedback": reason,
86
+ }
87
+
88
+
89
  def _measure_ms(code: str, fn_name: str, inputs: list, runs: int = 50) -> float:
90
+ """Returns ms or -1.0 if unmeasurable."""
91
  script = f"""
92
  import timeit, json, sys
93
  {code}
 
109
  line = line.strip()
110
  if line.startswith("{"):
111
  return json.loads(line)["ms"]
112
+ return -1.0
113
  except Exception:
114
  return -1.0
115
  finally:
 
119
 
120
 
121
  def _feedback(score: float) -> str:
122
+ if score >= 0.85: return "Excellent β€” near-optimal efficiency"
123
+ elif score >= 0.65: return "Good β€” minor optimisation possible"
124
+ elif score >= 0.45: return "Acceptable β€” room for improvement"
125
+ else: return "Poor β€” significant performance gap"
graders/reward_aggregator.py CHANGED
@@ -1,58 +1,52 @@
1
  """
2
- SecureCodeEnv - Reward Aggregator v3
3
-
4
- KEY CHANGES:
5
- 1. SECURITY GATE: episode cannot be DONE unless attack_resist >= 0.75
6
- AND static_security >= 0.70. Prevents insecure code from "winning".
7
- 2. Weights rebalanced: static_security raised, performance lowered.
8
- 3. DONE threshold raised to 0.92.
9
- 4. Security floor penalty: if attack_resist < 0.5 OR static_security < 0.5,
10
- total reward is capped at 0.65 (cannot fool the system with correctness alone).
11
  """
12
- from graders.correctness import grade_correctness
13
- from graders.attacks import grade_attacks
14
  from graders.static_analysis import grade_static_analysis
15
- from graders.performance import grade_performance
16
- from graders.consistency import grade_consistency
17
- from graders.documentation import grade_documentation, grade_code_structure
18
- from codegraph.extractor import extract_metadata
19
- from codegraph.graph import CodeGraph
 
20
 
21
- # REBALANCED weights β€” security dimensions raised
22
  WEIGHTS = {
23
- "correctness": 0.25, # was 0.30 β€” still most important
24
- "attack_resist": 0.25, # was 0.20 β€” raised: must resist real attacks
25
- "static_security": 0.20, # was 0.15 β€” raised: must pass static analysis
26
- "consistency": 0.10, # was 0.15 β€” reduced
27
- "performance": 0.08, # was 0.10 β€” reduced (too noisy)
28
- "documentation": 0.07, # was 0.05
29
- "code_structure": 0.05, # unchanged
30
  }
31
  assert abs(sum(WEIGHTS.values()) - 1.0) < 1e-9
32
 
33
- # Security gate: these minimums must ALL be met for episode to be DONE
34
  SECURITY_GATE = {
35
- "attack_resist": 0.75, # Must block >= 75% of attacks
36
- "static_security": 0.70, # Must pass >= 70% of static checks
37
- "correctness": 0.80, # Must pass >= 80% of tests
38
  }
39
- DONE_THRESHOLD = 0.92 # Raised from 0.90
40
-
41
- # Security floor: if security is critically low, cap total reward
42
  SECURITY_FLOOR_DIMS = ["attack_resist", "static_security"]
43
- SECURITY_FLOOR_THRESHOLD = 0.50
44
- SECURITY_FLOOR_CAP = 0.65
45
 
46
 
47
  def grade_submission(code, filename, task, graph, step, seed):
48
- corr = grade_correctness(code, task)
49
- atk = grade_attacks(code, task, seed)
50
- stat = grade_static_analysis(code, task)
51
- perf = grade_performance(code, task)
52
- cons = grade_consistency(code, filename, graph, step)
53
- doc = grade_documentation(code)
54
  struct = grade_code_structure(code)
55
 
 
56
  scores = {
57
  "correctness": corr["score"],
58
  "attack_resist": atk["score"],
@@ -63,21 +57,19 @@ def grade_submission(code, filename, task, graph, step, seed):
63
  "code_structure": struct["score"],
64
  }
65
 
66
- raw_reward = sum(scores[k] * WEIGHTS[k] for k in WEIGHTS)
 
67
 
68
- # SECURITY FLOOR: insecure code is capped regardless of correctness
69
- security_critical_fail = any(
70
- scores[dim] < SECURITY_FLOOR_THRESHOLD
71
- for dim in SECURITY_FLOOR_DIMS
72
- )
73
- if security_critical_fail:
74
- raw_reward = min(raw_reward, SECURITY_FLOOR_CAP)
75
 
76
- total_reward = round(max(0.0, min(1.0, raw_reward)), 4)
 
77
 
78
- # SECURITY GATE for done determination
79
- gate_passed = all(scores[dim] >= threshold
80
- for dim, threshold in SECURITY_GATE.items())
81
  done_eligible = total_reward >= DONE_THRESHOLD and gate_passed
82
 
83
  feedback = {
@@ -94,45 +86,43 @@ def grade_submission(code, filename, task, graph, step, seed):
94
 
95
  details = {
96
  "correctness": {"passed": corr.get("passed"), "total": corr.get("total")},
97
- "attacks": {"blocked": atk.get("blocked"), "total": atk.get("total"),
98
- "type": atk.get("attack_type")},
99
- "static": {"bandit_score": stat.get("bandit_score"),
100
- "hard_fail": stat.get("hard_fail", False),
101
- "issues": stat.get("issues", [])[:3]},
102
  "security_gate_passed": gate_passed,
103
- "done_eligible": done_eligible,
104
  }
105
 
106
  return {
107
- "scores": scores,
108
- "total_reward": total_reward,
109
  "done_eligible": done_eligible,
110
- "feedback": feedback,
111
- "details": details,
112
- "agent_ms": perf.get("agent_ms"),
113
- "naive_ms": perf.get("naive_ms"),
114
- "optimal_ms": perf.get("optimal_ms"),
115
- "new_metadata": extract_metadata(code, filename, step),
116
  }
117
 
118
 
119
- def _gate_status(scores: dict) -> str:
120
- failing = [f"{dim} ({scores[dim]:.2f} < {thr})"
121
- for dim, thr in SECURITY_GATE.items()
122
- if scores[dim] < thr]
123
- return f"BLOCKED β€” security gate not met: {', '.join(failing)}"
124
 
125
 
126
  def _summary(reward, scores, gate_passed):
127
  if reward >= DONE_THRESHOLD and gate_passed:
128
- return f"βœ… Excellent ({reward:.3f}) β€” production-ready, security gate passed"
129
  if not gate_passed:
130
- gate_msg = _gate_status(scores)
131
- return f"πŸ”’ {gate_msg} (reward: {reward:.3f})"
132
  if reward >= 0.75:
133
  weakest = min(scores, key=scores.get)
134
  return f"🟑 Good ({reward:.3f}) β€” improve: {weakest} ({scores[weakest]:.2f})"
135
  if reward >= 0.55:
136
  weak = [k for k, v in scores.items() if v < 0.5]
137
  return f"🟠 Needs work ({reward:.3f}) β€” fix: {', '.join(weak[:3])}"
138
- return f"πŸ”΄ Poor ({reward:.3f}) β€” major security/correctness failures"
 
1
  """
2
+ SecureCodeEnv - Reward Aggregator v4
3
+
4
+ ALL scores are epsilon-clamped to (0.001 … 0.999).
5
+ Security gate: episode done only when attackβ‰₯0.75, staticβ‰₯0.70, correctnessβ‰₯0.80.
6
+ Security floor: if attack<0.5 OR static<0.5, total capped at 0.65.
7
+ DONE threshold: 0.92 (after clamping).
 
 
 
8
  """
9
+ from graders.correctness import grade_correctness
10
+ from graders.attacks import grade_attacks
11
  from graders.static_analysis import grade_static_analysis
12
+ from graders.performance import grade_performance
13
+ from graders.consistency import grade_consistency
14
+ from graders.documentation import grade_documentation, grade_code_structure
15
+ from graders.clamp import clamp
16
+ from codegraph.extractor import extract_metadata
17
+ from codegraph.graph import CodeGraph
18
 
 
19
  WEIGHTS = {
20
+ "correctness": 0.25,
21
+ "attack_resist": 0.25,
22
+ "static_security": 0.20,
23
+ "consistency": 0.10,
24
+ "performance": 0.08,
25
+ "documentation": 0.07,
26
+ "code_structure": 0.05,
27
  }
28
  assert abs(sum(WEIGHTS.values()) - 1.0) < 1e-9
29
 
 
30
  SECURITY_GATE = {
31
+ "attack_resist": 0.75,
32
+ "static_security": 0.70,
33
+ "correctness": 0.80,
34
  }
35
+ DONE_THRESHOLD = 0.92
 
 
36
  SECURITY_FLOOR_DIMS = ["attack_resist", "static_security"]
37
+ SECURITY_FLOOR_CAP = 0.65 # unclamped; after clamping β†’ ≀ 0.649
 
38
 
39
 
40
  def grade_submission(code, filename, task, graph, step, seed):
41
+ corr = grade_correctness(code, task)
42
+ atk = grade_attacks(code, task, seed)
43
+ stat = grade_static_analysis(code, task)
44
+ perf = grade_performance(code, task)
45
+ cons = grade_consistency(code, filename, graph, step)
46
+ doc = grade_documentation(code)
47
  struct = grade_code_structure(code)
48
 
49
+ # All individual dimension scores are already clamped by each grader.
50
  scores = {
51
  "correctness": corr["score"],
52
  "attack_resist": atk["score"],
 
57
  "code_structure": struct["score"],
58
  }
59
 
60
+ # Weighted sum
61
+ raw = sum(scores[k] * WEIGHTS[k] for k in WEIGHTS)
62
 
63
+ # Security floor
64
+ security_fail = any(scores[d] < 0.5 for d in SECURITY_FLOOR_DIMS)
65
+ if security_fail:
66
+ raw = min(raw, SECURITY_FLOOR_CAP)
 
 
 
67
 
68
+ # Final clamp β€” guarantees (0.001 … 0.999)
69
+ total_reward = clamp(raw)
70
 
71
+ # Security gate for done
72
+ gate_passed = all(scores[d] >= thr for d, thr in SECURITY_GATE.items())
 
73
  done_eligible = total_reward >= DONE_THRESHOLD and gate_passed
74
 
75
  feedback = {
 
86
 
87
  details = {
88
  "correctness": {"passed": corr.get("passed"), "total": corr.get("total")},
89
+ "attacks": {"blocked": atk.get("blocked"), "total": atk.get("total"),
90
+ "type": atk.get("attack_type")},
91
+ "static": {"bandit_score": stat.get("bandit_score"),
92
+ "hard_fail": stat.get("hard_fail", False),
93
+ "issues": stat.get("issues", [])[:3]},
94
  "security_gate_passed": gate_passed,
95
+ "done_eligible": done_eligible,
96
  }
97
 
98
  return {
99
+ "scores": scores,
100
+ "total_reward": total_reward,
101
  "done_eligible": done_eligible,
102
+ "feedback": feedback,
103
+ "details": details,
104
+ "agent_ms": perf.get("agent_ms"),
105
+ "naive_ms": perf.get("naive_ms"),
106
+ "optimal_ms": perf.get("optimal_ms"),
107
+ "new_metadata": extract_metadata(code, filename, step),
108
  }
109
 
110
 
111
+ def _gate_status(scores):
112
+ failing = [f"{d}({scores[d]:.2f}<{thr})"
113
+ for d, thr in SECURITY_GATE.items() if scores[d] < thr]
114
+ return f"BLOCKED β€” {', '.join(failing)}"
 
115
 
116
 
117
  def _summary(reward, scores, gate_passed):
118
  if reward >= DONE_THRESHOLD and gate_passed:
119
+ return f"βœ… Excellent ({reward:.3f}) β€” security gate passed"
120
  if not gate_passed:
121
+ return f"πŸ”’ {_gate_status(scores)} (reward: {reward:.3f})"
 
122
  if reward >= 0.75:
123
  weakest = min(scores, key=scores.get)
124
  return f"🟑 Good ({reward:.3f}) β€” improve: {weakest} ({scores[weakest]:.2f})"
125
  if reward >= 0.55:
126
  weak = [k for k, v in scores.items() if v < 0.5]
127
  return f"🟠 Needs work ({reward:.3f}) β€” fix: {', '.join(weak[:3])}"
128
+ return f"πŸ”΄ Poor ({reward:.3f}) β€” major failures"
graders/static_analysis.py CHANGED
@@ -1,33 +1,28 @@
1
  """
2
- SecureCodeEnv - Static Analysis Grader v3
3
-
4
- FIXED:
5
- - HIGH severity issues now cap the score at 0.40 max (was just subtracting 0.30)
6
- - Task-specific security checks have hard caps when violated
7
- - bandit penalty curve is steeper
8
  """
9
- import subprocess, json, tempfile, os, ast, re
 
10
 
11
 
12
  def grade_static_analysis(code: str, task: dict) -> dict:
13
  bandit = _run_bandit(code)
14
  custom = _run_custom_checks(code, task)
15
 
16
- # If a HARD security requirement is violated, cap at 0.40 regardless of bandit
17
  if custom.get("hard_fail"):
18
- final_score = min(bandit["score"] * 0.4, 0.40)
19
  else:
20
- final_score = (bandit["score"] * 0.60) + (custom["score"] * 0.40)
21
 
22
  all_issues = bandit.get("issues", []) + custom.get("issues", [])
23
-
24
  return {
25
- "score": round(max(0.0, min(1.0, final_score)), 4),
26
- "bandit_score": bandit["score"],
27
- "ast_score": custom["score"],
28
  "hard_fail": custom.get("hard_fail", False),
29
  "issues": all_issues[:10],
30
- "feedback": _feedback(final_score, all_issues, custom.get("hard_fail", False)),
31
  }
32
 
33
 
@@ -37,182 +32,129 @@ def _run_bandit(code: str) -> dict:
37
  with tempfile.NamedTemporaryFile(mode="w", suffix=".py",
38
  delete=False, prefix="sce_ban_") as f:
39
  f.write(code); tmp = f.name
40
-
41
  res = subprocess.run(
42
  ["bandit", "-r", tmp, "-f", "json", "-q", "--exit-zero"],
43
- capture_output=True, text=True, timeout=15
44
- )
45
- data = json.loads(res.stdout or '{"results":[]}')
46
  issues = data.get("results", [])
47
-
48
- # Steeper penalty curve + cap at 3 HIGH issues = 0.0
49
- penalty = 0.0
50
- for i in issues:
51
- sev = i.get("issue_severity", "LOW")
52
- if sev == "HIGH": penalty += 0.40
53
- elif sev == "MEDIUM": penalty += 0.20
54
- else: penalty += 0.05
55
-
56
- score = max(0.0, 1.0 - penalty)
57
  return {
58
- "score": round(score, 4),
59
  "issues": [{"severity": i.get("issue_severity"),
60
  "text": i.get("issue_text", "")[:100],
61
  "line": i.get("line_number")} for i in issues[:5]],
62
  }
63
  except FileNotFoundError:
64
- return {"score": 1.0, "issues": [], "note": "bandit not installed"}
65
  except Exception as e:
66
- return {"score": 1.0, "issues": [], "note": str(e)[:40]}
67
  finally:
68
  if tmp and os.path.exists(tmp):
69
  try: os.unlink(tmp)
70
  except OSError: pass
71
 
72
 
73
- # Security requirements per task β€” violation = hard_fail (cap at 0.40)
74
  HARD_REQUIREMENTS = {
75
  "easy_password_validator": [
76
- {"type": "forbidden_any", "patterns": ["hashlib.md5", "hashlib.sha1", "hashlib.sha256",
77
- "md5(", "sha1(", "sha256("],
78
- "message": "Weak hash function used β€” must use bcrypt"},
79
- {"type": "forbidden_any", "patterns": ["random.random(", "random.randint("],
80
- "message": "Insecure randomness"},
81
  ],
82
  "easy_token_generator": [
83
- {"type": "forbidden_any", "patterns": ["random.random(", "random.randint(",
84
- "random.choice(", "random.seed("],
85
- "message": "Weak PRNG used for token generation β€” must use secrets module"},
86
  {"type": "required_any", "patterns": ["secrets"],
87
  "message": "secrets module not imported"},
88
  ],
89
  "medium_sql_query_builder": [
90
- {"type": "forbidden_pattern",
91
- "regex": r'f["\'].*SELECT.*{',
92
- "message": "f-string SQL query β€” SQL injection vulnerability"},
93
  {"type": "forbidden_pattern",
94
  "regex": r'["\']\s*\+\s*(username|role|user_input|query)',
95
- "message": "String concatenation in SQL β€” SQL injection vulnerability"},
96
- {"type": "forbidden_pattern",
97
- "regex": r'%\s*\(',
98
- "message": "% formatting in SQL β€” potential injection"},
99
  ],
100
  "medium_file_path_handler": [
101
- {"type": "forbidden_any",
102
- "patterns": ['"..\" in', "'..' in", '".." in', "\"..\" in"],
103
- "message": "String check for '..' is bypassable β€” use Path.resolve()"},
104
- {"type": "required_any",
105
- "patterns": ["resolve()", ".resolve", "realpath"],
106
- "message": "Must use os.path.realpath or Path.resolve() to prevent traversal"},
107
  ],
108
  "hard_jwt_validator": [
109
  {"type": "forbidden_any",
110
- "patterns": ["verify_signature\": False", "verify_signature':False",
111
- "verify_exp\": False", "algorithms=[\"none\"", "algorithms=['none'"],
112
- "message": "JWT verification disabled β€” security bypass"},
113
- {"type": "required_any",
114
- "patterns": ["algorithms="],
115
- "message": "algorithms= not specified β€” alg:none attack possible"},
116
  ],
117
  "hard_auth_middleware": [
118
- {"type": "required_any",
119
- "patterns": ["hmac.compare_digest"],
120
- "message": "hmac.compare_digest not used β€” timing attack possible"},
121
- {"type": "forbidden_pattern",
122
- "regex": r'==\s*(session_token|request_token|secret|token)',
123
- "message": "== used for secret comparison β€” timing attack"},
124
  ],
125
  "easy_input_sanitizer": [
126
- {"type": "forbidden_any",
127
- "patterns": ["eval(", "exec("],
128
- "message": "eval/exec used β€” remote code execution risk"},
129
  ],
130
  "hard_file_upload_handler": [
131
- {"type": "required_any",
132
- "patterns": ["uuid"],
133
- "message": "uuid not used β€” filename collisions possible"},
134
- {"type": "required_any",
135
- "patterns": [".suffix", "splitext", "os.path.splitext"],
136
- "message": "Extension not checked via proper method"},
137
  ],
138
  }
139
 
140
- # Soft checks (no hard cap) β€” applied to all tasks
141
  GENERIC_CHECKS = [
142
- {"type": "forbidden_any", "patterns": ["eval(", "exec("],
143
  "message": "eval/exec usage", "severity": "HIGH"},
144
  {"type": "forbidden_any", "patterns": ["shell=True"],
145
- "message": "shell=True enables command injection", "severity": "HIGH"},
146
- {"type": "forbidden_any", "patterns": ["pickle.loads", "pickle.load"],
147
- "message": "Unsafe pickle deserialization", "severity": "HIGH"},
148
- {"type": "forbidden_any", "patterns": ["yaml.load("],
149
- "message": "Unsafe yaml.load", "severity": "MEDIUM"},
150
- {"type": "forbidden_any", "patterns": ["hashlib.md5", "hashlib.sha1"],
151
- "message": "Weak hash function", "severity": "HIGH"},
152
  ]
153
 
154
 
155
  def _run_custom_checks(code: str, task: dict) -> dict:
156
- issues = []
157
- hard_fail = False
158
- checks_passed = 0
159
- total_checks = 0
160
 
161
- # Generic checks
162
  for chk in GENERIC_CHECKS:
163
- total_checks += 1
164
- found = _check_code(code, chk)
165
- if found:
166
  issues.append({"check": chk["message"], "severity": chk.get("severity","MEDIUM"),
167
  "message": chk["message"]})
168
  else:
169
- checks_passed += 1
170
-
171
- # Task-specific hard requirements
172
- task_id = task.get("id", "")
173
- for req in HARD_REQUIREMENTS.get(task_id, []):
174
- total_checks += 1
175
- violated = _check_requirement_violated(code, req)
176
- if violated:
177
  hard_fail = True
178
  issues.append({"check": req["message"], "severity": "CRITICAL",
179
  "message": req["message"]})
180
  else:
181
- checks_passed += 1
182
-
183
- score = checks_passed / max(total_checks, 1)
184
- return {"score": round(score, 4), "issues": issues, "hard_fail": hard_fail}
185
-
186
 
187
- def _check_code(code: str, chk: dict) -> bool:
188
- """Returns True if the violation is found."""
189
- t = chk.get("type", "")
190
- if t == "forbidden_any":
191
- return any(p in code for p in chk.get("patterns", []))
192
- if t == "required_any":
193
- return not any(p in code for p in chk.get("patterns", []))
194
- if t == "forbidden_pattern":
195
- return bool(re.search(chk.get("regex", "NOMATCH"), code, re.IGNORECASE))
196
- return False
197
 
198
 
199
- def _check_requirement_violated(code: str, req: dict) -> bool:
200
- """Returns True if requirement is violated (= bad)."""
201
  t = req.get("type", "")
202
  if t == "forbidden_any":
203
  return any(p in code for p in req.get("patterns", []))
204
  if t == "required_any":
205
  return not any(p in code for p in req.get("patterns", []))
206
  if t == "forbidden_pattern":
207
- return bool(re.search(req.get("regex", "NOMATCH"), code, re.IGNORECASE | re.DOTALL))
208
  return False
209
 
210
 
211
  def _feedback(score: float, issues: list, hard_fail: bool) -> str:
212
  if hard_fail:
213
- critical = [i["message"] for i in issues if i.get("severity") == "CRITICAL"]
214
- return f"CRITICAL security violation: {'; '.join(critical[:2])}"
215
  if score >= 0.9: return "Clean β€” no significant security issues"
216
  high = sum(1 for i in issues if i.get("severity") == "HIGH")
217
- if high > 0: return f"{high} HIGH severity issue(s) β€” must fix"
218
- return f"Some security issues found (score: {score:.2f})"
 
1
  """
2
+ SecureCodeEnv - Static Analysis Grader v4
3
+ All scores clamped to (0.001, 0.999).
 
 
 
 
4
  """
5
+ import subprocess, json, tempfile, os, re
6
+ from graders.clamp import clamp
7
 
8
 
9
  def grade_static_analysis(code: str, task: dict) -> dict:
10
  bandit = _run_bandit(code)
11
  custom = _run_custom_checks(code, task)
12
 
 
13
  if custom.get("hard_fail"):
14
+ final = min(bandit["score"] * 0.4, 0.40)
15
  else:
16
+ final = (bandit["score"] * 0.60) + (custom["score"] * 0.40)
17
 
18
  all_issues = bandit.get("issues", []) + custom.get("issues", [])
 
19
  return {
20
+ "score": clamp(final),
21
+ "bandit_score": clamp(bandit["score"]),
22
+ "ast_score": clamp(custom["score"]),
23
  "hard_fail": custom.get("hard_fail", False),
24
  "issues": all_issues[:10],
25
+ "feedback": _feedback(final, all_issues, custom.get("hard_fail", False)),
26
  }
27
 
28
 
 
32
  with tempfile.NamedTemporaryFile(mode="w", suffix=".py",
33
  delete=False, prefix="sce_ban_") as f:
34
  f.write(code); tmp = f.name
 
35
  res = subprocess.run(
36
  ["bandit", "-r", tmp, "-f", "json", "-q", "--exit-zero"],
37
+ capture_output=True, text=True, timeout=15)
38
+ data = json.loads(res.stdout or '{"results":[]}')
 
39
  issues = data.get("results", [])
40
+ penalty = sum(
41
+ 0.40 if i.get("issue_severity") == "HIGH"
42
+ else 0.20 if i.get("issue_severity") == "MEDIUM"
43
+ else 0.05
44
+ for i in issues)
 
 
 
 
 
45
  return {
46
+ "score": max(0.0, 1.0 - penalty),
47
  "issues": [{"severity": i.get("issue_severity"),
48
  "text": i.get("issue_text", "")[:100],
49
  "line": i.get("line_number")} for i in issues[:5]],
50
  }
51
  except FileNotFoundError:
52
+ return {"score": 0.75, "issues": [], "note": "bandit not installed"}
53
  except Exception as e:
54
+ return {"score": 0.75, "issues": [], "note": str(e)[:40]}
55
  finally:
56
  if tmp and os.path.exists(tmp):
57
  try: os.unlink(tmp)
58
  except OSError: pass
59
 
60
 
 
61
  HARD_REQUIREMENTS = {
62
  "easy_password_validator": [
63
+ {"type": "forbidden_any",
64
+ "patterns": ["hashlib.md5","hashlib.sha1","hashlib.sha256","md5(","sha1(","sha256("],
65
+ "message": "Weak hash β€” must use bcrypt"},
66
+ {"type": "required_any", "patterns": ["bcrypt"],
67
+ "message": "bcrypt not imported"},
68
  ],
69
  "easy_token_generator": [
70
+ {"type": "forbidden_any",
71
+ "patterns": ["random.random(","random.randint(","random.choice(","random.seed("],
72
+ "message": "Weak PRNG β€” must use secrets"},
73
  {"type": "required_any", "patterns": ["secrets"],
74
  "message": "secrets module not imported"},
75
  ],
76
  "medium_sql_query_builder": [
77
+ {"type": "forbidden_pattern", "regex": r'f["\'].*SELECT.*{',
78
+ "message": "f-string SQL β€” injection vulnerability"},
 
79
  {"type": "forbidden_pattern",
80
  "regex": r'["\']\s*\+\s*(username|role|user_input|query)',
81
+ "message": "String concat SQL β€” injection vulnerability"},
 
 
 
82
  ],
83
  "medium_file_path_handler": [
84
+ {"type": "required_any", "patterns": ["resolve()","resolve(","realpath"],
85
+ "message": "Must use Path.resolve() or realpath"},
 
 
 
 
86
  ],
87
  "hard_jwt_validator": [
88
  {"type": "forbidden_any",
89
+ "patterns": ["verify_signature\": False","verify_signature':False",
90
+ "verify_exp\": False","algorithms=[\"none\"","algorithms=['none'"],
91
+ "message": "JWT verification disabled"},
92
+ {"type": "required_any", "patterns": ["algorithms="],
93
+ "message": "algorithms= not specified"},
 
94
  ],
95
  "hard_auth_middleware": [
96
+ {"type": "required_any", "patterns": ["hmac.compare_digest"],
97
+ "message": "hmac.compare_digest not used"},
 
 
 
 
98
  ],
99
  "easy_input_sanitizer": [
100
+ {"type": "forbidden_any", "patterns": ["eval(","exec("],
101
+ "message": "eval/exec usage"},
 
102
  ],
103
  "hard_file_upload_handler": [
104
+ {"type": "required_any", "patterns": ["uuid"],
105
+ "message": "uuid not used"},
 
 
 
 
106
  ],
107
  }
108
 
 
109
  GENERIC_CHECKS = [
110
+ {"type": "forbidden_any", "patterns": ["eval(","exec("],
111
  "message": "eval/exec usage", "severity": "HIGH"},
112
  {"type": "forbidden_any", "patterns": ["shell=True"],
113
+ "message": "shell=True", "severity": "HIGH"},
114
+ {"type": "forbidden_any", "patterns": ["pickle.loads","pickle.load"],
115
+ "message": "Unsafe pickle", "severity": "HIGH"},
116
+ {"type": "forbidden_any", "patterns": ["hashlib.md5","hashlib.sha1"],
117
+ "message": "Weak hash", "severity": "HIGH"},
 
 
118
  ]
119
 
120
 
121
  def _run_custom_checks(code: str, task: dict) -> dict:
122
+ issues = []; hard_fail = False; passed = total = 0
 
 
 
123
 
 
124
  for chk in GENERIC_CHECKS:
125
+ total += 1
126
+ if _violated(code, chk):
 
127
  issues.append({"check": chk["message"], "severity": chk.get("severity","MEDIUM"),
128
  "message": chk["message"]})
129
  else:
130
+ passed += 1
131
+
132
+ for req in HARD_REQUIREMENTS.get(task.get("id",""), []):
133
+ total += 1
134
+ if _violated(code, req):
 
 
 
135
  hard_fail = True
136
  issues.append({"check": req["message"], "severity": "CRITICAL",
137
  "message": req["message"]})
138
  else:
139
+ passed += 1
 
 
 
 
140
 
141
+ return {"score": passed / max(total, 1), "issues": issues, "hard_fail": hard_fail}
 
 
 
 
 
 
 
 
 
142
 
143
 
144
+ def _violated(code: str, req: dict) -> bool:
 
145
  t = req.get("type", "")
146
  if t == "forbidden_any":
147
  return any(p in code for p in req.get("patterns", []))
148
  if t == "required_any":
149
  return not any(p in code for p in req.get("patterns", []))
150
  if t == "forbidden_pattern":
151
+ return bool(re.search(req.get("regex","NOMATCH"), code, re.IGNORECASE | re.DOTALL))
152
  return False
153
 
154
 
155
  def _feedback(score: float, issues: list, hard_fail: bool) -> str:
156
  if hard_fail:
157
+ return f"CRITICAL: {'; '.join(i['message'] for i in issues if i.get('severity')=='CRITICAL')[:120]}"
 
158
  if score >= 0.9: return "Clean β€” no significant security issues"
159
  high = sum(1 for i in issues if i.get("severity") == "HIGH")
160
+ return f"{high} HIGH severity issue(s)" if high else f"Some issues (score: {score:.2f})"
 
inference.py CHANGED
@@ -13,92 +13,140 @@ from typing import Dict, List, Any
13
 
14
  # ── Configuration ──────────────────────────────────────────────────────────
15
  API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
16
- MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
17
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
18
- ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
19
 
20
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "sk-placeholder")
21
 
 
22
  def clamp_score(score: float) -> float:
23
- """
24
- Ensures score is strictly between 0 and 1 (e.g., 0.001 to 0.999).
25
- Required by validator range constraints.
26
- """
27
  epsilon = 0.001
28
- return max(epsilon, min(1.0 - epsilon, float(score)))
 
 
 
 
 
 
 
29
 
30
  def clean_code(raw: str) -> str:
31
  """Removes markdown code fences safely."""
32
- lines = [line for line in raw.splitlines() if not line.strip().startswith("```")]
 
33
  return "\n".join(lines).strip()
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def run_episode(difficulty: str) -> None:
36
- """Runs episode and prints clamped [START], [STEP], and [END] blocks."""
37
  try:
38
- r = requests.post(f"{ENV_URL}/reset", json={"difficulty": difficulty}, timeout=30)
 
 
 
 
39
  r.raise_for_status()
40
  data = r.json()
41
  except Exception as e:
42
  print(f"Failed to reset {difficulty}: {e}", file=sys.stderr)
43
  return
44
 
45
- sid = data["session_id"]
46
- tid = data["task_id"]
47
-
48
- # [START] block
49
  print(f"[START] task={tid} difficulty={difficulty}", flush=True)
50
 
51
- final_score = 0.0
52
  total_steps = 0
53
 
54
  for i in range(1, 6):
55
  total_steps = i
56
- prompt = f"Task: {data['problem_statement']}\nContext: {json.dumps(data.get('codegraph', {}))}"
57
-
 
 
 
 
 
 
 
 
 
 
58
  try:
59
  resp = client.chat.completions.create(
60
  model=MODEL_NAME,
61
- messages=[{"role": "user", "content": prompt}],
62
- temperature=0.1
 
 
 
 
63
  )
64
  code = clean_code(resp.choices[0].message.content or "")
65
-
 
 
66
  step_r = requests.post(
67
  f"{ENV_URL}/step",
68
- json={"session_id": sid, "code": code, "filename": f"step_{i}.py", "task_id": tid},
69
- timeout=65
 
 
 
 
 
70
  )
71
  step_r.raise_for_status()
72
  res = step_r.json()
73
-
74
- raw_reward = res.get("total_reward", 0.0)
75
- clamped_reward = clamp_score(raw_reward)
76
- final_score = clamped_reward
77
-
78
- # [STEP] block with clamped reward
79
- print(f"[STEP] step={i} reward={clamped_reward:.3f}", flush=True)
80
 
81
  if res.get("done"):
82
  break
83
- data["codegraph"] = res.get("codegraph", {})
84
-
 
 
 
85
  except Exception as e:
86
  print(f"Error in step {i}: {e}", file=sys.stderr)
87
- break
 
 
 
88
 
89
- # [END] block with clamped final score
90
- print(f"[END] task={tid} score={final_score:.3f} steps={total_steps}", flush=True)
91
 
92
  def main():
 
93
  try:
94
- requests.get(f"{ENV_URL}/health", timeout=5).raise_for_status()
 
95
  except Exception as e:
96
  print(f"Health check failed: {e}", file=sys.stderr)
97
  sys.exit(1)
98
 
99
  for diff in ["easy", "medium", "hard"]:
100
  run_episode(diff)
101
- time.sleep(1)
 
102
 
103
  if __name__ == "__main__":
104
  main()
 
13
 
14
  # ── Configuration ──────────────────────────────────────────────────────────
15
  API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
16
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
17
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
18
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
19
 
20
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "sk-placeholder")
21
 
22
+
23
  def clamp_score(score: float) -> float:
24
+ """Ensures score is strictly between 0 and 1 (0.001 … 0.999)."""
 
 
 
25
  epsilon = 0.001
26
+ try:
27
+ v = float(score)
28
+ except (TypeError, ValueError):
29
+ return 0.5
30
+ if v != v: # NaN
31
+ return 0.5
32
+ return max(epsilon, min(1.0 - epsilon, v))
33
+
34
 
35
  def clean_code(raw: str) -> str:
36
  """Removes markdown code fences safely."""
37
+ lines = [line for line in raw.splitlines()
38
+ if not line.strip().startswith("```")]
39
  return "\n".join(lines).strip()
40
 
41
+
42
+ SYSTEM_PROMPT = """You are a senior Python security engineer.
43
+ Output ONLY raw Python code β€” no markdown, no explanations.
44
+ Your code must:
45
+ 1. Solve the problem correctly
46
+ 2. Resist SQL injection, path traversal, and auth bypass attacks
47
+ 3. Use parameterized queries β€” never f-string SQL
48
+ 4. Use secrets module (not random) for tokens
49
+ 5. Use bcrypt (not hashlib) for passwords
50
+ 6. Use hmac.compare_digest for secret comparison
51
+ 7. Have type hints and docstrings on every function"""
52
+
53
+
54
  def run_episode(difficulty: str) -> None:
55
+ """Runs one episode and prints [START], [STEP], [END] blocks."""
56
  try:
57
+ r = requests.post(
58
+ f"{ENV_URL}/reset",
59
+ json={"difficulty": difficulty},
60
+ timeout=30,
61
+ )
62
  r.raise_for_status()
63
  data = r.json()
64
  except Exception as e:
65
  print(f"Failed to reset {difficulty}: {e}", file=sys.stderr)
66
  return
67
 
68
+ sid = data["session_id"]
69
+ tid = data["task_id"]
 
 
70
  print(f"[START] task={tid} difficulty={difficulty}", flush=True)
71
 
72
+ final_score = clamp_score(0.0) # starts at epsilon, not 0.0
73
  total_steps = 0
74
 
75
  for i in range(1, 6):
76
  total_steps = i
77
+ context_str = json.dumps(data.get("codegraph", {}))[:2000]
78
+ prev_fb = data.get("last_feedback", "")
79
+
80
+ user_msg = (
81
+ f"Task: {data['problem_statement']}\n\n"
82
+ f"Security targets: {data.get('cwe_targets', [])}\n\n"
83
+ f"Codebase context:\n{context_str}"
84
+ )
85
+ if prev_fb:
86
+ user_msg += f"\n\nPrevious feedback:\n{prev_fb}"
87
+ user_msg += "\n\nWrite the complete Python implementation now:"
88
+
89
  try:
90
  resp = client.chat.completions.create(
91
  model=MODEL_NAME,
92
+ messages=[
93
+ {"role": "system", "content": SYSTEM_PROMPT},
94
+ {"role": "user", "content": user_msg},
95
+ ],
96
+ max_tokens=1500,
97
+ temperature=0.1,
98
  )
99
  code = clean_code(resp.choices[0].message.content or "")
100
+ if not code.strip():
101
+ code = "def placeholder(): pass"
102
+
103
  step_r = requests.post(
104
  f"{ENV_URL}/step",
105
+ json={
106
+ "session_id": sid,
107
+ "code": code,
108
+ "filename": f"step_{i}.py",
109
+ "task_id": tid,
110
+ },
111
+ timeout=65,
112
  )
113
  step_r.raise_for_status()
114
  res = step_r.json()
115
+
116
+ raw_reward = res.get("total_reward", 0.0)
117
+ clamped = clamp_score(raw_reward)
118
+ final_score = clamped
119
+
120
+ print(f"[STEP] step={i} reward={clamped:.4f}", flush=True)
 
121
 
122
  if res.get("done"):
123
  break
124
+
125
+ # Feed updated context back for next step
126
+ data["codegraph"] = res.get("codegraph", {})
127
+ data["last_feedback"] = res.get("feedback", {}).get("summary", "")
128
+
129
  except Exception as e:
130
  print(f"Error in step {i}: {e}", file=sys.stderr)
131
+ # Don't break β€” try remaining steps
132
+ time.sleep(1)
133
+
134
+ print(f"[END] task={tid} score={final_score:.4f} steps={total_steps}", flush=True)
135
 
 
 
136
 
137
  def main():
138
+ # Health check
139
  try:
140
+ requests.get(f"{ENV_URL}/health", timeout=10).raise_for_status()
141
+ print(f"Environment healthy: {ENV_URL}", file=sys.stderr)
142
  except Exception as e:
143
  print(f"Health check failed: {e}", file=sys.stderr)
144
  sys.exit(1)
145
 
146
  for diff in ["easy", "medium", "hard"]:
147
  run_episode(diff)
148
+ time.sleep(2)
149
+
150
 
151
  if __name__ == "__main__":
152
  main()