uvpatel7271's picture
Upload folder using huggingface_hub
c29f1fd verified
"""Shared deterministic grading helpers."""
from __future__ import annotations
import ast
import difflib
import math
import multiprocessing as mp
import time
import traceback
from typing import Any, Callable, Dict, List
try:
from ..Models import TaskGrade
from ..tasks.catalog import CallCase, ReviewTask
except ImportError:
from Models import TaskGrade
from tasks.catalog import CallCase, ReviewTask
STRICT_SCORE_MIN = 0.01
STRICT_SCORE_MAX = 0.99
POOR_SCORE = 0.1
NEAR_PERFECT_SCORE = 0.95
def finite_float(value: Any, fallback: float = STRICT_SCORE_MIN) -> float:
"""Convert a value into a finite float with a deterministic fallback."""
try:
numeric = float(value)
except (TypeError, ValueError):
return fallback
if math.isnan(numeric) or math.isinf(numeric):
return fallback
return numeric
def clamp(value: float, lower: float = 0.0, upper: float = 1.0) -> float:
"""Clamp a floating-point value to a closed interval."""
numeric = finite_float(value, fallback=lower)
return max(lower, min(upper, numeric))
def strict_score(value: Any, lower: float = STRICT_SCORE_MIN, upper: float = STRICT_SCORE_MAX) -> float:
"""Clamp a score to the OpenEnv-safe open interval (0, 1)."""
score = max(lower, min(upper, finite_float(value, fallback=lower)))
score = round(score, 3)
assert 0 < score < 1, f"Invalid score: {score}"
return score
def shaped_score(progress: Any, floor: float = POOR_SCORE, ceiling: float = NEAR_PERFECT_SCORE) -> float:
"""Map progress in [0, 1] to a shaped score band within (0, 1)."""
bounded_progress = clamp(finite_float(progress, fallback=0.0))
score = floor + (ceiling - floor) * bounded_progress
score = max(STRICT_SCORE_MIN, min(score, STRICT_SCORE_MAX))
score = round(score, 3)
assert 0 < score < 1, f"Invalid score: {score}"
return score
def score_from_checks(passed: int, total: int, floor: float = POOR_SCORE, ceiling: float = NEAR_PERFECT_SCORE) -> float:
"""Convert discrete checks into a smoothly shaped score."""
return shaped_score(safe_ratio(passed, total), floor=floor, ceiling=ceiling)
def safe_ratio(numerator: Any, denominator: Any) -> float:
"""Return a stable ratio in [0, 1] that never raises or produces NaN."""
denom = int(finite_float(denominator, fallback=0.0))
if denom <= 0:
return 0.0
numer = finite_float(numerator, fallback=0.0)
return clamp(numer / denom)
def component_score(value: Any) -> float:
"""Normalize component scores such as syntax, quality, and runtime."""
return strict_score(value)
def compile_code(code: str) -> tuple[bool, str]:
"""Return whether code compiles and the syntax error, if any."""
try:
compile(code, "<candidate>", "exec")
except SyntaxError as exc:
return False, f"SyntaxError: {exc.msg} (line {exc.lineno}, column {exc.offset})"
except Exception as exc: # pragma: no cover
return False, f"{type(exc).__name__}: {exc}"
return True, ""
def similarity_score(candidate: str, reference: str) -> float:
"""Compute a stable text similarity score in [0, 1]."""
return difflib.SequenceMatcher(a=candidate.strip(), b=reference.strip()).ratio()
def _queue_worker(
worker: Callable[[Dict[str, Any]], Dict[str, Any]],
payload: Dict[str, Any],
queue: Any,
) -> None:
try:
queue.put({"ok": True, "data": worker(payload)})
except Exception as exc: # pragma: no cover
queue.put(
{
"ok": False,
"error": f"{type(exc).__name__}: {exc}",
"traceback": traceback.format_exc(limit=5),
}
)
def run_with_timeout(
worker: Callable[[Dict[str, Any]], Dict[str, Any]],
payload: Dict[str, Any],
timeout_s: float,
) -> Dict[str, Any]:
"""Execute a worker in a subprocess and terminate on timeout."""
ctx = mp.get_context("spawn")
queue = ctx.Queue()
process = ctx.Process(target=_queue_worker, args=(worker, payload, queue))
process.start()
process.join(timeout_s)
if process.is_alive():
process.terminate()
process.join()
return {"timed_out": True, "error": f"Execution exceeded {timeout_s:.1f}s timeout."}
if queue.empty():
return {"timed_out": False, "error": "Worker exited without returning a result."}
message = queue.get()
if not message["ok"]:
return {
"timed_out": False,
"error": f"{message['error']}\n{message['traceback']}",
}
return {"timed_out": False, "data": message["data"]}
def _execute_cases_worker(payload: Dict[str, Any]) -> Dict[str, Any]:
namespace: Dict[str, Any] = {}
exec(payload["code"], namespace)
func = namespace[payload["function_name"]]
results: List[Dict[str, Any]] = []
for case in payload["cases"]:
try:
actual = func(*case["args"], **case["kwargs"])
passed = actual == case["expected"]
actual_repr = repr(actual)
except Exception as exc:
passed = False
actual_repr = f"{type(exc).__name__}: {exc}"
results.append(
{
"label": case["label"],
"passed": passed,
"expected": repr(case["expected"]),
"actual": actual_repr,
}
)
passed_total = sum(1 for item in results if item["passed"])
return {"passed": passed_total, "total": len(results), "results": results}
def execute_cases(code: str, function_name: str, cases: List[CallCase], timeout_s: float) -> Dict[str, Any]:
"""Run function test cases in a subprocess."""
payload = {
"code": code,
"function_name": function_name,
"cases": [
{"label": case.label, "args": case.args, "kwargs": case.kwargs, "expected": case.expected}
for case in cases
],
}
return run_with_timeout(_execute_cases_worker, payload, timeout_s=timeout_s)
class _LoopDepthVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self.depth = 0
self.max_depth = 0
def _visit_loop(self, node: ast.AST) -> None:
self.depth += 1
self.max_depth = max(self.max_depth, self.depth)
self.generic_visit(node)
self.depth -= 1
def visit_For(self, node: ast.For) -> None: # noqa: N802
self._visit_loop(node)
def visit_While(self, node: ast.While) -> None: # noqa: N802
self._visit_loop(node)
def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802
self._visit_loop(node)
def quality_metrics(code: str, function_name: str) -> Dict[str, Any]:
"""Compute deterministic AST/style quality metrics."""
compiled, error = compile_code(code)
if not compiled:
return {
"score": component_score(STRICT_SCORE_MIN),
"style_score": component_score(STRICT_SCORE_MIN),
"quality_notes": [error],
"max_loop_depth": 99,
}
tree = ast.parse(code)
function_node = next(
(
node
for node in tree.body
if isinstance(node, ast.FunctionDef) and node.name == function_name
),
None,
)
notes: List[str] = []
score = 0.0
if function_node is not None:
score += 0.2
else:
notes.append(f"Expected function {function_name!r} is missing.")
lines = [line.rstrip("\n") for line in code.splitlines()]
long_lines = [index + 1 for index, line in enumerate(lines) if len(line) > 88]
trailing_whitespace = [index + 1 for index, line in enumerate(lines) if line.rstrip() != line]
uses_tabs = any("\t" in line for line in lines)
style_score = 0.0
if not long_lines:
score += 0.15
style_score += 0.5
else:
notes.append(f"Lines longer than 88 characters: {long_lines[:3]}")
if not trailing_whitespace and not uses_tabs:
score += 0.15
style_score += 0.5
else:
notes.append("Remove tabs or trailing whitespace for cleaner style.")
if function_node is not None:
if ast.get_docstring(function_node):
score += 0.1
else:
notes.append("Add a short docstring to explain the function contract.")
visitor = _LoopDepthVisitor()
visitor.visit(function_node)
if visitor.max_depth <= 1:
score += 0.15
elif visitor.max_depth == 2:
score += 0.08
notes.append("Loop nesting is still higher than necessary.")
else:
notes.append("Refactor nested loops to improve readability and runtime.")
names = [node.id for node in ast.walk(function_node) if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store)]
meaningful_names = [name for name in names if len(name) >= 3]
if names:
score += 0.1 * (len(meaningful_names) / len(names))
function_length = (function_node.end_lineno or function_node.lineno) - function_node.lineno + 1
if function_length <= 25:
score += 0.1
elif function_length <= 40:
score += 0.05
notes.append("The function can be shortened or decomposed further.")
else:
notes.append("The function is long enough to justify refactoring.")
max_loop_depth = visitor.max_depth
else:
max_loop_depth = 0
source_hints = ("Counter(", "defaultdict(", "set(", "dict(", "sorted(", "sum(", " any(", " all(", " for ")
if any(hint in code for hint in source_hints):
score += 0.15
return {
"score": component_score(clamp(score)),
"style_score": component_score(clamp(style_score)),
"quality_notes": notes,
"max_loop_depth": max_loop_depth,
}
def build_benchmark_events(config: Dict[str, int]) -> List[Dict[str, Any]]:
"""Generate deterministic benchmark data without randomness."""
user_pool = config["user_pool"]
events_per_user = config["events_per_user"]
events: List[Dict[str, Any]] = []
for user_index in range(user_pool):
user_id = f"user-{user_index:03d}"
for event_index in range(events_per_user):
status = "active" if (user_index + event_index) % 3 != 0 else "inactive"
events.append({"user_id": user_id, "status": status, "minute": event_index})
if event_index % 6 == 0:
events.append({"user_id": user_id, "status": status, "minute": event_index})
return events
def _benchmark_worker(payload: Dict[str, Any]) -> Dict[str, Any]:
candidate_ns: Dict[str, Any] = {}
baseline_ns: Dict[str, Any] = {}
exec(payload["candidate_code"], candidate_ns)
exec(payload["baseline_code"], baseline_ns)
candidate = candidate_ns[payload["function_name"]]
baseline = baseline_ns[payload["function_name"]]
benchmark_events = payload["events"]
iterations = payload["iterations"]
baseline_output = baseline(benchmark_events)
candidate_output = candidate(benchmark_events)
if candidate_output != baseline_output:
raise AssertionError("Candidate output diverges from baseline on benchmark data.")
def _timed(fn: Callable[[Any], Any]) -> float:
start = time.perf_counter()
for _ in range(iterations):
fn(benchmark_events)
return time.perf_counter() - start
baseline_seconds = _timed(baseline)
candidate_seconds = _timed(candidate)
return {"baseline_seconds": baseline_seconds, "candidate_seconds": candidate_seconds}
def benchmark_candidate(task: ReviewTask, code: str, timeout_s: float) -> Dict[str, Any]:
"""Benchmark a candidate solution against the starter implementation."""
if not task.benchmark_config:
return {"runtime_score": component_score(STRICT_SCORE_MIN), "details": "No benchmark configured."}
events = build_benchmark_events(task.benchmark_config)
payload = {
"candidate_code": code,
"baseline_code": task.starter_code,
"function_name": task.function_name,
"events": events,
"iterations": task.benchmark_config.get("iterations", 5),
}
result = run_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s)
if result.get("timed_out"):
return {"runtime_score": component_score(STRICT_SCORE_MIN), "timed_out": True, "details": result["error"]}
if "error" in result:
return {"runtime_score": component_score(STRICT_SCORE_MIN), "timed_out": False, "details": result["error"]}
data = result["data"]
baseline_seconds = float(data["baseline_seconds"])
candidate_seconds = float(data["candidate_seconds"])
improvement_ratio = baseline_seconds / max(candidate_seconds, 1e-9)
runtime_score = component_score(clamp((improvement_ratio - 1.0) / 1.5))
return {
"runtime_score": runtime_score,
"timed_out": False,
"details": {
"baseline_seconds": round(baseline_seconds, 6),
"candidate_seconds": round(candidate_seconds, 6),
"improvement_ratio": round(improvement_ratio, 3),
},
}
def summarize_results(prefix: str, results: List[Dict[str, Any]]) -> str:
"""Render concise test output."""
if not results:
return f"{prefix}: no tests were executed."
lines = [prefix]
for item in results:
marker = "PASS" if item["passed"] else "FAIL"
lines.append(f"- {marker} {item['label']}: expected {item['expected']}, got {item['actual']}")
return "\n".join(lines)
def base_grade(
*,
score: float,
syntax_score: float,
tests_passed: int,
tests_total: int,
quality_score: float,
runtime_score: float,
timed_out: bool,
details: Dict[str, Any],
) -> TaskGrade:
"""Create a normalized TaskGrade payload."""
safe_score = strict_score(score)
safe_syntax_score = component_score(syntax_score)
safe_quality_score = component_score(quality_score)
safe_runtime_score = component_score(runtime_score)
return TaskGrade(
score=safe_score,
syntax_score=safe_syntax_score,
tests_passed=tests_passed,
tests_total=tests_total,
quality_score=safe_quality_score,
runtime_score=safe_runtime_score,
timed_out=timed_out,
details=details,
)