python_env / server /review_runtime.py
darshanajudiya7's picture
Upload folder using huggingface_hub
d25ab77 verified
"""Benchmark runtime for the Python code-review environment."""
from __future__ import annotations
import random
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Dict, List, Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
try:
from ..models import (
ActionType,
CodeReviewSnippet,
EpisodeMetrics,
HealthResponse,
IssueType,
MetricsResponse,
PythonAction,
PythonEnvConfig,
PythonObservation,
PythonState,
ReviewComment,
RewardSummary,
TaskListResponse,
)
from .grading import GradeResult, grade_review
from .task_bank import get_task_metadata, load_task_bank, load_task_catalog
except ImportError:
from models import ( # type: ignore
ActionType,
CodeReviewSnippet,
EpisodeMetrics,
HealthResponse,
IssueType,
MetricsResponse,
PythonAction,
PythonEnvConfig,
PythonObservation,
PythonState,
ReviewComment,
RewardSummary,
TaskListResponse,
)
from server.grading import GradeResult, grade_review # type: ignore
from server.task_bank import get_task_metadata, load_task_bank, load_task_catalog # type: ignore
def _utc_now() -> str:
return datetime.now(UTC).isoformat()
def _severity_reward(issue_severity: str, bonus_issue: bool) -> float:
if bonus_issue:
return 0.03
if issue_severity in {"CRITICAL", "HIGH"}:
return 0.15
if issue_severity == "MEDIUM":
return 0.10
return 0.05
def _false_positive_penalty(action_severity: Optional[str]) -> float:
if action_severity == "CRITICAL":
return -0.12
if action_severity == "HIGH":
return -0.08
return -0.04
def _line_window_for_task(task_id: str) -> int:
if task_id == "task_easy":
return 3
if task_id == "task_medium":
return 5
return 0
@dataclass
class EpisodeRuntime:
episode_id: str
task_id: str
snippet: CodeReviewSnippet
current_step: int
max_steps: int
created_at: str
review_history: List[ReviewComment] = field(default_factory=list)
cumulative_reward: float = 0.0
done: bool = False
last_feedback: str = ""
found_issue_ids: set[str] = field(default_factory=set)
duplicate_comments: int = 0
context_requests: int = 0
skipped_clean_lines: int = 0
skipped_issue_lines: int = 0
commented_lines: set[int] = field(default_factory=set)
grade: GradeResult = field(
default_factory=lambda: GradeResult(
score=0.0,
precision=0.0,
recall=0.0,
f1=0.0,
true_positives=0,
false_positives=0,
missed_issues=0,
required_found=0,
required_total=0,
bonus_found=0,
matched_issue_ids=[],
breakdown={},
)
)
reward_summary: RewardSummary = field(default_factory=RewardSummary)
_ACTIVE_EPISODE: Optional[EpisodeRuntime] = None
_TASK_CURSOR = -1
_SNIPPET_CURSORS: Dict[str, int] = {task.task_id: -1 for task in load_task_catalog()}
def _set_active_episode(episode: Optional[EpisodeRuntime]) -> None:
global _ACTIVE_EPISODE
_ACTIVE_EPISODE = episode
def _current_episode() -> Optional[EpisodeRuntime]:
return _ACTIVE_EPISODE
def _match_issue_for_action(task_id: str, snippet: CodeReviewSnippet, action: PythonAction, found_issue_ids: set[str]) -> Optional[str]:
if action.action_type != ActionType.ADD_COMMENT or action.line_number is None or action.issue_type is None:
return None
max_distance = _line_window_for_task(task_id)
best_issue_id: Optional[str] = None
best_distance = max_distance + 1
for issue in snippet.gold_issues:
if issue.issue_id in found_issue_ids or issue.issue_type != action.issue_type:
continue
distance = abs(action.line_number - issue.line)
if distance <= max_distance and distance < best_distance:
best_issue_id = issue.issue_id
best_distance = distance
return best_issue_id
def build_metrics(episode: EpisodeRuntime) -> EpisodeMetrics:
return EpisodeMetrics(
precision=episode.grade.precision,
recall=episode.grade.recall,
f1=episode.grade.f1,
true_positives=episode.grade.true_positives,
false_positives=episode.grade.false_positives,
missed_issues=episode.grade.missed_issues,
required_found=episode.grade.required_found,
required_total=episode.grade.required_total,
bonus_found=episode.grade.bonus_found,
duplicate_comments=episode.duplicate_comments,
context_requests=episode.context_requests,
skipped_clean_lines=episode.skipped_clean_lines,
skipped_issue_lines=episode.skipped_issue_lines,
current_score=episode.grade.score,
cumulative_reward=episode.cumulative_reward,
breakdown=episode.grade.breakdown,
)
def build_state(episode: EpisodeRuntime) -> PythonState:
return PythonState(
episode_id=episode.episode_id,
step_count=episode.current_step,
task_id=episode.task_id,
difficulty=get_task_metadata(episode.task_id).difficulty,
snippet_id=episode.snippet.snippet_id,
current_step=episode.current_step,
max_steps=episode.max_steps,
done=episode.done,
filename=episode.snippet.filename,
review_history=list(episode.review_history),
metrics=build_metrics(episode),
last_feedback=episode.last_feedback,
)
def get_tasks_response() -> TaskListResponse:
return TaskListResponse(tasks=load_task_catalog())
def get_metrics_response() -> MetricsResponse:
episode = _current_episode()
if episode is None:
return MetricsResponse()
return MetricsResponse(task_id=episode.task_id, snippet_id=episode.snippet.snippet_id, done=episode.done, metrics=build_metrics(episode))
def get_health_response() -> HealthResponse:
episode = _current_episode()
return HealthResponse(
status="ok",
environment="python_code_review_env",
task_count=sum(len(items) for items in load_task_bank().values()),
active_task_id=episode.task_id if episode else None,
active_snippet_id=episode.snippet.snippet_id if episode else None,
active_episode_id=episode.episode_id if episode else None,
)
def get_current_state() -> PythonState:
episode = _current_episode()
return PythonState() if episode is None else build_state(episode)
class PythonReviewRuntime(Environment[PythonAction, PythonObservation, PythonState]):
"""Deterministic code-review benchmark environment with dense rewards."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self, config: Optional[PythonEnvConfig] = None):
super().__init__()
self._config = config or PythonEnvConfig()
self._episode: Optional[EpisodeRuntime] = None
def _restore_episode(self) -> Optional[EpisodeRuntime]:
if self._episode is not None:
return self._episode
self._episode = _current_episode()
return self._episode
def _select_task_id(self, seed: Optional[int]) -> str:
task_order = list(self._config.task_order)
if seed is not None:
return random.Random(seed).choice(task_order)
if not self._config.rotate_tasks:
return task_order[0]
global _TASK_CURSOR
_TASK_CURSOR = (_TASK_CURSOR + 1) % len(task_order)
return task_order[_TASK_CURSOR]
def _select_snippet(self, task_id: str, seed: Optional[int]) -> CodeReviewSnippet:
snippets = load_task_bank()[task_id]
if seed is not None:
return random.Random(seed).choice(snippets)
_SNIPPET_CURSORS[task_id] = (_SNIPPET_CURSORS[task_id] + 1) % len(snippets)
return snippets[_SNIPPET_CURSORS[task_id]]
def _terminal_reward(self, episode: EpisodeRuntime, action_type: ActionType) -> float:
reward = 0.0
if episode.grade.required_found == episode.grade.required_total and episode.grade.required_total:
reward += 0.20
if episode.grade.false_positives == 0:
reward += 0.10
if action_type == ActionType.REQUEST_CHANGES and episode.snippet.must_reject:
reward += 0.10
if action_type == ActionType.APPROVE and episode.snippet.must_approve:
reward += 0.15
if action_type == ActionType.APPROVE and episode.snippet.must_reject:
reward -= 0.25
reward += 0.05 * (1 - (episode.current_step / max(episode.max_steps, 1)))
return reward
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: Optional[str] = None, **kwargs) -> PythonObservation:
del kwargs
selected_task_id = task_id or self._select_task_id(seed)
snippet = self._select_snippet(selected_task_id, seed)
metadata = get_task_metadata(selected_task_id)
episode = EpisodeRuntime(
episode_id=episode_id or str(uuid4()),
task_id=selected_task_id,
snippet=snippet,
current_step=0,
max_steps=min(metadata.max_steps, self._config.max_steps_per_task),
created_at=_utc_now(),
)
episode.grade = grade_review(selected_task_id, snippet, episode.review_history, episode.duplicate_comments)
episode.last_feedback = f"Loaded {metadata.name}. Review the code and submit comments line by line."
self._episode = episode
_set_active_episode(episode)
return self._build_observation(episode, 0.0)
def step(self, action: PythonAction, timeout_s: Optional[float] = None, **kwargs) -> PythonObservation:
del timeout_s, kwargs
episode = self._restore_episode()
if episode is None:
return self.reset()
if episode.done:
return self._build_observation(episode, 0.0)
episode.current_step += 1
step_reward = 0.0
breakdown: Dict[str, float] = {}
feedback = ""
matched_issue_ids: List[str] = []
if action.action_type == ActionType.ADD_COMMENT:
if action.line_number in episode.commented_lines:
episode.duplicate_comments += 1
step_reward -= 0.08
breakdown["duplicate_comment_penalty"] = -0.08
issue_id = _match_issue_for_action(episode.task_id, episode.snippet, action, episode.found_issue_ids)
if issue_id is not None:
issue = next(item for item in episode.snippet.gold_issues if item.issue_id == issue_id)
hit_reward = _severity_reward(issue.severity.value, not issue.required)
step_reward += hit_reward
breakdown["issue_hit"] = hit_reward
episode.found_issue_ids.add(issue_id)
matched_issue_ids = [issue_id]
feedback = f"Recorded issue on line {action.line_number}."
else:
penalty = _false_positive_penalty(action.severity.value if action.severity else None)
step_reward += penalty
breakdown["false_positive_penalty"] = penalty
feedback = "Comment did not match a benchmark issue."
if action.line_number is not None:
episode.commented_lines.add(action.line_number)
elif action.action_type == ActionType.SKIP_LINE:
assert action.line_number is not None
required_issue_on_line = any(
issue.required and issue.line == action.line_number
for issue in episode.snippet.gold_issues
)
if required_issue_on_line:
step_reward -= 0.10
episode.skipped_issue_lines += 1
breakdown["skip_issue_penalty"] = -0.10
feedback = "Skipped a line with a required issue."
else:
step_reward += 0.02
episode.skipped_clean_lines += 1
breakdown["skip_clean_reward"] = 0.02
feedback = "Marked the line as clean."
elif action.action_type == ActionType.ASK_CONTEXT:
episode.context_requests += 1
step_reward -= 0.03
breakdown["ask_context_penalty"] = -0.03
feedback = episode.snippet.context or episode.snippet.diff or "No additional context available."
elif action.action_type in {ActionType.APPROVE, ActionType.REQUEST_CHANGES}:
feedback = "Final review decision recorded."
episode.review_history.append(
ReviewComment(
step_index=episode.current_step,
action_type=action.action_type,
line_number=action.line_number,
issue_type=action.issue_type,
severity=action.severity,
comment=action.comment,
suggestion=action.suggestion,
question=action.question,
matched_issue_ids=matched_issue_ids,
reward_delta=step_reward,
)
)
if len(episode.review_history) > self._config.max_history_entries:
episode.review_history = episode.review_history[-self._config.max_history_entries :]
done = action.action_type in {ActionType.APPROVE, ActionType.REQUEST_CHANGES}
if episode.current_step >= episode.max_steps:
done = True
feedback = f"{feedback} Maximum steps reached.".strip()
episode.grade = grade_review(episode.task_id, episode.snippet, episode.review_history, episode.duplicate_comments)
if done:
terminal_bonus = self._terminal_reward(episode, action.action_type)
step_reward += terminal_bonus
breakdown["terminal_bonus"] = terminal_bonus
episode.done = True
feedback = f"{feedback} Final score {episode.grade.score:.2f}.".strip()
episode.cumulative_reward += step_reward
episode.reward_summary = RewardSummary(
step_reward=step_reward,
cumulative_reward=episode.cumulative_reward,
breakdown=breakdown,
false_positives=episode.grade.false_positives,
true_positives=episode.grade.true_positives,
missed_issues=episode.grade.missed_issues,
)
episode.last_feedback = feedback or "Step complete."
self._episode = episode
_set_active_episode(episode)
return self._build_observation(episode, step_reward)
def _build_observation(self, episode: EpisodeRuntime, reward: float) -> PythonObservation:
lines = episode.snippet.code.splitlines()
return PythonObservation(
snippet_id=episode.snippet.snippet_id,
code=episode.snippet.code,
filename=episode.snippet.filename,
language="python",
context=episode.snippet.context,
diff=episode.snippet.diff,
line_count=len(lines),
current_step=episode.current_step,
max_steps=episode.max_steps,
task_id=episode.task_id,
review_history=list(episode.review_history),
lines=lines,
reward_summary=episode.reward_summary,
metrics=build_metrics(episode),
feedback=episode.last_feedback,
done=episode.done,
reward=reward,
metadata={
"episode_id": episode.episode_id,
"created_at": episode.created_at,
"updated_at": _utc_now(),
"task_name": get_task_metadata(episode.task_id).name,
},
)
@property
def state(self) -> PythonState:
episode = self._restore_episode()
return PythonState() if episode is None else build_state(episode)