Spaces:
Sleeping
Sleeping
improve: 20 tasks, richer keywords, enhanced reward/grader, bigram matching, compelling README
b83c8ad | from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from sql_query_reviewer.models import ( | |
| IdentifiedIssue, | |
| SQLReviewAction, | |
| SQLReviewObservation, | |
| SQLReviewState, | |
| StepResult, | |
| TaskRecord, | |
| ) | |
| from server.grader import grade_episode, match_issue, validate_fix | |
| from server.reward import compute_reward | |
| class SQLReviewEnvironment: | |
| def __init__(self, task_directory: Path | None = None) -> None: | |
| self.task_directory = task_directory or Path(__file__).resolve().parent.parent / "tasks" | |
| self.tasks = self._load_tasks() | |
| self.task_order = sorted(self.tasks) | |
| self.current_task: TaskRecord | None = None | |
| self.current_state: SQLReviewState | None = None | |
| self._reset_index = 0 | |
| def available_task_ids(self) -> list[str]: | |
| return list(self.task_order) | |
| def reset(self, task_id: str | None = None) -> StepResult: | |
| selected_task_id = task_id or self._next_task_id() | |
| if selected_task_id not in self.tasks: | |
| raise ValueError(f"Unknown task_id: {selected_task_id}") | |
| self.current_task = self.tasks[selected_task_id] | |
| self.current_state = SQLReviewState(task_id=self.current_task.task_id) | |
| observation = self._build_observation( | |
| feedback="Review this SQL query and identify correctness, performance, or security issues." | |
| ) | |
| return StepResult(observation=observation, reward=0.0, done=False, info={}) | |
| def step(self, action: SQLReviewAction) -> StepResult: | |
| task = self._require_task() | |
| state = self._require_state() | |
| if state.done: | |
| raise RuntimeError("Episode already finished. Call reset() before taking more steps.") | |
| found_ids = {issue.issue_id for issue in state.issues_identified} | |
| reward = 0.0 | |
| info: dict[str, object] = {} | |
| feedback = "No-op." | |
| state.step_count += 1 | |
| if action.action_type == "identify_issue": | |
| duplicate_issue, duplicate_score = match_issue(action, task.ground_truth_issues, set()) | |
| if duplicate_issue is not None and duplicate_issue.id in found_ids: | |
| reward = compute_reward(action, duplicate_issue, duplicate_issue=True) | |
| feedback = f"Issue '{duplicate_issue.id}' was already identified earlier in the episode." | |
| info = {"match_score": round(duplicate_score, 3), "match_type": "duplicate", "issue_id": duplicate_issue.id} | |
| else: | |
| matched_issue, score = match_issue(action, task.ground_truth_issues, found_ids) | |
| if matched_issue is None: | |
| state.false_positive_count += 1 | |
| reward = compute_reward(action, None) | |
| feedback = "No matching issue found for that description." | |
| info = {"match_score": round(score, 3), "match_type": "none"} | |
| else: | |
| fix_valid = validate_fix(action.suggested_fix, matched_issue) | |
| state.issues_identified.append( | |
| IdentifiedIssue( | |
| issue_id=matched_issue.id, | |
| category=matched_issue.category, | |
| description=matched_issue.description, | |
| ) | |
| ) | |
| reward = compute_reward(action, matched_issue, fix_valid=fix_valid, issues_found_count=len(state.issues_identified), schema_available=bool(task.schema_info)) | |
| remaining = len(task.ground_truth_issues) - len(state.issues_identified) | |
| feedback = f"Matched {matched_issue.category} issue '{matched_issue.id}'. {remaining} issue(s) remaining." | |
| info = { | |
| "match_score": round(score, 3), | |
| "match_type": "fuzzy", | |
| "severity": matched_issue.severity, | |
| "issue_id": matched_issue.id, | |
| "all_issues_found": remaining == 0, | |
| } | |
| if fix_valid and action.suggested_fix: | |
| state.fixes_suggested.append(action.suggested_fix) | |
| elif action.action_type == "suggest_fix": | |
| if not state.issues_identified: | |
| reward = compute_reward(action, None, has_previous_issue=False) | |
| feedback = "Identify an issue before suggesting a fix." | |
| else: | |
| last_issue_id = state.issues_identified[-1].issue_id | |
| last_issue = next(issue for issue in task.ground_truth_issues if issue.id == last_issue_id) | |
| fix_valid = validate_fix(action.suggested_fix, last_issue) | |
| reward = compute_reward(action, last_issue, fix_valid=fix_valid, has_previous_issue=True) | |
| feedback = "Fix accepted for the last identified issue." if fix_valid else "Suggested fix did not match the expected remediation." | |
| info = {"issue_id": last_issue.id, "fix_valid": fix_valid} | |
| if fix_valid and action.suggested_fix: | |
| state.fixes_suggested.append(action.suggested_fix) | |
| elif action.action_type == "approve": | |
| remaining_unfound = len(task.ground_truth_issues) - len(found_ids) | |
| reward = compute_reward(action, None, remaining_unfound=remaining_unfound) | |
| state.approved = True | |
| state.done = True | |
| feedback = ( | |
| "Query approved with full issue coverage." | |
| if remaining_unfound == 0 | |
| else f"Query approved too early. {remaining_unfound} issue(s) were missed." | |
| ) | |
| info = {"remaining_unfound": remaining_unfound} | |
| else: | |
| feedback = self._schema_feedback(task) | |
| reward = compute_reward(action, None, schema_available=bool(task.schema_info)) | |
| info = {"context_shared": bool(task.schema_info)} | |
| state.total_reward += reward | |
| if state.step_count >= task.max_steps and not state.done: | |
| state.done = True | |
| feedback = f"{feedback} Maximum step count reached." | |
| if state.done: | |
| state.final_score = grade_episode( | |
| found_issue_ids={issue.issue_id for issue in state.issues_identified}, | |
| ground_truth_issues=task.ground_truth_issues, | |
| total_steps=state.step_count, | |
| max_steps=task.max_steps, | |
| false_positive_count=state.false_positive_count, | |
| ) | |
| info["final_score"] = state.final_score | |
| observation = self._build_observation(feedback=feedback) | |
| return StepResult(observation=observation, reward=reward, done=state.done, info=info) | |
| def state(self) -> SQLReviewState: | |
| return self._require_state().model_copy(deep=True) | |
| def _load_tasks(self) -> dict[str, TaskRecord]: | |
| tasks: dict[str, TaskRecord] = {} | |
| for file_path in sorted(self.task_directory.glob("*_tasks.json")): | |
| with file_path.open("r", encoding="utf-8") as handle: | |
| for raw_task in json.load(handle): | |
| task = TaskRecord.model_validate(raw_task) | |
| tasks[task.task_id] = task | |
| if not tasks: | |
| raise RuntimeError(f"No task files found in {self.task_directory}") | |
| return tasks | |
| def _next_task_id(self) -> str: | |
| task_id = self.task_order[self._reset_index % len(self.task_order)] | |
| self._reset_index += 1 | |
| return task_id | |
| def _build_observation(self, feedback: str) -> SQLReviewObservation: | |
| task = self._require_task() | |
| state = self._require_state() | |
| remaining_actions = max(task.max_steps - state.step_count, 0) | |
| return SQLReviewObservation( | |
| query=task.query, | |
| schema_info=task.schema_info, | |
| context=task.context, | |
| issues_found_so_far=state.issues_identified, | |
| remaining_actions=remaining_actions, | |
| difficulty=task.difficulty, | |
| feedback=feedback, | |
| ) | |
| def _schema_feedback(self, task: TaskRecord) -> str: | |
| if not task.schema_info: | |
| return "No additional schema context is available for this task." | |
| tables = ", ".join(sorted(task.schema_info)) | |
| return f"Schema context available for: {tables}." | |
| def _require_task(self) -> TaskRecord: | |
| if self.current_task is None: | |
| raise RuntimeError("Environment has no active task. Call reset() first.") | |
| return self.current_task | |
| def _require_state(self) -> SQLReviewState: | |
| if self.current_state is None: | |
| raise RuntimeError("Environment has no active state. Call reset() first.") | |
| return self.current_state | |