sql-query-reviewer / server /environment.py
hellinferno's picture
improve: 20 tasks, richer keywords, enhanced reward/grader, bigram matching, compelling README
b83c8ad
from __future__ import annotations
import json
from pathlib import Path
from sql_query_reviewer.models import (
IdentifiedIssue,
SQLReviewAction,
SQLReviewObservation,
SQLReviewState,
StepResult,
TaskRecord,
)
from server.grader import grade_episode, match_issue, validate_fix
from server.reward import compute_reward
class SQLReviewEnvironment:
def __init__(self, task_directory: Path | None = None) -> None:
self.task_directory = task_directory or Path(__file__).resolve().parent.parent / "tasks"
self.tasks = self._load_tasks()
self.task_order = sorted(self.tasks)
self.current_task: TaskRecord | None = None
self.current_state: SQLReviewState | None = None
self._reset_index = 0
def available_task_ids(self) -> list[str]:
return list(self.task_order)
def reset(self, task_id: str | None = None) -> StepResult:
selected_task_id = task_id or self._next_task_id()
if selected_task_id not in self.tasks:
raise ValueError(f"Unknown task_id: {selected_task_id}")
self.current_task = self.tasks[selected_task_id]
self.current_state = SQLReviewState(task_id=self.current_task.task_id)
observation = self._build_observation(
feedback="Review this SQL query and identify correctness, performance, or security issues."
)
return StepResult(observation=observation, reward=0.0, done=False, info={})
def step(self, action: SQLReviewAction) -> StepResult:
task = self._require_task()
state = self._require_state()
if state.done:
raise RuntimeError("Episode already finished. Call reset() before taking more steps.")
found_ids = {issue.issue_id for issue in state.issues_identified}
reward = 0.0
info: dict[str, object] = {}
feedback = "No-op."
state.step_count += 1
if action.action_type == "identify_issue":
duplicate_issue, duplicate_score = match_issue(action, task.ground_truth_issues, set())
if duplicate_issue is not None and duplicate_issue.id in found_ids:
reward = compute_reward(action, duplicate_issue, duplicate_issue=True)
feedback = f"Issue '{duplicate_issue.id}' was already identified earlier in the episode."
info = {"match_score": round(duplicate_score, 3), "match_type": "duplicate", "issue_id": duplicate_issue.id}
else:
matched_issue, score = match_issue(action, task.ground_truth_issues, found_ids)
if matched_issue is None:
state.false_positive_count += 1
reward = compute_reward(action, None)
feedback = "No matching issue found for that description."
info = {"match_score": round(score, 3), "match_type": "none"}
else:
fix_valid = validate_fix(action.suggested_fix, matched_issue)
state.issues_identified.append(
IdentifiedIssue(
issue_id=matched_issue.id,
category=matched_issue.category,
description=matched_issue.description,
)
)
reward = compute_reward(action, matched_issue, fix_valid=fix_valid, issues_found_count=len(state.issues_identified), schema_available=bool(task.schema_info))
remaining = len(task.ground_truth_issues) - len(state.issues_identified)
feedback = f"Matched {matched_issue.category} issue '{matched_issue.id}'. {remaining} issue(s) remaining."
info = {
"match_score": round(score, 3),
"match_type": "fuzzy",
"severity": matched_issue.severity,
"issue_id": matched_issue.id,
"all_issues_found": remaining == 0,
}
if fix_valid and action.suggested_fix:
state.fixes_suggested.append(action.suggested_fix)
elif action.action_type == "suggest_fix":
if not state.issues_identified:
reward = compute_reward(action, None, has_previous_issue=False)
feedback = "Identify an issue before suggesting a fix."
else:
last_issue_id = state.issues_identified[-1].issue_id
last_issue = next(issue for issue in task.ground_truth_issues if issue.id == last_issue_id)
fix_valid = validate_fix(action.suggested_fix, last_issue)
reward = compute_reward(action, last_issue, fix_valid=fix_valid, has_previous_issue=True)
feedback = "Fix accepted for the last identified issue." if fix_valid else "Suggested fix did not match the expected remediation."
info = {"issue_id": last_issue.id, "fix_valid": fix_valid}
if fix_valid and action.suggested_fix:
state.fixes_suggested.append(action.suggested_fix)
elif action.action_type == "approve":
remaining_unfound = len(task.ground_truth_issues) - len(found_ids)
reward = compute_reward(action, None, remaining_unfound=remaining_unfound)
state.approved = True
state.done = True
feedback = (
"Query approved with full issue coverage."
if remaining_unfound == 0
else f"Query approved too early. {remaining_unfound} issue(s) were missed."
)
info = {"remaining_unfound": remaining_unfound}
else:
feedback = self._schema_feedback(task)
reward = compute_reward(action, None, schema_available=bool(task.schema_info))
info = {"context_shared": bool(task.schema_info)}
state.total_reward += reward
if state.step_count >= task.max_steps and not state.done:
state.done = True
feedback = f"{feedback} Maximum step count reached."
if state.done:
state.final_score = grade_episode(
found_issue_ids={issue.issue_id for issue in state.issues_identified},
ground_truth_issues=task.ground_truth_issues,
total_steps=state.step_count,
max_steps=task.max_steps,
false_positive_count=state.false_positive_count,
)
info["final_score"] = state.final_score
observation = self._build_observation(feedback=feedback)
return StepResult(observation=observation, reward=reward, done=state.done, info=info)
def state(self) -> SQLReviewState:
return self._require_state().model_copy(deep=True)
def _load_tasks(self) -> dict[str, TaskRecord]:
tasks: dict[str, TaskRecord] = {}
for file_path in sorted(self.task_directory.glob("*_tasks.json")):
with file_path.open("r", encoding="utf-8") as handle:
for raw_task in json.load(handle):
task = TaskRecord.model_validate(raw_task)
tasks[task.task_id] = task
if not tasks:
raise RuntimeError(f"No task files found in {self.task_directory}")
return tasks
def _next_task_id(self) -> str:
task_id = self.task_order[self._reset_index % len(self.task_order)]
self._reset_index += 1
return task_id
def _build_observation(self, feedback: str) -> SQLReviewObservation:
task = self._require_task()
state = self._require_state()
remaining_actions = max(task.max_steps - state.step_count, 0)
return SQLReviewObservation(
query=task.query,
schema_info=task.schema_info,
context=task.context,
issues_found_so_far=state.issues_identified,
remaining_actions=remaining_actions,
difficulty=task.difficulty,
feedback=feedback,
)
def _schema_feedback(self, task: TaskRecord) -> str:
if not task.schema_info:
return "No additional schema context is available for this task."
tables = ", ".join(sorted(task.schema_info))
return f"Schema context available for: {tables}."
def _require_task(self) -> TaskRecord:
if self.current_task is None:
raise RuntimeError("Environment has no active task. Call reset() first.")
return self.current_task
def _require_state(self) -> SQLReviewState:
if self.current_state is None:
raise RuntimeError("Environment has no active state. Call reset() first.")
return self.current_state