"""Benchmark runtime for the Python code-review environment.""" from __future__ import annotations import random from dataclasses import dataclass, field from datetime import UTC, datetime from typing import Dict, List, Optional from uuid import uuid4 from openenv.core.env_server.interfaces import Environment try: from ..models import ( ActionType, CodeReviewSnippet, EpisodeMetrics, HealthResponse, IssueType, MetricsResponse, PythonAction, PythonEnvConfig, PythonObservation, PythonState, ReviewComment, RewardSummary, TaskListResponse, ) from .grading import GradeResult, grade_review from .task_bank import get_task_metadata, load_task_bank, load_task_catalog except ImportError: from models import ( # type: ignore ActionType, CodeReviewSnippet, EpisodeMetrics, HealthResponse, IssueType, MetricsResponse, PythonAction, PythonEnvConfig, PythonObservation, PythonState, ReviewComment, RewardSummary, TaskListResponse, ) from server.grading import GradeResult, grade_review # type: ignore from server.task_bank import get_task_metadata, load_task_bank, load_task_catalog # type: ignore def _utc_now() -> str: return datetime.now(UTC).isoformat() def _severity_reward(issue_severity: str, bonus_issue: bool) -> float: if bonus_issue: return 0.03 if issue_severity in {"CRITICAL", "HIGH"}: return 0.15 if issue_severity == "MEDIUM": return 0.10 return 0.05 def _false_positive_penalty(action_severity: Optional[str]) -> float: if action_severity == "CRITICAL": return -0.12 if action_severity == "HIGH": return -0.08 return -0.04 def _line_window_for_task(task_id: str) -> int: if task_id == "task_easy": return 3 if task_id == "task_medium": return 5 return 0 @dataclass class EpisodeRuntime: episode_id: str task_id: str snippet: CodeReviewSnippet current_step: int max_steps: int created_at: str review_history: List[ReviewComment] = field(default_factory=list) cumulative_reward: float = 0.0 done: bool = False last_feedback: str = "" found_issue_ids: set[str] = field(default_factory=set) duplicate_comments: int = 0 context_requests: int = 0 skipped_clean_lines: int = 0 skipped_issue_lines: int = 0 commented_lines: set[int] = field(default_factory=set) grade: GradeResult = field( default_factory=lambda: GradeResult( score=0.0, precision=0.0, recall=0.0, f1=0.0, true_positives=0, false_positives=0, missed_issues=0, required_found=0, required_total=0, bonus_found=0, matched_issue_ids=[], breakdown={}, ) ) reward_summary: RewardSummary = field(default_factory=RewardSummary) _ACTIVE_EPISODE: Optional[EpisodeRuntime] = None _TASK_CURSOR = -1 _SNIPPET_CURSORS: Dict[str, int] = {task.task_id: -1 for task in load_task_catalog()} def _set_active_episode(episode: Optional[EpisodeRuntime]) -> None: global _ACTIVE_EPISODE _ACTIVE_EPISODE = episode def _current_episode() -> Optional[EpisodeRuntime]: return _ACTIVE_EPISODE def _match_issue_for_action(task_id: str, snippet: CodeReviewSnippet, action: PythonAction, found_issue_ids: set[str]) -> Optional[str]: if action.action_type != ActionType.ADD_COMMENT or action.line_number is None or action.issue_type is None: return None max_distance = _line_window_for_task(task_id) best_issue_id: Optional[str] = None best_distance = max_distance + 1 for issue in snippet.gold_issues: if issue.issue_id in found_issue_ids or issue.issue_type != action.issue_type: continue distance = abs(action.line_number - issue.line) if distance <= max_distance and distance < best_distance: best_issue_id = issue.issue_id best_distance = distance return best_issue_id def build_metrics(episode: EpisodeRuntime) -> EpisodeMetrics: return EpisodeMetrics( precision=episode.grade.precision, recall=episode.grade.recall, f1=episode.grade.f1, true_positives=episode.grade.true_positives, false_positives=episode.grade.false_positives, missed_issues=episode.grade.missed_issues, required_found=episode.grade.required_found, required_total=episode.grade.required_total, bonus_found=episode.grade.bonus_found, duplicate_comments=episode.duplicate_comments, context_requests=episode.context_requests, skipped_clean_lines=episode.skipped_clean_lines, skipped_issue_lines=episode.skipped_issue_lines, current_score=episode.grade.score, cumulative_reward=episode.cumulative_reward, breakdown=episode.grade.breakdown, ) def build_state(episode: EpisodeRuntime) -> PythonState: return PythonState( episode_id=episode.episode_id, step_count=episode.current_step, task_id=episode.task_id, difficulty=get_task_metadata(episode.task_id).difficulty, snippet_id=episode.snippet.snippet_id, current_step=episode.current_step, max_steps=episode.max_steps, done=episode.done, filename=episode.snippet.filename, review_history=list(episode.review_history), metrics=build_metrics(episode), last_feedback=episode.last_feedback, ) def get_tasks_response() -> TaskListResponse: return TaskListResponse(tasks=load_task_catalog()) def get_metrics_response() -> MetricsResponse: episode = _current_episode() if episode is None: return MetricsResponse() return MetricsResponse(task_id=episode.task_id, snippet_id=episode.snippet.snippet_id, done=episode.done, metrics=build_metrics(episode)) def get_health_response() -> HealthResponse: episode = _current_episode() return HealthResponse( status="ok", environment="python_code_review_env", task_count=sum(len(items) for items in load_task_bank().values()), active_task_id=episode.task_id if episode else None, active_snippet_id=episode.snippet.snippet_id if episode else None, active_episode_id=episode.episode_id if episode else None, ) def get_current_state() -> PythonState: episode = _current_episode() return PythonState() if episode is None else build_state(episode) class PythonReviewRuntime(Environment[PythonAction, PythonObservation, PythonState]): """Deterministic code-review benchmark environment with dense rewards.""" SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self, config: Optional[PythonEnvConfig] = None): super().__init__() self._config = config or PythonEnvConfig() self._episode: Optional[EpisodeRuntime] = None def _restore_episode(self) -> Optional[EpisodeRuntime]: if self._episode is not None: return self._episode self._episode = _current_episode() return self._episode def _select_task_id(self, seed: Optional[int]) -> str: task_order = list(self._config.task_order) if seed is not None: return random.Random(seed).choice(task_order) if not self._config.rotate_tasks: return task_order[0] global _TASK_CURSOR _TASK_CURSOR = (_TASK_CURSOR + 1) % len(task_order) return task_order[_TASK_CURSOR] def _select_snippet(self, task_id: str, seed: Optional[int]) -> CodeReviewSnippet: snippets = load_task_bank()[task_id] if seed is not None: return random.Random(seed).choice(snippets) _SNIPPET_CURSORS[task_id] = (_SNIPPET_CURSORS[task_id] + 1) % len(snippets) return snippets[_SNIPPET_CURSORS[task_id]] def _terminal_reward(self, episode: EpisodeRuntime, action_type: ActionType) -> float: reward = 0.0 if episode.grade.required_found == episode.grade.required_total and episode.grade.required_total: reward += 0.20 if episode.grade.false_positives == 0: reward += 0.10 if action_type == ActionType.REQUEST_CHANGES and episode.snippet.must_reject: reward += 0.10 if action_type == ActionType.APPROVE and episode.snippet.must_approve: reward += 0.15 if action_type == ActionType.APPROVE and episode.snippet.must_reject: reward -= 0.25 reward += 0.05 * (1 - (episode.current_step / max(episode.max_steps, 1))) return reward def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: Optional[str] = None, **kwargs) -> PythonObservation: del kwargs selected_task_id = task_id or self._select_task_id(seed) snippet = self._select_snippet(selected_task_id, seed) metadata = get_task_metadata(selected_task_id) episode = EpisodeRuntime( episode_id=episode_id or str(uuid4()), task_id=selected_task_id, snippet=snippet, current_step=0, max_steps=min(metadata.max_steps, self._config.max_steps_per_task), created_at=_utc_now(), ) episode.grade = grade_review(selected_task_id, snippet, episode.review_history, episode.duplicate_comments) episode.last_feedback = f"Loaded {metadata.name}. Review the code and submit comments line by line." self._episode = episode _set_active_episode(episode) return self._build_observation(episode, 0.0) def step(self, action: PythonAction, timeout_s: Optional[float] = None, **kwargs) -> PythonObservation: del timeout_s, kwargs episode = self._restore_episode() if episode is None: return self.reset() if episode.done: return self._build_observation(episode, 0.0) episode.current_step += 1 step_reward = 0.0 breakdown: Dict[str, float] = {} feedback = "" matched_issue_ids: List[str] = [] if action.action_type == ActionType.ADD_COMMENT: if action.line_number in episode.commented_lines: episode.duplicate_comments += 1 step_reward -= 0.08 breakdown["duplicate_comment_penalty"] = -0.08 issue_id = _match_issue_for_action(episode.task_id, episode.snippet, action, episode.found_issue_ids) if issue_id is not None: issue = next(item for item in episode.snippet.gold_issues if item.issue_id == issue_id) hit_reward = _severity_reward(issue.severity.value, not issue.required) step_reward += hit_reward breakdown["issue_hit"] = hit_reward episode.found_issue_ids.add(issue_id) matched_issue_ids = [issue_id] feedback = f"Recorded issue on line {action.line_number}." else: penalty = _false_positive_penalty(action.severity.value if action.severity else None) step_reward += penalty breakdown["false_positive_penalty"] = penalty feedback = "Comment did not match a benchmark issue." if action.line_number is not None: episode.commented_lines.add(action.line_number) elif action.action_type == ActionType.SKIP_LINE: assert action.line_number is not None required_issue_on_line = any( issue.required and issue.line == action.line_number for issue in episode.snippet.gold_issues ) if required_issue_on_line: step_reward -= 0.10 episode.skipped_issue_lines += 1 breakdown["skip_issue_penalty"] = -0.10 feedback = "Skipped a line with a required issue." else: step_reward += 0.02 episode.skipped_clean_lines += 1 breakdown["skip_clean_reward"] = 0.02 feedback = "Marked the line as clean." elif action.action_type == ActionType.ASK_CONTEXT: episode.context_requests += 1 step_reward -= 0.03 breakdown["ask_context_penalty"] = -0.03 feedback = episode.snippet.context or episode.snippet.diff or "No additional context available." elif action.action_type in {ActionType.APPROVE, ActionType.REQUEST_CHANGES}: feedback = "Final review decision recorded." episode.review_history.append( ReviewComment( step_index=episode.current_step, action_type=action.action_type, line_number=action.line_number, issue_type=action.issue_type, severity=action.severity, comment=action.comment, suggestion=action.suggestion, question=action.question, matched_issue_ids=matched_issue_ids, reward_delta=step_reward, ) ) if len(episode.review_history) > self._config.max_history_entries: episode.review_history = episode.review_history[-self._config.max_history_entries :] done = action.action_type in {ActionType.APPROVE, ActionType.REQUEST_CHANGES} if episode.current_step >= episode.max_steps: done = True feedback = f"{feedback} Maximum steps reached.".strip() episode.grade = grade_review(episode.task_id, episode.snippet, episode.review_history, episode.duplicate_comments) if done: terminal_bonus = self._terminal_reward(episode, action.action_type) step_reward += terminal_bonus breakdown["terminal_bonus"] = terminal_bonus episode.done = True feedback = f"{feedback} Final score {episode.grade.score:.2f}.".strip() episode.cumulative_reward += step_reward episode.reward_summary = RewardSummary( step_reward=step_reward, cumulative_reward=episode.cumulative_reward, breakdown=breakdown, false_positives=episode.grade.false_positives, true_positives=episode.grade.true_positives, missed_issues=episode.grade.missed_issues, ) episode.last_feedback = feedback or "Step complete." self._episode = episode _set_active_episode(episode) return self._build_observation(episode, step_reward) def _build_observation(self, episode: EpisodeRuntime, reward: float) -> PythonObservation: lines = episode.snippet.code.splitlines() return PythonObservation( snippet_id=episode.snippet.snippet_id, code=episode.snippet.code, filename=episode.snippet.filename, language="python", context=episode.snippet.context, diff=episode.snippet.diff, line_count=len(lines), current_step=episode.current_step, max_steps=episode.max_steps, task_id=episode.task_id, review_history=list(episode.review_history), lines=lines, reward_summary=episode.reward_summary, metrics=build_metrics(episode), feedback=episode.last_feedback, done=episode.done, reward=reward, metadata={ "episode_id": episode.episode_id, "created_at": episode.created_at, "updated_at": _utc_now(), "task_name": get_task_metadata(episode.task_id).name, }, ) @property def state(self) -> PythonState: episode = self._restore_episode() return PythonState() if episode is None else build_state(episode)