python_env / models.py
darshanajudiya7's picture
Upload folder using huggingface_hub
36758e6 verified
"""Shared models for the Python code-review OpenEnv benchmark."""
from __future__ import annotations
from enum import Enum
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
from openenv.core.env_server.types import Action, Observation, State
class Difficulty(str, Enum):
EASY = "easy"
MEDIUM = "medium"
HARD = "hard"
class ActionType(str, Enum):
ADD_COMMENT = "ADD_COMMENT"
APPROVE = "APPROVE"
REQUEST_CHANGES = "REQUEST_CHANGES"
ASK_CONTEXT = "ASK_CONTEXT"
SKIP_LINE = "SKIP_LINE"
class IssueType(str, Enum):
STYLE = "STYLE"
LOGIC = "LOGIC"
SECURITY = "SECURITY"
PERFORMANCE = "PERFORMANCE"
DOCS = "DOCS"
class Severity(str, Enum):
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
CRITICAL = "CRITICAL"
class GoldIssue(BaseModel):
"""Hidden benchmark annotation for one issue in a snippet."""
issue_id: str
line: int = Field(..., ge=1)
issue_type: IssueType
severity: Severity
description: str
required: bool = True
explanation_keywords: List[str] = Field(default_factory=list)
fix_keywords: List[str] = Field(default_factory=list)
owasp_category: Optional[str] = None
owasp_keywords: List[str] = Field(default_factory=list)
class ReviewComment(BaseModel):
"""Stored review action visible to the agent in `review_history`."""
step_index: int = Field(..., ge=1)
action_type: ActionType
line_number: Optional[int] = Field(default=None, ge=1)
issue_type: Optional[IssueType] = None
severity: Optional[Severity] = None
comment: Optional[str] = None
suggestion: Optional[str] = None
question: Optional[str] = None
matched_issue_ids: List[str] = Field(default_factory=list)
reward_delta: float = 0.0
class CodeReviewSnippet(BaseModel):
"""Benchmark sample loaded from JSON."""
snippet_id: str
filename: str
code: str
context: Optional[str] = None
diff: Optional[str] = None
gold_issues: List[GoldIssue]
must_approve: bool = False
must_reject: bool = True
class TaskMetadata(BaseModel):
"""Visible task-family metadata."""
task_id: str
name: str
difficulty: Difficulty
description: str
snippet_count: int = Field(..., ge=0)
max_steps: int = Field(..., ge=1)
min_score: float = Field(default=0.0, ge=0.0, le=1.0)
max_score: float = Field(default=1.0, ge=0.0, le=1.0)
class ReviewFinding(BaseModel):
"""Compatibility shim for earlier template-derived environment code."""
title: str = ""
line: Optional[int] = Field(default=None, ge=1)
category: str = "bug"
severity: str = "warning"
rationale: str = ""
recommendation: Optional[str] = None
rule_id: Optional[str] = None
class TaskDescriptor(BaseModel):
"""Compatibility shim for earlier template-derived environment code."""
task_id: str
difficulty: str
title: str
objective: str
code: str
max_steps: int = Field(..., ge=1)
success_threshold: float = Field(default=0.0, ge=0.0, le=1.0)
class TaskEvaluation(BaseModel):
"""Compatibility shim for earlier template-derived environment code."""
matched_reference_ids: List[str] = Field(default_factory=list)
matched_findings: int = 0
total_findings: int = 0
false_positives: int = 0
duplicate_findings: int = 0
weighted_recall: float = 0.0
patch_score: float = 0.0
score: float = 0.0
passed: bool = False
class PythonEnvConfig(BaseModel):
"""Environment configuration used by the benchmark runtime."""
task_order: List[str] = Field(
default_factory=lambda: ["task_easy", "task_medium", "task_hard"]
)
max_steps_per_task: int = Field(default=25, ge=1, le=100)
max_history_entries: int = Field(default=200, ge=1, le=1000)
rotate_tasks: bool = True
# Evaluation parameters
patch_bonus_multiplier: float = 0.2
false_positive_penalty: float = 0.05
duplicate_penalty: float = 0.02
hint_penalty: float = 0.1
class EpisodeMetrics(BaseModel):
"""Current episode metrics for UI, evaluation, and RL logging."""
precision: float = Field(default=0.0, ge=0.0, le=1.0)
recall: float = Field(default=0.0, ge=0.0, le=1.0)
f1: float = Field(default=0.0, ge=0.0, le=1.0)
true_positives: int = Field(default=0, ge=0)
false_positives: int = Field(default=0, ge=0)
missed_issues: int = Field(default=0, ge=0)
required_found: int = Field(default=0, ge=0)
required_total: int = Field(default=0, ge=0)
bonus_found: int = Field(default=0, ge=0)
duplicate_comments: int = Field(default=0, ge=0)
context_requests: int = Field(default=0, ge=0)
skipped_clean_lines: int = Field(default=0, ge=0)
skipped_issue_lines: int = Field(default=0, ge=0)
current_score: float = Field(default=0.0, ge=0.0, le=1.0)
cumulative_reward: float = 0.0
breakdown: Dict[str, float] = Field(default_factory=dict)
class RewardSummary(BaseModel):
"""Reward details from the most recent step."""
step_reward: float = 0.0
cumulative_reward: float = 0.0
breakdown: Dict[str, float] = Field(default_factory=dict)
false_positives: int = Field(default=0, ge=0)
true_positives: int = Field(default=0, ge=0)
missed_issues: int = Field(default=0, ge=0)
class PythonReviewAction(Action):
"""Structured review action emitted by a model or trainer."""
operation: str = Field(default="ADD_COMMENT", description="The operation to perform.")
findings: List[ReviewFinding] = Field(default_factory=list, description="The findings list.")
patched_code: Optional[str] = Field(default=None, description="The fixed source code.")
action_type: ActionType = Field(
default=ActionType.ADD_COMMENT,
description="Choose the review action: comment on a line, skip a clean line, ask for context, approve, or request changes.",
)
line_number: Optional[int] = Field(
default=None,
ge=1,
description="Required for ADD_COMMENT and SKIP_LINE. Enter the code line number you are acting on.",
)
issue_type: Optional[IssueType] = Field(
default=None,
description="Required for ADD_COMMENT. Pick the issue category for the selected line.",
)
severity: Optional[Severity] = Field(
default=None,
description="Required for ADD_COMMENT. Pick how serious the issue is.",
)
comment: Optional[str] = Field(
default=None,
description="Required for ADD_COMMENT. Also used for ASK_CONTEXT if you want to ask a question in plain text.",
)
suggestion: Optional[str] = None
question: Optional[str] = None
@classmethod
def model_json_schema(cls, *args, **kwargs):
"""Trim legacy fields from the generated UI schema."""
schema = super().model_json_schema(*args, **kwargs)
properties = schema.get("properties", {})
visible_fields = {
"action_type",
"line_number",
"issue_type",
"severity",
"comment",
}
schema["properties"] = {
name: value for name, value in properties.items() if name in visible_fields
}
schema["required"] = [
name for name in schema.get("required", []) if name in visible_fields
]
return schema
@model_validator(mode="after")
def validate_action_shape(self) -> "PythonReviewAction":
"""Require the right fields for each action type."""
# Legacy template actions still use string operations like `submit_findings`.
# Benchmark actions should validate against `action_type`.
if self.operation != "ADD_COMMENT":
return self
if self.action_type == ActionType.ADD_COMMENT:
missing = []
if self.line_number is None:
missing.append("line_number")
if self.issue_type is None:
missing.append("issue_type")
if self.severity is None:
missing.append("severity")
if not (self.comment or "").strip():
missing.append("comment")
if missing:
raise ValueError("ADD_COMMENT requires: " + ", ".join(missing))
elif self.action_type == ActionType.SKIP_LINE:
if self.line_number is None:
raise ValueError("SKIP_LINE requires line_number")
elif self.action_type == ActionType.ASK_CONTEXT:
if not (self.question or self.comment or "").strip():
raise ValueError("ASK_CONTEXT requires question or comment")
elif self.action_type in {ActionType.APPROVE, ActionType.REQUEST_CHANGES}:
noisy_fields = {
"line_number": self.line_number,
"issue_type": self.issue_type,
"severity": self.severity,
"comment": self.comment,
"suggestion": self.suggestion,
"question": self.question,
}
populated = [
name for name, value in noisy_fields.items() if value not in (None, "")
]
if populated:
raise ValueError(
f"{self.action_type.value} does not accept extra fields: {', '.join(populated)}"
)
return self
class PythonReviewObservation(Observation):
"""Observation returned by reset/step, including trainer-visible metrics."""
snippet_id: str = ""
code: str = ""
filename: str = ""
language: str = "python"
context: Optional[str] = None
diff: Optional[str] = None
line_count: int = Field(default=0, ge=0)
current_step: int = Field(default=0, ge=0)
max_steps: int = Field(default=1, ge=1)
task_id: str = ""
review_history: List[ReviewComment] = Field(default_factory=list)
lines: List[str] = Field(default_factory=list)
reward_summary: RewardSummary = Field(default_factory=RewardSummary)
metrics: EpisodeMetrics = Field(default_factory=EpisodeMetrics)
feedback: str = ""
review_time_ms: float = 0.0
# Template compatibility
task: Optional[TaskDescriptor] = None
instructions: str = ""
submitted_findings: List[ReviewFinding] = Field(default_factory=list)
hints_used: int = 0
attempts_remaining: int = 0
evaluation: Optional[TaskEvaluation] = None
score: float = 0.0
class PythonReviewState(State):
"""Full server-side state exposed by `/state`."""
task_id: Optional[str] = None
difficulty: Optional[Difficulty] = None
snippet_id: Optional[str] = None
current_step: int = Field(default=0, ge=0)
max_steps: int = Field(default=0, ge=0)
done: bool = False
filename: Optional[str] = None
review_history: List[ReviewComment] = Field(default_factory=list)
metrics: EpisodeMetrics = Field(default_factory=EpisodeMetrics)
last_feedback: str = ""
class TaskListResponse(BaseModel):
tasks: List[TaskMetadata] = Field(default_factory=list)
class MetricsResponse(BaseModel):
task_id: Optional[str] = None
snippet_id: Optional[str] = None
done: bool = False
metrics: EpisodeMetrics = Field(default_factory=EpisodeMetrics)
class HealthResponse(BaseModel):
status: str = "ok"
environment: str = "python_code_review_env"
task_count: int = Field(default=0, ge=0)
active_task_id: Optional[str] = None
active_snippet_id: Optional[str] = None
active_episode_id: Optional[str] = None
PythonAction = PythonReviewAction
PythonObservation = PythonReviewObservation
PythonState = PythonReviewState
CodeReviewAction = PythonReviewAction
CodeReviewObservation = PythonReviewObservation
CodeReviewConfig = PythonEnvConfig