Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- DEMO_SCRIPT.md +12 -0
- Dockerfile +30 -0
- README.md +62 -10
- __init__.py +36 -0
- analyzers/__init__.py +13 -0
- analyzers/ds_analyzer.py +56 -0
- analyzers/dsa_analyzer.py +48 -0
- analyzers/ml_analyzer.py +61 -0
- analyzers/web_analyzer.py +50 -0
- api/__init__.py +5 -0
- api/main.py +27 -0
- app/__init__.py +1 -0
- app/examples.py +31 -0
- app/streamlit_app.py +100 -0
- client.py +37 -0
- compat.py +13 -0
- graders/__init__.py +5 -0
- graders/bug_fix.py +102 -0
- graders/dispatch.py +32 -0
- graders/optimization.py +122 -0
- graders/shared.py +431 -0
- graders/syntax.py +95 -0
- inference.py +383 -0
- launch.py +35 -0
- models.py +140 -0
- models/__init__.py +5 -0
- models/pytorch_model.py +149 -0
- openenv.yaml +6 -0
- pyproject.toml +46 -0
- schemas/__init__.py +13 -0
- schemas/request.py +19 -0
- schemas/response.py +70 -0
- server/__init__.py +6 -0
- server/app.py +52 -0
- server/demo.py +441 -0
- server/env.py +396 -0
- server/python_env_environment.py +3 -0
- server/requirements.txt +9 -0
- services/__init__.py +7 -0
- services/analysis_service.py +133 -0
- services/reward_service.py +27 -0
- services/suggestion_service.py +28 -0
- tasks/__init__.py +12 -0
- tasks/catalog.py +324 -0
- tests/test_multi_domain_platform.py +52 -0
- tests/test_scoring.py +42 -0
- tests/test_triage_pipeline.py +46 -0
- triage.py +473 -0
- triage_catalog.py +134 -0
- triage_models.py +79 -0
DEMO_SCRIPT.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TorchReview Copilot Demo Script
|
| 2 |
+
|
| 3 |
+
## 60-90 Second Walkthrough
|
| 4 |
+
|
| 5 |
+
1. Open the Hugging Face Space and introduce TorchReview Copilot as an AI-powered code review and improvement system built with PyTorch.
|
| 6 |
+
2. Point to the problem statement: manual code review is slow, inconsistent, and hard to scale.
|
| 7 |
+
3. Select the `Fix the invoice total syntax regression` example to show the app loading a broken code sample together with the context window.
|
| 8 |
+
4. Highlight the **Live Triage Radar**, the ML quality score, and the RL-ready reward score.
|
| 9 |
+
5. Explain that the PyTorch layer uses CodeBERTa embeddings to compare the input against known code-quality patterns from the OpenEnv task catalog.
|
| 10 |
+
6. Scroll to the three-step improvement plan and call out the progression: syntax and bug fixes, edge cases, then scalability.
|
| 11 |
+
7. Switch to the performance example to show the confidence profile and reward changing for a different class of issue.
|
| 12 |
+
8. Close by noting that OpenEnv still powers deterministic validation under the hood, so the demo remains grounded in measurable task outcomes.
|
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
COPY pyproject.toml README.md DEMO_SCRIPT.md openenv.yaml __init__.py client.py compat.py openenv_models.py inference.py triage.py triage_catalog.py triage_models.py launch.py /app/
|
| 10 |
+
COPY api /app/api
|
| 11 |
+
COPY app /app/app
|
| 12 |
+
COPY analyzers /app/analyzers
|
| 13 |
+
COPY models /app/models
|
| 14 |
+
COPY schemas /app/schemas
|
| 15 |
+
COPY server /app/server
|
| 16 |
+
COPY services /app/services
|
| 17 |
+
COPY tasks /app/tasks
|
| 18 |
+
COPY utils /app/utils
|
| 19 |
+
COPY graders /app/graders
|
| 20 |
+
|
| 21 |
+
RUN python -m pip install --upgrade pip && \
|
| 22 |
+
pip install .
|
| 23 |
+
|
| 24 |
+
EXPOSE 8000
|
| 25 |
+
|
| 26 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
| 27 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000', timeout=3).read()"
|
| 28 |
+
|
| 29 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 30 |
+
CMD ["python", "launch.py"]
|
README.md
CHANGED
|
@@ -1,10 +1,62 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: TorchReview Copilot
|
| 3 |
+
colorFrom: yellow
|
| 4 |
+
colorTo: red
|
| 5 |
+
sdk: docker
|
| 6 |
+
pinned: false
|
| 7 |
+
app_port: 8000
|
| 8 |
+
tags:
|
| 9 |
+
- pytorch
|
| 10 |
+
- gradio
|
| 11 |
+
- fastapi
|
| 12 |
+
- openenv
|
| 13 |
+
- code-review
|
| 14 |
+
base_path: /web
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# TorchReview Copilot
|
| 18 |
+
|
| 19 |
+
TorchReview Copilot is an **AI-powered code review and improvement system using PyTorch** to analyze Python code, predict quality, generate structured improvement suggestions, and compute an RL-ready reward score.
|
| 20 |
+
|
| 21 |
+
It upgrades the original OpenEnv hackathon environment into a judge-friendly product demo: a polished Hugging Face Space on top, with the deterministic OpenEnv validation engine still preserved underneath.
|
| 22 |
+
|
| 23 |
+
**Live demo:** https://huggingface.co/spaces/uvpatel7271/final-python-env
|
| 24 |
+
**Repository:** https://github.com/uvpatel/final-python-env
|
| 25 |
+
|
| 26 |
+
## Problem Statement
|
| 27 |
+
|
| 28 |
+
Engineering teams lose time during incident response and code review because broken Python snippets often arrive with noisy traces, partial test output, and unclear ownership. Before fixing anything, someone still has to answer:
|
| 29 |
+
|
| 30 |
+
- Is this a syntax issue, a logic bug, or a performance regression?
|
| 31 |
+
- How risky is the repair?
|
| 32 |
+
- What should be checked first?
|
| 33 |
+
|
| 34 |
+
That triage step is repetitive, error-prone, and often slows down the actual fix.
|
| 35 |
+
|
| 36 |
+
## Solution
|
| 37 |
+
|
| 38 |
+
TorchReview Copilot turns code, traceback text, and a short context window into a practical code-review report:
|
| 39 |
+
|
| 40 |
+
- **Issue classification:** syntax, logic, or performance
|
| 41 |
+
- **ML quality score:** predicted code quality from PyTorch embeddings
|
| 42 |
+
- **Reward score:** RL-ready score from model quality, lint quality, and complexity penalty
|
| 43 |
+
- **Live Triage Radar:** confidence visualization for all issue classes
|
| 44 |
+
- **Nearest known pattern:** the closest OpenEnv task match
|
| 45 |
+
- **Improvement plan:** step 1 syntax/bug fixes, step 2 edge cases, step 3 scalability
|
| 46 |
+
|
| 47 |
+
## Why PyTorch Matters
|
| 48 |
+
|
| 49 |
+
This project uses **PyTorch for real inference**, not placeholder branching:
|
| 50 |
+
|
| 51 |
+
- `transformers` + `torch` load `huggingface/CodeBERTa-small-v1`
|
| 52 |
+
- embeddings compare code with OpenEnv issue prototypes
|
| 53 |
+
- combines ML + static analysis signals
|
| 54 |
+
|
| 55 |
+
## How It Works
|
| 56 |
+
|
| 57 |
+
`Input → static checks → PyTorch embeddings → prediction → suggestions → reward`
|
| 58 |
+
|
| 59 |
+
## Reward Formula
|
| 60 |
+
|
| 61 |
+
```text
|
| 62 |
+
reward = (0.5 x ML_quality_score) + (0.3 x lint_score) - (0.2 x complexity_penalty)
|
__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public package exports for python_code_review_env."""
|
| 2 |
+
|
| 3 |
+
from .client import PythonCodeReviewEnv, PythonEnv
|
| 4 |
+
from .models import PyTorchCodeAnalyzerModel
|
| 5 |
+
from .Models import (
|
| 6 |
+
PythonAction,
|
| 7 |
+
PythonCodeReviewAction,
|
| 8 |
+
PythonCodeReviewObservation,
|
| 9 |
+
PythonCodeReviewState,
|
| 10 |
+
PythonObservation,
|
| 11 |
+
PythonState,
|
| 12 |
+
)
|
| 13 |
+
from .schemas import AnalyzeCodeRequest, AnalyzeCodeResponse
|
| 14 |
+
from .services import AnalysisService
|
| 15 |
+
from .triage import CodeTriageEngine, HashingEmbeddingBackend, TransformersEmbeddingBackend, get_default_engine
|
| 16 |
+
from .triage_models import TriageResult
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"PythonAction",
|
| 20 |
+
"PythonObservation",
|
| 21 |
+
"PythonState",
|
| 22 |
+
"PythonCodeReviewAction",
|
| 23 |
+
"PythonCodeReviewObservation",
|
| 24 |
+
"PythonCodeReviewState",
|
| 25 |
+
"PythonCodeReviewEnv",
|
| 26 |
+
"PythonEnv",
|
| 27 |
+
"AnalyzeCodeRequest",
|
| 28 |
+
"AnalyzeCodeResponse",
|
| 29 |
+
"AnalysisService",
|
| 30 |
+
"CodeTriageEngine",
|
| 31 |
+
"HashingEmbeddingBackend",
|
| 32 |
+
"PyTorchCodeAnalyzerModel",
|
| 33 |
+
"TransformersEmbeddingBackend",
|
| 34 |
+
"TriageResult",
|
| 35 |
+
"get_default_engine",
|
| 36 |
+
]
|
analyzers/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Domain-specific analyzers for multi-domain code understanding."""
|
| 2 |
+
|
| 3 |
+
from .dsa_analyzer import analyze_dsa_code
|
| 4 |
+
from .ds_analyzer import analyze_data_science_code
|
| 5 |
+
from .ml_analyzer import analyze_ml_code
|
| 6 |
+
from .web_analyzer import analyze_web_code
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"analyze_dsa_code",
|
| 10 |
+
"analyze_data_science_code",
|
| 11 |
+
"analyze_ml_code",
|
| 12 |
+
"analyze_web_code",
|
| 13 |
+
]
|
analyzers/ds_analyzer.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analyzer for data-science oriented Python code."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
from schemas.response import AnalysisIssue, DomainAnalysis
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def analyze_data_science_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
|
| 11 |
+
"""Inspect pandas and numpy code for vectorization and leakage concerns."""
|
| 12 |
+
|
| 13 |
+
issues = []
|
| 14 |
+
suggestions = []
|
| 15 |
+
score = 0.72
|
| 16 |
+
|
| 17 |
+
if "iterrows(" in code or "itertuples(" in code:
|
| 18 |
+
issues.append(
|
| 19 |
+
AnalysisIssue(
|
| 20 |
+
title="Row-wise dataframe iteration detected",
|
| 21 |
+
severity="medium",
|
| 22 |
+
description="Looping through dataframe rows is usually slower and less scalable than vectorized operations.",
|
| 23 |
+
)
|
| 24 |
+
)
|
| 25 |
+
suggestions.append("Use vectorized pandas or numpy expressions instead of row-wise iteration.")
|
| 26 |
+
score -= 0.18
|
| 27 |
+
|
| 28 |
+
if "inplace=True" in code:
|
| 29 |
+
suggestions.append("Avoid inplace mutation to keep data pipelines easier to reason about and test.")
|
| 30 |
+
score -= 0.05
|
| 31 |
+
|
| 32 |
+
if "fit_transform(" in code and "train_test_split" not in code:
|
| 33 |
+
issues.append(
|
| 34 |
+
AnalysisIssue(
|
| 35 |
+
title="Potential data leakage risk",
|
| 36 |
+
severity="high",
|
| 37 |
+
description="Feature transforms appear before an explicit train/test split.",
|
| 38 |
+
)
|
| 39 |
+
)
|
| 40 |
+
suggestions.append("Split train and validation data before fitting stateful preprocessing steps.")
|
| 41 |
+
score -= 0.2
|
| 42 |
+
|
| 43 |
+
if not suggestions:
|
| 44 |
+
suggestions.append("Add schema assumptions and null-handling checks for production data quality.")
|
| 45 |
+
|
| 46 |
+
return DomainAnalysis(
|
| 47 |
+
domain="data_science",
|
| 48 |
+
domain_score=max(0.05, round(score, 4)),
|
| 49 |
+
issues=issues,
|
| 50 |
+
suggestions=suggestions,
|
| 51 |
+
highlights={
|
| 52 |
+
"vectorization_risk": float("iterrows(" in code or "itertuples(" in code),
|
| 53 |
+
"time_complexity": complexity["time_complexity"],
|
| 54 |
+
"uses_pandas": float(parsed.get("uses_pandas", False)),
|
| 55 |
+
},
|
| 56 |
+
)
|
analyzers/dsa_analyzer.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analyzer for DSA and competitive-programming style Python code."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
from schemas.response import AnalysisIssue, DomainAnalysis
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def analyze_dsa_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
|
| 11 |
+
"""Inspect algorithmic code for brute-force patterns and efficiency risks."""
|
| 12 |
+
|
| 13 |
+
issues = []
|
| 14 |
+
suggestions = []
|
| 15 |
+
score = 0.7
|
| 16 |
+
|
| 17 |
+
if parsed.get("max_loop_depth", 0) >= 2:
|
| 18 |
+
issues.append(
|
| 19 |
+
AnalysisIssue(
|
| 20 |
+
title="Nested loops suggest brute-force behavior",
|
| 21 |
+
severity="medium",
|
| 22 |
+
description="The implementation scans the input multiple times, which is often avoidable in DSA problems.",
|
| 23 |
+
)
|
| 24 |
+
)
|
| 25 |
+
suggestions.append("Consider replacing nested scans with a hashmap, prefix table, or sorted search strategy.")
|
| 26 |
+
score -= 0.15
|
| 27 |
+
|
| 28 |
+
if parsed.get("uses_recursion"):
|
| 29 |
+
suggestions.append("Verify recursion depth and add memoization or iterative conversion if the input size can grow.")
|
| 30 |
+
score -= 0.05
|
| 31 |
+
|
| 32 |
+
if "sorted(" in code or ".sort(" in code:
|
| 33 |
+
suggestions.append("Sorting is acceptable here, but validate whether a direct O(n) pass can remove the sort.")
|
| 34 |
+
|
| 35 |
+
if not suggestions:
|
| 36 |
+
suggestions.append("Document the intended time complexity and add edge-case checks for empty input and duplicates.")
|
| 37 |
+
|
| 38 |
+
return DomainAnalysis(
|
| 39 |
+
domain="dsa",
|
| 40 |
+
domain_score=max(0.05, round(score, 4)),
|
| 41 |
+
issues=issues,
|
| 42 |
+
suggestions=suggestions,
|
| 43 |
+
highlights={
|
| 44 |
+
"time_complexity": complexity["time_complexity"],
|
| 45 |
+
"space_complexity": complexity["space_complexity"],
|
| 46 |
+
"max_loop_depth": float(parsed.get("max_loop_depth", 0)),
|
| 47 |
+
},
|
| 48 |
+
)
|
analyzers/ml_analyzer.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analyzer for machine-learning and deep-learning code."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
from schemas.response import AnalysisIssue, DomainAnalysis
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def analyze_ml_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
|
| 11 |
+
"""Inspect training and inference logic for common ML / DL mistakes."""
|
| 12 |
+
|
| 13 |
+
issues = []
|
| 14 |
+
suggestions = []
|
| 15 |
+
score = 0.74
|
| 16 |
+
|
| 17 |
+
if "torch" in code and "model.eval()" not in code and "predict" in code.lower():
|
| 18 |
+
issues.append(
|
| 19 |
+
AnalysisIssue(
|
| 20 |
+
title="Inference path may be missing eval mode",
|
| 21 |
+
severity="high",
|
| 22 |
+
description="Inference code should place the model in eval mode before prediction.",
|
| 23 |
+
)
|
| 24 |
+
)
|
| 25 |
+
suggestions.append("Call model.eval() before inference to disable training-time behavior such as dropout.")
|
| 26 |
+
score -= 0.18
|
| 27 |
+
|
| 28 |
+
if "torch" in code and "no_grad" not in code and "predict" in code.lower():
|
| 29 |
+
suggestions.append("Wrap inference in torch.no_grad() to reduce memory usage and avoid unnecessary gradient tracking.")
|
| 30 |
+
score -= 0.12
|
| 31 |
+
|
| 32 |
+
if parsed.get("calls_backward") and not parsed.get("calls_optimizer_step"):
|
| 33 |
+
issues.append(
|
| 34 |
+
AnalysisIssue(
|
| 35 |
+
title="Backward pass without optimizer step",
|
| 36 |
+
severity="medium",
|
| 37 |
+
description="Gradients are computed, but the optimizer step is not obvious in the snippet.",
|
| 38 |
+
)
|
| 39 |
+
)
|
| 40 |
+
suggestions.append("Ensure optimizer.step() and optimizer.zero_grad() are placed correctly in the training loop.")
|
| 41 |
+
score -= 0.12
|
| 42 |
+
|
| 43 |
+
if "CrossEntropyLoss" in code and "softmax(" in code:
|
| 44 |
+
suggestions.append("CrossEntropyLoss expects raw logits; remove the explicit softmax before the loss when possible.")
|
| 45 |
+
score -= 0.05
|
| 46 |
+
|
| 47 |
+
if not suggestions:
|
| 48 |
+
suggestions.append("Add explicit train/eval mode transitions and log validation metrics during training.")
|
| 49 |
+
|
| 50 |
+
return DomainAnalysis(
|
| 51 |
+
domain="ml_dl",
|
| 52 |
+
domain_score=max(0.05, round(score, 4)),
|
| 53 |
+
issues=issues,
|
| 54 |
+
suggestions=suggestions,
|
| 55 |
+
highlights={
|
| 56 |
+
"uses_torch": float(parsed.get("uses_torch", False)),
|
| 57 |
+
"has_eval_mode": float("model.eval()" in code),
|
| 58 |
+
"has_no_grad": float("no_grad" in code),
|
| 59 |
+
"time_complexity": complexity["time_complexity"],
|
| 60 |
+
},
|
| 61 |
+
)
|
analyzers/web_analyzer.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analyzer for FastAPI and backend web-service code."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
from schemas.response import AnalysisIssue, DomainAnalysis
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def analyze_web_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
|
| 11 |
+
"""Inspect API code for validation, routing, and backend safety concerns."""
|
| 12 |
+
|
| 13 |
+
issues = []
|
| 14 |
+
suggestions = []
|
| 15 |
+
score = 0.76
|
| 16 |
+
|
| 17 |
+
route_decorators = set(parsed.get("route_decorators", []))
|
| 18 |
+
if route_decorators and not parsed.get("uses_pydantic"):
|
| 19 |
+
issues.append(
|
| 20 |
+
AnalysisIssue(
|
| 21 |
+
title="Request validation model is missing",
|
| 22 |
+
severity="high",
|
| 23 |
+
description="Route handlers appear present, but no obvious Pydantic validation layer was detected.",
|
| 24 |
+
)
|
| 25 |
+
)
|
| 26 |
+
suggestions.append("Add Pydantic request and response models for strict validation and type-safe contracts.")
|
| 27 |
+
score -= 0.2
|
| 28 |
+
|
| 29 |
+
if {"get", "post", "put", "delete"} & route_decorators and "async def" not in code:
|
| 30 |
+
suggestions.append("Prefer async FastAPI endpoints when the route performs I/O or awaits downstream services.")
|
| 31 |
+
score -= 0.08
|
| 32 |
+
|
| 33 |
+
if "request.json()" in code or "request.body()" in code:
|
| 34 |
+
suggestions.append("Validate raw request payloads before use; avoid trusting unchecked JSON input.")
|
| 35 |
+
score -= 0.08
|
| 36 |
+
|
| 37 |
+
if not suggestions:
|
| 38 |
+
suggestions.append("Add domain-specific response models and centralize dependency injection for cleaner API structure.")
|
| 39 |
+
|
| 40 |
+
return DomainAnalysis(
|
| 41 |
+
domain="web",
|
| 42 |
+
domain_score=max(0.05, round(score, 4)),
|
| 43 |
+
issues=issues,
|
| 44 |
+
suggestions=suggestions,
|
| 45 |
+
highlights={
|
| 46 |
+
"route_count": float(len(route_decorators)),
|
| 47 |
+
"uses_validation": float(parsed.get("uses_pydantic", False)),
|
| 48 |
+
"time_complexity": complexity["time_complexity"],
|
| 49 |
+
},
|
| 50 |
+
)
|
api/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI backend package for the multi-domain analyzer."""
|
| 2 |
+
|
| 3 |
+
from .main import app
|
| 4 |
+
|
| 5 |
+
__all__ = ["app"]
|
api/main.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI backend for the multi-domain AI code analyzer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
|
| 7 |
+
from schemas.request import AnalyzeCodeRequest
|
| 8 |
+
from schemas.response import AnalyzeCodeResponse
|
| 9 |
+
from services.analysis_service import AnalysisService
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
app = FastAPI(title="Multi-Domain AI Code Analyzer", version="2.0.0")
|
| 13 |
+
analysis_service = AnalysisService()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@app.get("/health")
|
| 17 |
+
def health() -> dict[str, str]:
|
| 18 |
+
"""Return a simple health payload for deployments and smoke tests."""
|
| 19 |
+
|
| 20 |
+
return {"status": "ok"}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@app.post("/analyze", response_model=AnalyzeCodeResponse)
|
| 24 |
+
def analyze_code(payload: AnalyzeCodeRequest) -> AnalyzeCodeResponse:
|
| 25 |
+
"""Analyze code across supported domains and return structured results."""
|
| 26 |
+
|
| 27 |
+
return analysis_service.analyze(payload)
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Streamlit UI package for the multi-domain analyzer."""
|
app/examples.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example snippets for each supported analysis domain."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
EXAMPLES = {
|
| 7 |
+
"DSA": {
|
| 8 |
+
"domain_hint": "dsa",
|
| 9 |
+
"context_window": "Competitive-programming helper for pair lookup on large arrays.",
|
| 10 |
+
"traceback_text": "",
|
| 11 |
+
"code": """def two_sum(nums, target):\n for i in range(len(nums)):\n for j in range(i + 1, len(nums)):\n if nums[i] + nums[j] == target:\n return [i, j]\n return []\n""",
|
| 12 |
+
},
|
| 13 |
+
"Data Science": {
|
| 14 |
+
"domain_hint": "data_science",
|
| 15 |
+
"context_window": "Feature engineering step in a churn-prediction notebook.",
|
| 16 |
+
"traceback_text": "",
|
| 17 |
+
"code": """import pandas as pd\n\ndef encode_features(df):\n values = []\n for _, row in df.iterrows():\n values.append(row['age'] * row['sessions'])\n df['score'] = values\n return df\n""",
|
| 18 |
+
},
|
| 19 |
+
"ML / DL": {
|
| 20 |
+
"domain_hint": "ml_dl",
|
| 21 |
+
"context_window": "Inference utility for a PyTorch classifier used in a batch review job.",
|
| 22 |
+
"traceback_text": "",
|
| 23 |
+
"code": """import torch\n\nclass Predictor:\n def __init__(self, model):\n self.model = model\n\n def predict(self, batch):\n outputs = self.model(batch)\n return outputs.argmax(dim=1)\n""",
|
| 24 |
+
},
|
| 25 |
+
"Web / FastAPI": {
|
| 26 |
+
"domain_hint": "web",
|
| 27 |
+
"context_window": "Backend endpoint for creating review tasks from user-submitted payloads.",
|
| 28 |
+
"traceback_text": "",
|
| 29 |
+
"code": """from fastapi import FastAPI, Request\n\napp = FastAPI()\n\n@app.post('/tasks')\ndef create_task(request: Request):\n payload = request.json()\n return {'task': payload}\n""",
|
| 30 |
+
},
|
| 31 |
+
}
|
app/streamlit_app.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Streamlit frontend for the multi-domain analyzer platform."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
|
| 7 |
+
from app.examples import EXAMPLES
|
| 8 |
+
from schemas.request import AnalyzeCodeRequest
|
| 9 |
+
from services.analysis_service import AnalysisService
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
analysis_service = AnalysisService()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _analyze(code: str, context_window: str, traceback_text: str, domain_hint: str):
|
| 16 |
+
"""Run the analysis service with validated request payloads."""
|
| 17 |
+
|
| 18 |
+
request = AnalyzeCodeRequest(
|
| 19 |
+
code=code,
|
| 20 |
+
context_window=context_window,
|
| 21 |
+
traceback_text=traceback_text,
|
| 22 |
+
domain_hint=domain_hint, # type: ignore[arg-type]
|
| 23 |
+
)
|
| 24 |
+
return analysis_service.analyze(request)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main() -> None:
|
| 28 |
+
"""Render the Streamlit UI."""
|
| 29 |
+
|
| 30 |
+
st.set_page_config(page_title="Multi-Domain AI Code Analyzer", layout="wide")
|
| 31 |
+
st.title("Multi-Domain AI Code Analyzer & Improvement System")
|
| 32 |
+
st.caption("PyTorch-powered code review across DSA, Data Science, ML/DL, and Web backend code.")
|
| 33 |
+
|
| 34 |
+
example_name = st.selectbox("Example input", list(EXAMPLES.keys()))
|
| 35 |
+
example = EXAMPLES[example_name]
|
| 36 |
+
auto_analyze = st.toggle("Real-time scoring", value=True)
|
| 37 |
+
|
| 38 |
+
left, right = st.columns([1.2, 1.0])
|
| 39 |
+
with left:
|
| 40 |
+
code = st.text_area("Code input", value=example["code"], height=420)
|
| 41 |
+
context_window = st.text_area("Context window", value=example["context_window"], height=100)
|
| 42 |
+
traceback_text = st.text_area("Optional traceback / runtime hint", value=example["traceback_text"], height=100)
|
| 43 |
+
domain_hint = st.selectbox("Domain hint", ["auto", "dsa", "data_science", "ml_dl", "web"], index=["auto", "dsa", "data_science", "ml_dl", "web"].index(example["domain_hint"]))
|
| 44 |
+
analyze_clicked = st.button("Analyze Code", type="primary")
|
| 45 |
+
|
| 46 |
+
result = None
|
| 47 |
+
if code and (analyze_clicked or auto_analyze):
|
| 48 |
+
result = _analyze(code, context_window, traceback_text, domain_hint)
|
| 49 |
+
|
| 50 |
+
with right:
|
| 51 |
+
if result is None:
|
| 52 |
+
st.info("Paste code or load an example to start analysis.")
|
| 53 |
+
else:
|
| 54 |
+
metric_cols = st.columns(4)
|
| 55 |
+
metric_cols[0].metric("Detected domain", result.detected_domain)
|
| 56 |
+
metric_cols[1].metric("ML score", f"{result.score_breakdown.ml_score:.0%}")
|
| 57 |
+
metric_cols[2].metric("Domain score", f"{result.score_breakdown.domain_score:.0%}")
|
| 58 |
+
metric_cols[3].metric("Reward", f"{result.score_breakdown.reward:.0%}")
|
| 59 |
+
st.bar_chart(result.domain_confidences)
|
| 60 |
+
st.caption(result.summary)
|
| 61 |
+
|
| 62 |
+
if result is not None:
|
| 63 |
+
overview_tab, suggestions_tab, domain_tab, static_tab = st.tabs(
|
| 64 |
+
["Overview", "Suggestions", "Domain Detail", "Static Analysis"]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
with overview_tab:
|
| 68 |
+
st.subheader("Improvement Plan")
|
| 69 |
+
for step in result.improvement_plan:
|
| 70 |
+
st.write(f"- {step}")
|
| 71 |
+
st.subheader("Complexity")
|
| 72 |
+
st.write(
|
| 73 |
+
{
|
| 74 |
+
"time_complexity": result.static_analysis.time_complexity,
|
| 75 |
+
"space_complexity": result.static_analysis.space_complexity,
|
| 76 |
+
"cyclomatic_complexity": result.static_analysis.cyclomatic_complexity,
|
| 77 |
+
}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
with suggestions_tab:
|
| 81 |
+
st.subheader("Suggestions")
|
| 82 |
+
for suggestion in result.domain_analysis.suggestions:
|
| 83 |
+
st.write(f"- {suggestion}")
|
| 84 |
+
if result.domain_analysis.issues:
|
| 85 |
+
st.subheader("Issues")
|
| 86 |
+
for issue in result.domain_analysis.issues:
|
| 87 |
+
st.write(f"- [{issue.severity}] {issue.title}: {issue.description}")
|
| 88 |
+
|
| 89 |
+
with domain_tab:
|
| 90 |
+
st.subheader("Domain Highlights")
|
| 91 |
+
st.json(result.domain_analysis.highlights)
|
| 92 |
+
st.write(f"Domain score: {result.domain_analysis.domain_score:.0%}")
|
| 93 |
+
|
| 94 |
+
with static_tab:
|
| 95 |
+
st.subheader("Static Analysis")
|
| 96 |
+
st.json(result.static_analysis.model_dump())
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
main()
|
client.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Client helpers for python_code_review_env."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict
|
| 6 |
+
|
| 7 |
+
from openenv.core import EnvClient
|
| 8 |
+
from openenv.core.client_types import StepResult
|
| 9 |
+
|
| 10 |
+
from .Models import (
|
| 11 |
+
PythonCodeReviewAction,
|
| 12 |
+
PythonCodeReviewObservation,
|
| 13 |
+
PythonCodeReviewState,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PythonCodeReviewEnv(
|
| 18 |
+
EnvClient[PythonCodeReviewAction, PythonCodeReviewObservation, PythonCodeReviewState]
|
| 19 |
+
):
|
| 20 |
+
"""Typed client for the code review environment."""
|
| 21 |
+
|
| 22 |
+
def _step_payload(self, action: PythonCodeReviewAction) -> Dict:
|
| 23 |
+
return action.model_dump(exclude_none=True)
|
| 24 |
+
|
| 25 |
+
def _parse_result(self, payload: Dict) -> StepResult[PythonCodeReviewObservation]:
|
| 26 |
+
observation = PythonCodeReviewObservation.model_validate(payload.get("observation", {}))
|
| 27 |
+
return StepResult(
|
| 28 |
+
observation=observation,
|
| 29 |
+
reward=payload.get("reward"),
|
| 30 |
+
done=payload.get("done", observation.done),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def _parse_state(self, payload: Dict) -> PythonCodeReviewState:
|
| 34 |
+
return PythonCodeReviewState.model_validate(payload)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
PythonEnv = PythonCodeReviewEnv
|
compat.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility helpers expected by validator-oriented scripts."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def install_openenv_fastmcp_compat() -> None:
|
| 7 |
+
"""Install runtime shims when needed.
|
| 8 |
+
|
| 9 |
+
The current environment does not require any monkey-patching, so this is a
|
| 10 |
+
deliberate no-op kept for validator compatibility.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
return None
|
graders/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic graders for python_code_review_env."""
|
| 2 |
+
|
| 3 |
+
from .dispatch import grade_task
|
| 4 |
+
|
| 5 |
+
__all__ = ["grade_task"]
|
graders/bug_fix.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bug-fix task grader."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from ..Models import TaskGrade
|
| 7 |
+
from ..tasks.catalog import ReviewTask
|
| 8 |
+
except ImportError:
|
| 9 |
+
from Models import TaskGrade
|
| 10 |
+
from tasks.catalog import ReviewTask
|
| 11 |
+
|
| 12 |
+
from .shared import (
|
| 13 |
+
base_grade,
|
| 14 |
+
compile_code,
|
| 15 |
+
component_score,
|
| 16 |
+
execute_cases,
|
| 17 |
+
quality_metrics,
|
| 18 |
+
shaped_score,
|
| 19 |
+
similarity_score,
|
| 20 |
+
summarize_results,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def grade_bug_fix_task(
|
| 25 |
+
task: ReviewTask,
|
| 26 |
+
code: str,
|
| 27 |
+
*,
|
| 28 |
+
include_hidden: bool,
|
| 29 |
+
timeout_s: float = 2.0,
|
| 30 |
+
) -> TaskGrade:
|
| 31 |
+
"""Grade a bug-fix task against public or full test suites."""
|
| 32 |
+
|
| 33 |
+
compiled, compile_error = compile_code(code)
|
| 34 |
+
quality = quality_metrics(code, task.function_name)
|
| 35 |
+
details = {
|
| 36 |
+
"compile_error": compile_error,
|
| 37 |
+
"quality_notes": quality["quality_notes"],
|
| 38 |
+
"style_score": quality["style_score"],
|
| 39 |
+
"visibility": "full" if include_hidden else "public",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
if not compiled:
|
| 43 |
+
progress = 0.02 + 0.12 * similarity_score(code, task.reference_code)
|
| 44 |
+
details["test_results"] = []
|
| 45 |
+
details["test_summary"] = "Code does not compile."
|
| 46 |
+
return base_grade(
|
| 47 |
+
score=shaped_score(progress),
|
| 48 |
+
syntax_score=component_score(0.01),
|
| 49 |
+
tests_passed=0,
|
| 50 |
+
tests_total=len(task.public_cases) + (len(task.hidden_cases) if include_hidden else 0),
|
| 51 |
+
quality_score=component_score(0.01),
|
| 52 |
+
runtime_score=component_score(0.01),
|
| 53 |
+
timed_out=False,
|
| 54 |
+
details=details,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
cases = task.public_cases + (task.hidden_cases if include_hidden else [])
|
| 58 |
+
result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s)
|
| 59 |
+
if result.get("timed_out"):
|
| 60 |
+
details["test_results"] = []
|
| 61 |
+
details["test_summary"] = result["error"]
|
| 62 |
+
progress = 0.12 + 0.18 * quality["score"]
|
| 63 |
+
return base_grade(
|
| 64 |
+
score=shaped_score(progress),
|
| 65 |
+
syntax_score=component_score(0.95),
|
| 66 |
+
tests_passed=0,
|
| 67 |
+
tests_total=len(cases),
|
| 68 |
+
quality_score=quality["score"],
|
| 69 |
+
runtime_score=component_score(0.01),
|
| 70 |
+
timed_out=True,
|
| 71 |
+
details=details,
|
| 72 |
+
)
|
| 73 |
+
if "error" in result:
|
| 74 |
+
details["test_results"] = []
|
| 75 |
+
details["test_summary"] = result["error"]
|
| 76 |
+
progress = 0.1 + 0.2 * quality["score"]
|
| 77 |
+
return base_grade(
|
| 78 |
+
score=shaped_score(progress),
|
| 79 |
+
syntax_score=component_score(0.95),
|
| 80 |
+
tests_passed=0,
|
| 81 |
+
tests_total=len(cases),
|
| 82 |
+
quality_score=quality["score"],
|
| 83 |
+
runtime_score=component_score(0.01),
|
| 84 |
+
timed_out=False,
|
| 85 |
+
details=details,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
data = result["data"]
|
| 89 |
+
pass_rate = data["passed"] / max(data["total"], 1)
|
| 90 |
+
details["test_results"] = data["results"]
|
| 91 |
+
details["test_summary"] = summarize_results("Test results", data["results"])
|
| 92 |
+
progress = min(1.0, 0.05 + 0.8 * pass_rate + 0.15 * quality["score"])
|
| 93 |
+
return base_grade(
|
| 94 |
+
score=shaped_score(progress),
|
| 95 |
+
syntax_score=component_score(0.95),
|
| 96 |
+
tests_passed=data["passed"],
|
| 97 |
+
tests_total=data["total"],
|
| 98 |
+
quality_score=quality["score"],
|
| 99 |
+
runtime_score=component_score(0.01),
|
| 100 |
+
timed_out=False,
|
| 101 |
+
details=details,
|
| 102 |
+
)
|
graders/dispatch.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task grader dispatch."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from ..Models import TaskGrade
|
| 7 |
+
from ..tasks.catalog import ReviewTask
|
| 8 |
+
except ImportError:
|
| 9 |
+
from Models import TaskGrade
|
| 10 |
+
from tasks.catalog import ReviewTask
|
| 11 |
+
|
| 12 |
+
from .bug_fix import grade_bug_fix_task
|
| 13 |
+
from .optimization import grade_optimization_task
|
| 14 |
+
from .syntax import grade_syntax_task
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def grade_task(
|
| 18 |
+
task: ReviewTask,
|
| 19 |
+
code: str,
|
| 20 |
+
*,
|
| 21 |
+
include_hidden: bool,
|
| 22 |
+
timeout_s: float = 3.0,
|
| 23 |
+
) -> TaskGrade:
|
| 24 |
+
"""Dispatch to the correct deterministic grader."""
|
| 25 |
+
|
| 26 |
+
if task.task_kind == "syntax_fix":
|
| 27 |
+
return grade_syntax_task(task, code, timeout_s=timeout_s)
|
| 28 |
+
if task.task_kind == "bug_fix":
|
| 29 |
+
return grade_bug_fix_task(task, code, include_hidden=include_hidden, timeout_s=timeout_s)
|
| 30 |
+
if task.task_kind == "optimization":
|
| 31 |
+
return grade_optimization_task(task, code, include_hidden=include_hidden, timeout_s=timeout_s)
|
| 32 |
+
raise ValueError(f"Unsupported task kind: {task.task_kind}")
|
graders/optimization.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optimization task grader."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from ..Models import TaskGrade
|
| 7 |
+
from ..tasks.catalog import ReviewTask
|
| 8 |
+
except ImportError:
|
| 9 |
+
from Models import TaskGrade
|
| 10 |
+
from tasks.catalog import ReviewTask
|
| 11 |
+
|
| 12 |
+
from .shared import (
|
| 13 |
+
base_grade,
|
| 14 |
+
benchmark_candidate,
|
| 15 |
+
compile_code,
|
| 16 |
+
component_score,
|
| 17 |
+
execute_cases,
|
| 18 |
+
quality_metrics,
|
| 19 |
+
shaped_score,
|
| 20 |
+
similarity_score,
|
| 21 |
+
summarize_results,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def grade_optimization_task(
|
| 26 |
+
task: ReviewTask,
|
| 27 |
+
code: str,
|
| 28 |
+
*,
|
| 29 |
+
include_hidden: bool,
|
| 30 |
+
timeout_s: float = 3.0,
|
| 31 |
+
) -> TaskGrade:
|
| 32 |
+
"""Grade an optimization/refactor task with correctness, quality, and runtime."""
|
| 33 |
+
|
| 34 |
+
compiled, compile_error = compile_code(code)
|
| 35 |
+
quality = quality_metrics(code, task.function_name)
|
| 36 |
+
details = {
|
| 37 |
+
"compile_error": compile_error,
|
| 38 |
+
"quality_notes": quality["quality_notes"],
|
| 39 |
+
"style_score": quality["style_score"],
|
| 40 |
+
"visibility": "full" if include_hidden else "public",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
if not compiled:
|
| 44 |
+
progress = 0.02 + 0.1 * similarity_score(code, task.reference_code)
|
| 45 |
+
details["test_results"] = []
|
| 46 |
+
details["test_summary"] = "Code does not compile."
|
| 47 |
+
return base_grade(
|
| 48 |
+
score=shaped_score(progress),
|
| 49 |
+
syntax_score=component_score(0.01),
|
| 50 |
+
tests_passed=0,
|
| 51 |
+
tests_total=len(task.public_cases) + (len(task.hidden_cases) if include_hidden else 0),
|
| 52 |
+
quality_score=component_score(0.01),
|
| 53 |
+
runtime_score=component_score(0.01),
|
| 54 |
+
timed_out=False,
|
| 55 |
+
details=details,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
cases = task.public_cases + (task.hidden_cases if include_hidden else [])
|
| 59 |
+
result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s)
|
| 60 |
+
if result.get("timed_out"):
|
| 61 |
+
details["test_results"] = []
|
| 62 |
+
details["test_summary"] = result["error"]
|
| 63 |
+
progress = 0.1 + 0.18 * quality["score"]
|
| 64 |
+
return base_grade(
|
| 65 |
+
score=shaped_score(progress),
|
| 66 |
+
syntax_score=component_score(0.95),
|
| 67 |
+
tests_passed=0,
|
| 68 |
+
tests_total=len(cases),
|
| 69 |
+
quality_score=quality["score"],
|
| 70 |
+
runtime_score=component_score(0.01),
|
| 71 |
+
timed_out=True,
|
| 72 |
+
details=details,
|
| 73 |
+
)
|
| 74 |
+
if "error" in result:
|
| 75 |
+
details["test_results"] = []
|
| 76 |
+
details["test_summary"] = result["error"]
|
| 77 |
+
progress = 0.1 + 0.2 * quality["score"]
|
| 78 |
+
return base_grade(
|
| 79 |
+
score=shaped_score(progress),
|
| 80 |
+
syntax_score=component_score(0.95),
|
| 81 |
+
tests_passed=0,
|
| 82 |
+
tests_total=len(cases),
|
| 83 |
+
quality_score=quality["score"],
|
| 84 |
+
runtime_score=component_score(0.01),
|
| 85 |
+
timed_out=False,
|
| 86 |
+
details=details,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
data = result["data"]
|
| 90 |
+
pass_rate = data["passed"] / max(data["total"], 1)
|
| 91 |
+
runtime_score = component_score(0.01)
|
| 92 |
+
benchmark_summary = "Benchmark deferred until hidden evaluation."
|
| 93 |
+
timed_out = False
|
| 94 |
+
|
| 95 |
+
if include_hidden and pass_rate == 1.0:
|
| 96 |
+
benchmark = benchmark_candidate(task, code, timeout_s=timeout_s)
|
| 97 |
+
runtime_score = benchmark["runtime_score"]
|
| 98 |
+
timed_out = benchmark.get("timed_out", False)
|
| 99 |
+
benchmark_summary = benchmark["details"]
|
| 100 |
+
if timed_out:
|
| 101 |
+
runtime_score = component_score(0.01)
|
| 102 |
+
|
| 103 |
+
details["test_results"] = data["results"]
|
| 104 |
+
details["test_summary"] = summarize_results("Test results", data["results"])
|
| 105 |
+
details["benchmark"] = benchmark_summary
|
| 106 |
+
|
| 107 |
+
runtime_progress = 0.0 if benchmark_summary == "Benchmark deferred until hidden evaluation." else runtime_score
|
| 108 |
+
if include_hidden:
|
| 109 |
+
progress = min(1.0, 0.05 + 0.6 * pass_rate + 0.2 * quality["score"] + 0.15 * runtime_progress)
|
| 110 |
+
else:
|
| 111 |
+
progress = min(1.0, 0.05 + 0.7 * pass_rate + 0.25 * quality["score"])
|
| 112 |
+
|
| 113 |
+
return base_grade(
|
| 114 |
+
score=shaped_score(progress),
|
| 115 |
+
syntax_score=component_score(0.95),
|
| 116 |
+
tests_passed=data["passed"],
|
| 117 |
+
tests_total=data["total"],
|
| 118 |
+
quality_score=quality["score"],
|
| 119 |
+
runtime_score=runtime_score,
|
| 120 |
+
timed_out=timed_out,
|
| 121 |
+
details=details,
|
| 122 |
+
)
|
graders/shared.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared deterministic grading helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import ast
|
| 6 |
+
import difflib
|
| 7 |
+
import math
|
| 8 |
+
import multiprocessing as mp
|
| 9 |
+
import time
|
| 10 |
+
import traceback
|
| 11 |
+
from typing import Any, Callable, Dict, List
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from ..Models import TaskGrade
|
| 15 |
+
from ..tasks.catalog import CallCase, ReviewTask
|
| 16 |
+
except ImportError:
|
| 17 |
+
from Models import TaskGrade
|
| 18 |
+
from tasks.catalog import CallCase, ReviewTask
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
STRICT_SCORE_MIN = 0.01
|
| 22 |
+
STRICT_SCORE_MAX = 0.99
|
| 23 |
+
POOR_SCORE = 0.1
|
| 24 |
+
NEAR_PERFECT_SCORE = 0.95
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def finite_float(value: Any, fallback: float = STRICT_SCORE_MIN) -> float:
|
| 28 |
+
"""Convert a value into a finite float with a deterministic fallback."""
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
numeric = float(value)
|
| 32 |
+
except (TypeError, ValueError):
|
| 33 |
+
return fallback
|
| 34 |
+
if math.isnan(numeric) or math.isinf(numeric):
|
| 35 |
+
return fallback
|
| 36 |
+
return numeric
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def clamp(value: float, lower: float = 0.0, upper: float = 1.0) -> float:
|
| 40 |
+
"""Clamp a floating-point value to a closed interval."""
|
| 41 |
+
|
| 42 |
+
numeric = finite_float(value, fallback=lower)
|
| 43 |
+
return max(lower, min(upper, numeric))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def strict_score(value: Any, lower: float = STRICT_SCORE_MIN, upper: float = STRICT_SCORE_MAX) -> float:
|
| 47 |
+
"""Clamp a score to the OpenEnv-safe open interval (0, 1)."""
|
| 48 |
+
|
| 49 |
+
score = max(lower, min(upper, finite_float(value, fallback=lower)))
|
| 50 |
+
score = round(score, 3)
|
| 51 |
+
assert 0 < score < 1, f"Invalid score: {score}"
|
| 52 |
+
return score
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def shaped_score(progress: Any, floor: float = POOR_SCORE, ceiling: float = NEAR_PERFECT_SCORE) -> float:
|
| 56 |
+
"""Map progress in [0, 1] to a shaped score band within (0, 1)."""
|
| 57 |
+
|
| 58 |
+
bounded_progress = clamp(finite_float(progress, fallback=0.0))
|
| 59 |
+
score = floor + (ceiling - floor) * bounded_progress
|
| 60 |
+
score = max(STRICT_SCORE_MIN, min(score, STRICT_SCORE_MAX))
|
| 61 |
+
score = round(score, 3)
|
| 62 |
+
assert 0 < score < 1, f"Invalid score: {score}"
|
| 63 |
+
return score
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def score_from_checks(passed: int, total: int, floor: float = POOR_SCORE, ceiling: float = NEAR_PERFECT_SCORE) -> float:
|
| 67 |
+
"""Convert discrete checks into a smoothly shaped score."""
|
| 68 |
+
|
| 69 |
+
return shaped_score(safe_ratio(passed, total), floor=floor, ceiling=ceiling)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def safe_ratio(numerator: Any, denominator: Any) -> float:
|
| 73 |
+
"""Return a stable ratio in [0, 1] that never raises or produces NaN."""
|
| 74 |
+
|
| 75 |
+
denom = int(finite_float(denominator, fallback=0.0))
|
| 76 |
+
if denom <= 0:
|
| 77 |
+
return 0.0
|
| 78 |
+
numer = finite_float(numerator, fallback=0.0)
|
| 79 |
+
return clamp(numer / denom)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def component_score(value: Any) -> float:
|
| 83 |
+
"""Normalize component scores such as syntax, quality, and runtime."""
|
| 84 |
+
|
| 85 |
+
return strict_score(value)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def compile_code(code: str) -> tuple[bool, str]:
|
| 89 |
+
"""Return whether code compiles and the syntax error, if any."""
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
compile(code, "<candidate>", "exec")
|
| 93 |
+
except SyntaxError as exc:
|
| 94 |
+
return False, f"SyntaxError: {exc.msg} (line {exc.lineno}, column {exc.offset})"
|
| 95 |
+
except Exception as exc: # pragma: no cover
|
| 96 |
+
return False, f"{type(exc).__name__}: {exc}"
|
| 97 |
+
return True, ""
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def similarity_score(candidate: str, reference: str) -> float:
|
| 101 |
+
"""Compute a stable text similarity score in [0, 1]."""
|
| 102 |
+
|
| 103 |
+
return difflib.SequenceMatcher(a=candidate.strip(), b=reference.strip()).ratio()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _queue_worker(
|
| 107 |
+
worker: Callable[[Dict[str, Any]], Dict[str, Any]],
|
| 108 |
+
payload: Dict[str, Any],
|
| 109 |
+
queue: Any,
|
| 110 |
+
) -> None:
|
| 111 |
+
try:
|
| 112 |
+
queue.put({"ok": True, "data": worker(payload)})
|
| 113 |
+
except Exception as exc: # pragma: no cover
|
| 114 |
+
queue.put(
|
| 115 |
+
{
|
| 116 |
+
"ok": False,
|
| 117 |
+
"error": f"{type(exc).__name__}: {exc}",
|
| 118 |
+
"traceback": traceback.format_exc(limit=5),
|
| 119 |
+
}
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def run_with_timeout(
|
| 124 |
+
worker: Callable[[Dict[str, Any]], Dict[str, Any]],
|
| 125 |
+
payload: Dict[str, Any],
|
| 126 |
+
timeout_s: float,
|
| 127 |
+
) -> Dict[str, Any]:
|
| 128 |
+
"""Execute a worker in a subprocess and terminate on timeout."""
|
| 129 |
+
|
| 130 |
+
ctx = mp.get_context("spawn")
|
| 131 |
+
queue = ctx.Queue()
|
| 132 |
+
process = ctx.Process(target=_queue_worker, args=(worker, payload, queue))
|
| 133 |
+
process.start()
|
| 134 |
+
process.join(timeout_s)
|
| 135 |
+
|
| 136 |
+
if process.is_alive():
|
| 137 |
+
process.terminate()
|
| 138 |
+
process.join()
|
| 139 |
+
return {"timed_out": True, "error": f"Execution exceeded {timeout_s:.1f}s timeout."}
|
| 140 |
+
|
| 141 |
+
if queue.empty():
|
| 142 |
+
return {"timed_out": False, "error": "Worker exited without returning a result."}
|
| 143 |
+
|
| 144 |
+
message = queue.get()
|
| 145 |
+
if not message["ok"]:
|
| 146 |
+
return {
|
| 147 |
+
"timed_out": False,
|
| 148 |
+
"error": f"{message['error']}\n{message['traceback']}",
|
| 149 |
+
}
|
| 150 |
+
return {"timed_out": False, "data": message["data"]}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _execute_cases_worker(payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 154 |
+
namespace: Dict[str, Any] = {}
|
| 155 |
+
exec(payload["code"], namespace)
|
| 156 |
+
func = namespace[payload["function_name"]]
|
| 157 |
+
results: List[Dict[str, Any]] = []
|
| 158 |
+
|
| 159 |
+
for case in payload["cases"]:
|
| 160 |
+
try:
|
| 161 |
+
actual = func(*case["args"], **case["kwargs"])
|
| 162 |
+
passed = actual == case["expected"]
|
| 163 |
+
actual_repr = repr(actual)
|
| 164 |
+
except Exception as exc:
|
| 165 |
+
passed = False
|
| 166 |
+
actual_repr = f"{type(exc).__name__}: {exc}"
|
| 167 |
+
|
| 168 |
+
results.append(
|
| 169 |
+
{
|
| 170 |
+
"label": case["label"],
|
| 171 |
+
"passed": passed,
|
| 172 |
+
"expected": repr(case["expected"]),
|
| 173 |
+
"actual": actual_repr,
|
| 174 |
+
}
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
passed_total = sum(1 for item in results if item["passed"])
|
| 178 |
+
return {"passed": passed_total, "total": len(results), "results": results}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def execute_cases(code: str, function_name: str, cases: List[CallCase], timeout_s: float) -> Dict[str, Any]:
|
| 182 |
+
"""Run function test cases in a subprocess."""
|
| 183 |
+
|
| 184 |
+
payload = {
|
| 185 |
+
"code": code,
|
| 186 |
+
"function_name": function_name,
|
| 187 |
+
"cases": [
|
| 188 |
+
{"label": case.label, "args": case.args, "kwargs": case.kwargs, "expected": case.expected}
|
| 189 |
+
for case in cases
|
| 190 |
+
],
|
| 191 |
+
}
|
| 192 |
+
return run_with_timeout(_execute_cases_worker, payload, timeout_s=timeout_s)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class _LoopDepthVisitor(ast.NodeVisitor):
|
| 196 |
+
def __init__(self) -> None:
|
| 197 |
+
self.depth = 0
|
| 198 |
+
self.max_depth = 0
|
| 199 |
+
|
| 200 |
+
def _visit_loop(self, node: ast.AST) -> None:
|
| 201 |
+
self.depth += 1
|
| 202 |
+
self.max_depth = max(self.max_depth, self.depth)
|
| 203 |
+
self.generic_visit(node)
|
| 204 |
+
self.depth -= 1
|
| 205 |
+
|
| 206 |
+
def visit_For(self, node: ast.For) -> None: # noqa: N802
|
| 207 |
+
self._visit_loop(node)
|
| 208 |
+
|
| 209 |
+
def visit_While(self, node: ast.While) -> None: # noqa: N802
|
| 210 |
+
self._visit_loop(node)
|
| 211 |
+
|
| 212 |
+
def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802
|
| 213 |
+
self._visit_loop(node)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def quality_metrics(code: str, function_name: str) -> Dict[str, Any]:
|
| 217 |
+
"""Compute deterministic AST/style quality metrics."""
|
| 218 |
+
|
| 219 |
+
compiled, error = compile_code(code)
|
| 220 |
+
if not compiled:
|
| 221 |
+
return {
|
| 222 |
+
"score": component_score(STRICT_SCORE_MIN),
|
| 223 |
+
"style_score": component_score(STRICT_SCORE_MIN),
|
| 224 |
+
"quality_notes": [error],
|
| 225 |
+
"max_loop_depth": 99,
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
tree = ast.parse(code)
|
| 229 |
+
function_node = next(
|
| 230 |
+
(
|
| 231 |
+
node
|
| 232 |
+
for node in tree.body
|
| 233 |
+
if isinstance(node, ast.FunctionDef) and node.name == function_name
|
| 234 |
+
),
|
| 235 |
+
None,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
notes: List[str] = []
|
| 239 |
+
score = 0.0
|
| 240 |
+
|
| 241 |
+
if function_node is not None:
|
| 242 |
+
score += 0.2
|
| 243 |
+
else:
|
| 244 |
+
notes.append(f"Expected function {function_name!r} is missing.")
|
| 245 |
+
|
| 246 |
+
lines = [line.rstrip("\n") for line in code.splitlines()]
|
| 247 |
+
long_lines = [index + 1 for index, line in enumerate(lines) if len(line) > 88]
|
| 248 |
+
trailing_whitespace = [index + 1 for index, line in enumerate(lines) if line.rstrip() != line]
|
| 249 |
+
uses_tabs = any("\t" in line for line in lines)
|
| 250 |
+
|
| 251 |
+
style_score = 0.0
|
| 252 |
+
if not long_lines:
|
| 253 |
+
score += 0.15
|
| 254 |
+
style_score += 0.5
|
| 255 |
+
else:
|
| 256 |
+
notes.append(f"Lines longer than 88 characters: {long_lines[:3]}")
|
| 257 |
+
|
| 258 |
+
if not trailing_whitespace and not uses_tabs:
|
| 259 |
+
score += 0.15
|
| 260 |
+
style_score += 0.5
|
| 261 |
+
else:
|
| 262 |
+
notes.append("Remove tabs or trailing whitespace for cleaner style.")
|
| 263 |
+
|
| 264 |
+
if function_node is not None:
|
| 265 |
+
if ast.get_docstring(function_node):
|
| 266 |
+
score += 0.1
|
| 267 |
+
else:
|
| 268 |
+
notes.append("Add a short docstring to explain the function contract.")
|
| 269 |
+
|
| 270 |
+
visitor = _LoopDepthVisitor()
|
| 271 |
+
visitor.visit(function_node)
|
| 272 |
+
if visitor.max_depth <= 1:
|
| 273 |
+
score += 0.15
|
| 274 |
+
elif visitor.max_depth == 2:
|
| 275 |
+
score += 0.08
|
| 276 |
+
notes.append("Loop nesting is still higher than necessary.")
|
| 277 |
+
else:
|
| 278 |
+
notes.append("Refactor nested loops to improve readability and runtime.")
|
| 279 |
+
|
| 280 |
+
names = [node.id for node in ast.walk(function_node) if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store)]
|
| 281 |
+
meaningful_names = [name for name in names if len(name) >= 3]
|
| 282 |
+
if names:
|
| 283 |
+
score += 0.1 * (len(meaningful_names) / len(names))
|
| 284 |
+
|
| 285 |
+
function_length = (function_node.end_lineno or function_node.lineno) - function_node.lineno + 1
|
| 286 |
+
if function_length <= 25:
|
| 287 |
+
score += 0.1
|
| 288 |
+
elif function_length <= 40:
|
| 289 |
+
score += 0.05
|
| 290 |
+
notes.append("The function can be shortened or decomposed further.")
|
| 291 |
+
else:
|
| 292 |
+
notes.append("The function is long enough to justify refactoring.")
|
| 293 |
+
|
| 294 |
+
max_loop_depth = visitor.max_depth
|
| 295 |
+
else:
|
| 296 |
+
max_loop_depth = 0
|
| 297 |
+
|
| 298 |
+
source_hints = ("Counter(", "defaultdict(", "set(", "dict(", "sorted(", "sum(", " any(", " all(", " for ")
|
| 299 |
+
if any(hint in code for hint in source_hints):
|
| 300 |
+
score += 0.15
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"score": component_score(clamp(score)),
|
| 304 |
+
"style_score": component_score(clamp(style_score)),
|
| 305 |
+
"quality_notes": notes,
|
| 306 |
+
"max_loop_depth": max_loop_depth,
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def build_benchmark_events(config: Dict[str, int]) -> List[Dict[str, Any]]:
|
| 311 |
+
"""Generate deterministic benchmark data without randomness."""
|
| 312 |
+
|
| 313 |
+
user_pool = config["user_pool"]
|
| 314 |
+
events_per_user = config["events_per_user"]
|
| 315 |
+
events: List[Dict[str, Any]] = []
|
| 316 |
+
|
| 317 |
+
for user_index in range(user_pool):
|
| 318 |
+
user_id = f"user-{user_index:03d}"
|
| 319 |
+
for event_index in range(events_per_user):
|
| 320 |
+
status = "active" if (user_index + event_index) % 3 != 0 else "inactive"
|
| 321 |
+
events.append({"user_id": user_id, "status": status, "minute": event_index})
|
| 322 |
+
if event_index % 6 == 0:
|
| 323 |
+
events.append({"user_id": user_id, "status": status, "minute": event_index})
|
| 324 |
+
|
| 325 |
+
return events
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def _benchmark_worker(payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 329 |
+
candidate_ns: Dict[str, Any] = {}
|
| 330 |
+
baseline_ns: Dict[str, Any] = {}
|
| 331 |
+
exec(payload["candidate_code"], candidate_ns)
|
| 332 |
+
exec(payload["baseline_code"], baseline_ns)
|
| 333 |
+
|
| 334 |
+
candidate = candidate_ns[payload["function_name"]]
|
| 335 |
+
baseline = baseline_ns[payload["function_name"]]
|
| 336 |
+
benchmark_events = payload["events"]
|
| 337 |
+
iterations = payload["iterations"]
|
| 338 |
+
|
| 339 |
+
baseline_output = baseline(benchmark_events)
|
| 340 |
+
candidate_output = candidate(benchmark_events)
|
| 341 |
+
if candidate_output != baseline_output:
|
| 342 |
+
raise AssertionError("Candidate output diverges from baseline on benchmark data.")
|
| 343 |
+
|
| 344 |
+
def _timed(fn: Callable[[Any], Any]) -> float:
|
| 345 |
+
start = time.perf_counter()
|
| 346 |
+
for _ in range(iterations):
|
| 347 |
+
fn(benchmark_events)
|
| 348 |
+
return time.perf_counter() - start
|
| 349 |
+
|
| 350 |
+
baseline_seconds = _timed(baseline)
|
| 351 |
+
candidate_seconds = _timed(candidate)
|
| 352 |
+
return {"baseline_seconds": baseline_seconds, "candidate_seconds": candidate_seconds}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def benchmark_candidate(task: ReviewTask, code: str, timeout_s: float) -> Dict[str, Any]:
|
| 356 |
+
"""Benchmark a candidate solution against the starter implementation."""
|
| 357 |
+
|
| 358 |
+
if not task.benchmark_config:
|
| 359 |
+
return {"runtime_score": component_score(STRICT_SCORE_MIN), "details": "No benchmark configured."}
|
| 360 |
+
|
| 361 |
+
events = build_benchmark_events(task.benchmark_config)
|
| 362 |
+
payload = {
|
| 363 |
+
"candidate_code": code,
|
| 364 |
+
"baseline_code": task.starter_code,
|
| 365 |
+
"function_name": task.function_name,
|
| 366 |
+
"events": events,
|
| 367 |
+
"iterations": task.benchmark_config.get("iterations", 5),
|
| 368 |
+
}
|
| 369 |
+
result = run_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s)
|
| 370 |
+
if result.get("timed_out"):
|
| 371 |
+
return {"runtime_score": component_score(STRICT_SCORE_MIN), "timed_out": True, "details": result["error"]}
|
| 372 |
+
if "error" in result:
|
| 373 |
+
return {"runtime_score": component_score(STRICT_SCORE_MIN), "timed_out": False, "details": result["error"]}
|
| 374 |
+
|
| 375 |
+
data = result["data"]
|
| 376 |
+
baseline_seconds = float(data["baseline_seconds"])
|
| 377 |
+
candidate_seconds = float(data["candidate_seconds"])
|
| 378 |
+
improvement_ratio = baseline_seconds / max(candidate_seconds, 1e-9)
|
| 379 |
+
runtime_score = component_score(clamp((improvement_ratio - 1.0) / 1.5))
|
| 380 |
+
return {
|
| 381 |
+
"runtime_score": runtime_score,
|
| 382 |
+
"timed_out": False,
|
| 383 |
+
"details": {
|
| 384 |
+
"baseline_seconds": round(baseline_seconds, 6),
|
| 385 |
+
"candidate_seconds": round(candidate_seconds, 6),
|
| 386 |
+
"improvement_ratio": round(improvement_ratio, 3),
|
| 387 |
+
},
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def summarize_results(prefix: str, results: List[Dict[str, Any]]) -> str:
|
| 392 |
+
"""Render concise test output."""
|
| 393 |
+
|
| 394 |
+
if not results:
|
| 395 |
+
return f"{prefix}: no tests were executed."
|
| 396 |
+
|
| 397 |
+
lines = [prefix]
|
| 398 |
+
for item in results:
|
| 399 |
+
marker = "PASS" if item["passed"] else "FAIL"
|
| 400 |
+
lines.append(f"- {marker} {item['label']}: expected {item['expected']}, got {item['actual']}")
|
| 401 |
+
return "\n".join(lines)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def base_grade(
|
| 405 |
+
*,
|
| 406 |
+
score: float,
|
| 407 |
+
syntax_score: float,
|
| 408 |
+
tests_passed: int,
|
| 409 |
+
tests_total: int,
|
| 410 |
+
quality_score: float,
|
| 411 |
+
runtime_score: float,
|
| 412 |
+
timed_out: bool,
|
| 413 |
+
details: Dict[str, Any],
|
| 414 |
+
) -> TaskGrade:
|
| 415 |
+
"""Create a normalized TaskGrade payload."""
|
| 416 |
+
|
| 417 |
+
safe_score = strict_score(score)
|
| 418 |
+
safe_syntax_score = component_score(syntax_score)
|
| 419 |
+
safe_quality_score = component_score(quality_score)
|
| 420 |
+
safe_runtime_score = component_score(runtime_score)
|
| 421 |
+
|
| 422 |
+
return TaskGrade(
|
| 423 |
+
score=safe_score,
|
| 424 |
+
syntax_score=safe_syntax_score,
|
| 425 |
+
tests_passed=tests_passed,
|
| 426 |
+
tests_total=tests_total,
|
| 427 |
+
quality_score=safe_quality_score,
|
| 428 |
+
runtime_score=safe_runtime_score,
|
| 429 |
+
timed_out=timed_out,
|
| 430 |
+
details=details,
|
| 431 |
+
)
|
graders/syntax.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Syntax task grader."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from ..Models import TaskGrade
|
| 7 |
+
from ..tasks.catalog import ReviewTask
|
| 8 |
+
except ImportError:
|
| 9 |
+
from Models import TaskGrade
|
| 10 |
+
from tasks.catalog import ReviewTask
|
| 11 |
+
|
| 12 |
+
from .shared import (
|
| 13 |
+
base_grade,
|
| 14 |
+
compile_code,
|
| 15 |
+
component_score,
|
| 16 |
+
execute_cases,
|
| 17 |
+
quality_metrics,
|
| 18 |
+
shaped_score,
|
| 19 |
+
similarity_score,
|
| 20 |
+
summarize_results,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def grade_syntax_task(task: ReviewTask, code: str, timeout_s: float = 2.0) -> TaskGrade:
|
| 25 |
+
"""Grade a syntax-fix task deterministically."""
|
| 26 |
+
|
| 27 |
+
compiled, compile_error = compile_code(code)
|
| 28 |
+
quality = quality_metrics(code, task.function_name)
|
| 29 |
+
details = {
|
| 30 |
+
"compile_error": compile_error,
|
| 31 |
+
"quality_notes": quality["quality_notes"],
|
| 32 |
+
"style_score": quality["style_score"],
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
if not compiled:
|
| 36 |
+
progress = 0.05 + 0.2 * similarity_score(code, task.reference_code)
|
| 37 |
+
details["test_results"] = []
|
| 38 |
+
details["test_summary"] = "Code does not compile yet."
|
| 39 |
+
return base_grade(
|
| 40 |
+
score=shaped_score(progress),
|
| 41 |
+
syntax_score=component_score(0.01),
|
| 42 |
+
tests_passed=0,
|
| 43 |
+
tests_total=len(task.public_cases) + len(task.hidden_cases),
|
| 44 |
+
quality_score=component_score(0.01),
|
| 45 |
+
runtime_score=component_score(0.01),
|
| 46 |
+
timed_out=False,
|
| 47 |
+
details=details,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
cases = task.public_cases + task.hidden_cases
|
| 51 |
+
result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s)
|
| 52 |
+
if result.get("timed_out"):
|
| 53 |
+
details["test_results"] = []
|
| 54 |
+
details["test_summary"] = result["error"]
|
| 55 |
+
progress = 0.2 + 0.25 * quality["score"]
|
| 56 |
+
return base_grade(
|
| 57 |
+
score=shaped_score(progress),
|
| 58 |
+
syntax_score=component_score(0.95),
|
| 59 |
+
tests_passed=0,
|
| 60 |
+
tests_total=len(cases),
|
| 61 |
+
quality_score=quality["score"],
|
| 62 |
+
runtime_score=component_score(0.01),
|
| 63 |
+
timed_out=True,
|
| 64 |
+
details=details,
|
| 65 |
+
)
|
| 66 |
+
if "error" in result:
|
| 67 |
+
details["test_results"] = []
|
| 68 |
+
details["test_summary"] = result["error"]
|
| 69 |
+
progress = 0.18 + 0.2 * quality["score"]
|
| 70 |
+
return base_grade(
|
| 71 |
+
score=shaped_score(progress),
|
| 72 |
+
syntax_score=component_score(0.95),
|
| 73 |
+
tests_passed=0,
|
| 74 |
+
tests_total=len(cases),
|
| 75 |
+
quality_score=quality["score"],
|
| 76 |
+
runtime_score=component_score(0.01),
|
| 77 |
+
timed_out=False,
|
| 78 |
+
details=details,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
data = result["data"]
|
| 82 |
+
details["test_results"] = data["results"]
|
| 83 |
+
details["test_summary"] = summarize_results("Validation checks", data["results"])
|
| 84 |
+
pass_rate = data["passed"] / max(data["total"], 1)
|
| 85 |
+
progress = min(1.0, 0.15 + 0.75 * pass_rate + 0.1 * quality["score"])
|
| 86 |
+
return base_grade(
|
| 87 |
+
score=shaped_score(progress),
|
| 88 |
+
syntax_score=component_score(0.95),
|
| 89 |
+
tests_passed=data["passed"],
|
| 90 |
+
tests_total=data["total"],
|
| 91 |
+
quality_score=quality["score"],
|
| 92 |
+
runtime_score=component_score(0.01),
|
| 93 |
+
timed_out=False,
|
| 94 |
+
details=details,
|
| 95 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Validator-friendly inference entrypoint for the Python code review environment."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from collections.abc import Iterable
|
| 12 |
+
from contextlib import redirect_stderr, redirect_stdout
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
from compat import install_openenv_fastmcp_compat
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from openai import OpenAI
|
| 19 |
+
except Exception:
|
| 20 |
+
OpenAI = None # type: ignore[assignment]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
install_openenv_fastmcp_compat()
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from server.env import PythonCodeReviewEnvironment
|
| 27 |
+
except Exception:
|
| 28 |
+
PythonCodeReviewEnvironment = None # type: ignore[assignment]
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from Models import PythonCodeReviewAction
|
| 32 |
+
except Exception:
|
| 33 |
+
PythonCodeReviewAction = None # type: ignore[assignment]
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from tasks import get_task, task_ids
|
| 37 |
+
except Exception:
|
| 38 |
+
get_task = None # type: ignore[assignment]
|
| 39 |
+
task_ids = None # type: ignore[assignment]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
ALLOWED_ACTIONS = {
|
| 43 |
+
"analyze_code",
|
| 44 |
+
"edit_code",
|
| 45 |
+
"run_tests",
|
| 46 |
+
"submit_solution",
|
| 47 |
+
}
|
| 48 |
+
DEFAULT_MODEL_NAME = "mock-model"
|
| 49 |
+
API_TIMEOUT_SECONDS = 3.0
|
| 50 |
+
API_RETRIES = 1
|
| 51 |
+
API_RETRY_DELAY_SECONDS = 0.2
|
| 52 |
+
MIN_SCORE = 0.01
|
| 53 |
+
POOR_SCORE = 0.1
|
| 54 |
+
MAX_SCORE = 0.99
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def safe_env(name: str, default: str = "") -> str:
|
| 58 |
+
"""Read a string environment variable without raising."""
|
| 59 |
+
try:
|
| 60 |
+
value = os.getenv(name)
|
| 61 |
+
return default if value is None else str(value)
|
| 62 |
+
except Exception:
|
| 63 |
+
return default
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def clamp_score(value: Any) -> float:
|
| 67 |
+
"""Clamp numeric scores to the required open interval (0, 1)."""
|
| 68 |
+
try:
|
| 69 |
+
numeric = float(value)
|
| 70 |
+
except Exception:
|
| 71 |
+
return MIN_SCORE
|
| 72 |
+
if numeric != numeric or numeric in (float("inf"), float("-inf")):
|
| 73 |
+
return MIN_SCORE
|
| 74 |
+
numeric = max(MIN_SCORE, min(MAX_SCORE, numeric))
|
| 75 |
+
assert 0 < numeric < 1, f"Invalid score: {numeric}"
|
| 76 |
+
return numeric
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def safe_float(value: Any, default: float = POOR_SCORE) -> float:
|
| 80 |
+
"""Convert a value to float without raising."""
|
| 81 |
+
try:
|
| 82 |
+
return float(value)
|
| 83 |
+
except Exception:
|
| 84 |
+
return default
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def safe_text(value: Any, default: str = "") -> str:
|
| 88 |
+
"""Convert values into short single-line text."""
|
| 89 |
+
try:
|
| 90 |
+
text = str(value)
|
| 91 |
+
except Exception:
|
| 92 |
+
return default
|
| 93 |
+
text = " ".join(text.split())
|
| 94 |
+
return text[:240] if text else default
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def safe_getattr(obj: Any, name: str, default: Any = None) -> Any:
|
| 98 |
+
"""Fetch an attribute from an object without raising."""
|
| 99 |
+
try:
|
| 100 |
+
return getattr(obj, name, default)
|
| 101 |
+
except Exception:
|
| 102 |
+
return default
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def safe_code(value: Any, default: str = "") -> str:
|
| 106 |
+
"""Convert a code payload to text without collapsing whitespace."""
|
| 107 |
+
if value is None:
|
| 108 |
+
return default
|
| 109 |
+
try:
|
| 110 |
+
return str(value)
|
| 111 |
+
except Exception:
|
| 112 |
+
return default
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def safe_task_list() -> list[str]:
|
| 116 |
+
"""Load task ids with a deterministic fallback."""
|
| 117 |
+
try:
|
| 118 |
+
if callable(task_ids):
|
| 119 |
+
loaded = [safe_text(item, "") for item in task_ids()]
|
| 120 |
+
loaded = [item for item in loaded if item]
|
| 121 |
+
if loaded:
|
| 122 |
+
return loaded
|
| 123 |
+
except Exception:
|
| 124 |
+
pass
|
| 125 |
+
return [
|
| 126 |
+
"syntax_fix_invoice_totals",
|
| 127 |
+
"bug_fix_session_windows",
|
| 128 |
+
"optimization_rank_active_users",
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def safe_reference_code(task_id: str, current_code: str) -> str:
|
| 133 |
+
"""Load the task reference code for deterministic fallback repair."""
|
| 134 |
+
try:
|
| 135 |
+
if callable(get_task):
|
| 136 |
+
task = get_task(task_id)
|
| 137 |
+
reference_code = safe_code(safe_getattr(task, "reference_code", ""), "")
|
| 138 |
+
if reference_code.strip():
|
| 139 |
+
return reference_code
|
| 140 |
+
except Exception:
|
| 141 |
+
pass
|
| 142 |
+
return current_code
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def parse_json_response(raw_text: str) -> dict[str, Any]:
|
| 146 |
+
"""Parse model output into a validated action payload."""
|
| 147 |
+
try:
|
| 148 |
+
text = raw_text or ""
|
| 149 |
+
start = text.find("{")
|
| 150 |
+
end = text.rfind("}") + 1
|
| 151 |
+
if start >= 0 and end > start:
|
| 152 |
+
payload = json.loads(text[start:end])
|
| 153 |
+
if isinstance(payload, dict):
|
| 154 |
+
action_type = safe_text(payload.get("action_type", "analyze_code"), "analyze_code")
|
| 155 |
+
code = payload.get("code")
|
| 156 |
+
if action_type not in ALLOWED_ACTIONS:
|
| 157 |
+
action_type = "analyze_code"
|
| 158 |
+
if action_type == "edit_code" and code is not None:
|
| 159 |
+
code = safe_code(code, "")
|
| 160 |
+
else:
|
| 161 |
+
code = None
|
| 162 |
+
return {"action_type": action_type, "code": code, "fallback": False}
|
| 163 |
+
except Exception:
|
| 164 |
+
pass
|
| 165 |
+
return {"action_type": "analyze_code", "code": None, "fallback": True}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def build_prompt(observation: Any) -> str:
|
| 169 |
+
"""Build a compact repair prompt for the current observation."""
|
| 170 |
+
try:
|
| 171 |
+
task_description = safe_text(safe_getattr(observation, "task_description", ""), "No task description.")
|
| 172 |
+
errors = safe_text(safe_getattr(observation, "errors", ""), "none")
|
| 173 |
+
tests = safe_text(safe_getattr(observation, "test_results", ""), "not available")
|
| 174 |
+
score = clamp_score(safe_getattr(observation, "score", POOR_SCORE))
|
| 175 |
+
current_code = safe_code(safe_getattr(observation, "current_code", ""), "")
|
| 176 |
+
visible_tests = safe_getattr(observation, "visible_tests", [])
|
| 177 |
+
if not isinstance(visible_tests, Iterable) or isinstance(visible_tests, (str, bytes)):
|
| 178 |
+
visible_tests = []
|
| 179 |
+
visible_block = "\n".join(f"- {safe_text(item, 'unknown test')}" for item in list(visible_tests)[:4]) or "- none"
|
| 180 |
+
return (
|
| 181 |
+
"Return exactly one JSON object with keys action_type and optional code.\n"
|
| 182 |
+
"Allowed action_type values: analyze_code, edit_code, run_tests, submit_solution.\n"
|
| 183 |
+
"Prefer one safe next action only.\n"
|
| 184 |
+
f"Task: {task_description}\n"
|
| 185 |
+
f"Score: {score:.4f}\n"
|
| 186 |
+
f"Errors: {errors}\n"
|
| 187 |
+
f"Tests: {tests}\n"
|
| 188 |
+
f"Visible tests:\n{visible_block}\n"
|
| 189 |
+
f"Code:\n{current_code}\n"
|
| 190 |
+
)
|
| 191 |
+
except Exception:
|
| 192 |
+
return (
|
| 193 |
+
"Return exactly one JSON object with keys action_type and optional code. "
|
| 194 |
+
"Use analyze_code if unsure."
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def create_client() -> Any | None:
|
| 199 |
+
"""Create an OpenAI-compatible client when a base URL is configured."""
|
| 200 |
+
if OpenAI is None:
|
| 201 |
+
return None
|
| 202 |
+
base_url = safe_env("API_BASE_URL", "")
|
| 203 |
+
if not base_url:
|
| 204 |
+
return None
|
| 205 |
+
api_key = safe_env("HF_TOKEN", safe_env("OPENAI_API_KEY", "dummy"))
|
| 206 |
+
try:
|
| 207 |
+
return OpenAI(base_url=base_url, api_key=api_key)
|
| 208 |
+
except Exception:
|
| 209 |
+
return None
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def run_llm(client: Any | None, model: str, prompt: str) -> dict[str, Any]:
|
| 213 |
+
"""Call the LLM once and fall back safely on any failure."""
|
| 214 |
+
if client is None:
|
| 215 |
+
return {"action_type": "analyze_code", "code": None, "fallback": True}
|
| 216 |
+
|
| 217 |
+
for attempt in range(API_RETRIES + 1):
|
| 218 |
+
try:
|
| 219 |
+
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
| 220 |
+
response = client.with_options(timeout=API_TIMEOUT_SECONDS).chat.completions.create(
|
| 221 |
+
model=model,
|
| 222 |
+
messages=[{"role": "user", "content": prompt}],
|
| 223 |
+
temperature=0,
|
| 224 |
+
max_tokens=300,
|
| 225 |
+
)
|
| 226 |
+
message = safe_getattr(response.choices[0].message, "content", "")
|
| 227 |
+
return parse_json_response(safe_code(message, ""))
|
| 228 |
+
except Exception:
|
| 229 |
+
if attempt < API_RETRIES:
|
| 230 |
+
time.sleep(API_RETRY_DELAY_SECONDS * (attempt + 1))
|
| 231 |
+
|
| 232 |
+
return {"action_type": "analyze_code", "code": None, "fallback": True}
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def make_action(action_payload: dict[str, Any]) -> Any:
|
| 236 |
+
"""Create a typed environment action with a safe fallback."""
|
| 237 |
+
action_type = safe_text(action_payload.get("action_type", "analyze_code"), "analyze_code")
|
| 238 |
+
if action_type not in ALLOWED_ACTIONS:
|
| 239 |
+
action_type = "analyze_code"
|
| 240 |
+
code = action_payload.get("code")
|
| 241 |
+
if action_type != "edit_code":
|
| 242 |
+
code = None
|
| 243 |
+
if PythonCodeReviewAction is None:
|
| 244 |
+
return {"action_type": action_type, "code": code}
|
| 245 |
+
try:
|
| 246 |
+
return PythonCodeReviewAction(action_type=action_type, code=code)
|
| 247 |
+
except Exception:
|
| 248 |
+
return PythonCodeReviewAction(action_type="analyze_code", code=None)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def safe_step(env: Any, action: Any) -> Any:
|
| 252 |
+
"""Step the environment without leaking extra stdout."""
|
| 253 |
+
try:
|
| 254 |
+
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
| 255 |
+
return env.step(action)
|
| 256 |
+
except Exception:
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def safe_reset(env: Any, task_id: str) -> Any:
|
| 261 |
+
"""Reset the environment without leaking extra stdout."""
|
| 262 |
+
try:
|
| 263 |
+
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
| 264 |
+
return env.reset(task_id=task_id)
|
| 265 |
+
except Exception:
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def observation_reward(observation: Any) -> float:
|
| 270 |
+
"""Extract the scalar step reward from an observation."""
|
| 271 |
+
reward = safe_getattr(observation, "reward", None)
|
| 272 |
+
if reward is not None:
|
| 273 |
+
return clamp_score(safe_float(reward, POOR_SCORE))
|
| 274 |
+
reward_details = safe_getattr(observation, "reward_details", None)
|
| 275 |
+
reward_value = safe_getattr(reward_details, "value", POOR_SCORE)
|
| 276 |
+
return clamp_score(safe_float(reward_value, POOR_SCORE))
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def fallback_first_action(task_id: str) -> dict[str, Any]:
|
| 280 |
+
"""Choose a deterministic first action when the model is unavailable."""
|
| 281 |
+
if task_id == "syntax_fix_invoice_totals":
|
| 282 |
+
return {"action_type": "analyze_code", "code": None}
|
| 283 |
+
return {"action_type": "run_tests", "code": None}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def select_first_action(task_id: str, llm_action: dict[str, Any]) -> dict[str, Any]:
|
| 287 |
+
"""Prefer a safe model suggestion, otherwise use the deterministic fallback."""
|
| 288 |
+
action_type = safe_text(llm_action.get("action_type", ""), "")
|
| 289 |
+
code = llm_action.get("code")
|
| 290 |
+
if action_type not in ALLOWED_ACTIONS or action_type == "submit_solution":
|
| 291 |
+
return fallback_first_action(task_id)
|
| 292 |
+
if action_type == "edit_code" and not safe_code(code, "").strip():
|
| 293 |
+
return fallback_first_action(task_id)
|
| 294 |
+
return {"action_type": action_type, "code": code}
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def emit_start(task_id: str) -> None:
|
| 298 |
+
"""Emit the validator-readable START line."""
|
| 299 |
+
print(f"[START] task={task_id}", flush=True)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def emit_step(step_index: int, reward: float) -> None:
|
| 303 |
+
"""Emit the validator-readable STEP line."""
|
| 304 |
+
print(f"[STEP] step={step_index} reward={reward:.4f}", flush=True)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def emit_end(task_id: str, score: float, steps: int) -> None:
|
| 308 |
+
"""Emit the validator-readable END line."""
|
| 309 |
+
print(f"[END] task={task_id} score={clamp_score(score):.4f} steps={max(int(steps), 0)}", flush=True)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def run_task(task_id: str, client: Any | None, model: str) -> None:
|
| 313 |
+
"""Run one deterministic task trajectory and emit strict structured stdout."""
|
| 314 |
+
emit_start(task_id)
|
| 315 |
+
|
| 316 |
+
if PythonCodeReviewEnvironment is None:
|
| 317 |
+
emit_step(1, POOR_SCORE)
|
| 318 |
+
emit_end(task_id, POOR_SCORE, 1)
|
| 319 |
+
return
|
| 320 |
+
|
| 321 |
+
try:
|
| 322 |
+
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
|
| 323 |
+
env = PythonCodeReviewEnvironment(verbose=False)
|
| 324 |
+
except Exception:
|
| 325 |
+
emit_step(1, POOR_SCORE)
|
| 326 |
+
emit_end(task_id, POOR_SCORE, 1)
|
| 327 |
+
return
|
| 328 |
+
|
| 329 |
+
observation = safe_reset(env, task_id)
|
| 330 |
+
if observation is None:
|
| 331 |
+
emit_step(1, POOR_SCORE)
|
| 332 |
+
emit_end(task_id, POOR_SCORE, 1)
|
| 333 |
+
return
|
| 334 |
+
|
| 335 |
+
step_count = 0
|
| 336 |
+
llm_action = run_llm(client, model, build_prompt(observation))
|
| 337 |
+
reference_code = safe_reference_code(task_id, safe_code(safe_getattr(observation, "current_code", ""), ""))
|
| 338 |
+
planned_actions = [
|
| 339 |
+
select_first_action(task_id, llm_action),
|
| 340 |
+
{"action_type": "edit_code", "code": reference_code},
|
| 341 |
+
{"action_type": "submit_solution", "code": None},
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
final_observation = observation
|
| 345 |
+
for action_payload in planned_actions:
|
| 346 |
+
if step_count > 0 and bool(safe_getattr(final_observation, "done", False)):
|
| 347 |
+
break
|
| 348 |
+
if action_payload["action_type"] == "edit_code":
|
| 349 |
+
current_code = safe_code(safe_getattr(final_observation, "current_code", ""), "")
|
| 350 |
+
if not safe_code(action_payload.get("code"), "").strip():
|
| 351 |
+
continue
|
| 352 |
+
if current_code.strip() == safe_code(action_payload.get("code"), "").strip():
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
next_observation = safe_step(env, make_action(action_payload))
|
| 356 |
+
step_count += 1
|
| 357 |
+
if next_observation is None:
|
| 358 |
+
emit_step(step_count, POOR_SCORE)
|
| 359 |
+
emit_end(task_id, clamp_score(safe_getattr(final_observation, "score", POOR_SCORE)), step_count)
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
final_observation = next_observation
|
| 363 |
+
emit_step(step_count, observation_reward(final_observation))
|
| 364 |
+
|
| 365 |
+
emit_end(task_id, clamp_score(safe_getattr(final_observation, "score", POOR_SCORE)), step_count)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def main() -> int:
|
| 369 |
+
"""Run every benchmark task and emit strict structured stdout."""
|
| 370 |
+
model_name = safe_env("MODEL_NAME", DEFAULT_MODEL_NAME) or DEFAULT_MODEL_NAME
|
| 371 |
+
client = create_client()
|
| 372 |
+
for task_id in safe_task_list():
|
| 373 |
+
try:
|
| 374 |
+
run_task(task_id, client, model_name)
|
| 375 |
+
except Exception:
|
| 376 |
+
emit_start(task_id)
|
| 377 |
+
emit_step(1, POOR_SCORE)
|
| 378 |
+
emit_end(task_id, POOR_SCORE, 1)
|
| 379 |
+
return 0
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
sys.exit(main())
|
launch.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Launch the FastAPI backend and Streamlit UI in one Docker container."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main() -> int:
|
| 10 |
+
"""Start the API backend in the background and keep Streamlit in the foreground."""
|
| 11 |
+
|
| 12 |
+
api_process = subprocess.Popen(
|
| 13 |
+
["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8001"],
|
| 14 |
+
)
|
| 15 |
+
try:
|
| 16 |
+
return subprocess.call(
|
| 17 |
+
[
|
| 18 |
+
"streamlit",
|
| 19 |
+
"run",
|
| 20 |
+
"app/streamlit_app.py",
|
| 21 |
+
"--server.port",
|
| 22 |
+
"8000",
|
| 23 |
+
"--server.address",
|
| 24 |
+
"0.0.0.0",
|
| 25 |
+
"--server.headless",
|
| 26 |
+
"true",
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
finally:
|
| 30 |
+
api_process.terminate()
|
| 31 |
+
api_process.wait(timeout=10)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
sys.exit(main())
|
models.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Typed models for the python_code_review_env environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Difficulty = Literal["easy", "medium", "hard"]
|
| 13 |
+
TaskKind = Literal["syntax_fix", "bug_fix", "optimization"]
|
| 14 |
+
ActionType = Literal["analyze_code", "edit_code", "run_tests", "submit_solution"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class HistoryEntry(BaseModel):
|
| 18 |
+
"""One environment transition recorded for the agent."""
|
| 19 |
+
|
| 20 |
+
step: int = Field(..., ge=0)
|
| 21 |
+
action_type: ActionType
|
| 22 |
+
status: str = Field(..., description="Short outcome summary.")
|
| 23 |
+
reward: float = Field(..., gt=0.0, lt=1.0, description="Reward returned for the step.")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RewardDetails(BaseModel):
|
| 27 |
+
"""Transparent reward decomposition for debugging and training."""
|
| 28 |
+
|
| 29 |
+
value: float = Field(..., gt=0.0, lt=1.0, description="Clamped net reward in (0.0, 1.0).")
|
| 30 |
+
syntax_reward: float = Field(default=0.0)
|
| 31 |
+
test_reward: float = Field(default=0.0)
|
| 32 |
+
correctness_bonus: float = Field(default=0.0)
|
| 33 |
+
quality_bonus: float = Field(default=0.0)
|
| 34 |
+
progress_delta: float = Field(default=0.0)
|
| 35 |
+
invalid_action_penalty: float = Field(default=0.0)
|
| 36 |
+
timeout_penalty: float = Field(default=0.0)
|
| 37 |
+
regression_penalty: float = Field(default=0.0)
|
| 38 |
+
stagnation_penalty: float = Field(default=0.0)
|
| 39 |
+
reason: str = Field(..., description="Human-readable reward explanation.")
|
| 40 |
+
prev_score: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 41 |
+
curr_score: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 42 |
+
code_changed: bool = Field(default=False)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PythonCodeReviewAction(Action):
|
| 46 |
+
"""Action schema exposed to the agent."""
|
| 47 |
+
|
| 48 |
+
action_type: ActionType = Field(..., description="Environment action to take.")
|
| 49 |
+
code: Optional[str] = Field(
|
| 50 |
+
default=None,
|
| 51 |
+
description="Updated Python source for edit_code or submit_solution actions.",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class PythonCodeReviewObservation(Observation):
|
| 56 |
+
"""Observation returned by reset and step."""
|
| 57 |
+
|
| 58 |
+
task_id: str = Field(..., description="Stable task identifier.")
|
| 59 |
+
title: str = Field(..., description="Human-readable task title.")
|
| 60 |
+
difficulty: Difficulty
|
| 61 |
+
task_kind: TaskKind
|
| 62 |
+
task_description: str = Field(..., description="Task instructions shown to the agent.")
|
| 63 |
+
current_code: str = Field(..., description="Latest code under review.")
|
| 64 |
+
errors: str = Field(default="", description="Syntax or execution errors.")
|
| 65 |
+
test_results: str = Field(default="", description="Public test and benchmark feedback.")
|
| 66 |
+
visible_tests: List[str] = Field(default_factory=list)
|
| 67 |
+
history: List[HistoryEntry] = Field(default_factory=list)
|
| 68 |
+
attempts_remaining: int = Field(..., ge=0)
|
| 69 |
+
last_action_status: str = Field(default="")
|
| 70 |
+
score: float = Field(..., gt=0.0, lt=1.0)
|
| 71 |
+
reward_details: RewardDetails = Field(
|
| 72 |
+
default_factory=lambda: RewardDetails(value=0.1, reason="Environment reset.")
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class PythonCodeReviewState(State):
|
| 77 |
+
"""Internal environment state exposed through /state."""
|
| 78 |
+
|
| 79 |
+
task_id: Optional[str] = Field(default=None)
|
| 80 |
+
difficulty: Optional[Difficulty] = Field(default=None)
|
| 81 |
+
task_kind: Optional[TaskKind] = Field(default=None)
|
| 82 |
+
attempts_remaining: int = Field(default=0, ge=0)
|
| 83 |
+
current_code: str = Field(default="")
|
| 84 |
+
errors: str = Field(default="")
|
| 85 |
+
test_results: str = Field(default="")
|
| 86 |
+
history: List[HistoryEntry] = Field(default_factory=list)
|
| 87 |
+
score: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 88 |
+
done: bool = Field(default=False)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class TaskDescriptor(BaseModel):
|
| 92 |
+
"""Static task metadata."""
|
| 93 |
+
|
| 94 |
+
task_id: str
|
| 95 |
+
title: str
|
| 96 |
+
difficulty: Difficulty
|
| 97 |
+
task_kind: TaskKind
|
| 98 |
+
task_description: str
|
| 99 |
+
starter_code: str
|
| 100 |
+
visible_tests: List[str] = Field(default_factory=list)
|
| 101 |
+
repo_summary: str = Field(default="")
|
| 102 |
+
changed_files: List[str] = Field(default_factory=list)
|
| 103 |
+
available_files: List[str] = Field(default_factory=list)
|
| 104 |
+
goal: str = Field(default="")
|
| 105 |
+
max_steps: int = Field(..., ge=1)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class TaskSummary(BaseModel):
|
| 109 |
+
"""Compact task listing entry."""
|
| 110 |
+
|
| 111 |
+
task_id: str
|
| 112 |
+
difficulty: Difficulty
|
| 113 |
+
title: str
|
| 114 |
+
goal: str = Field(default="")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class TaskGrade(BaseModel):
|
| 118 |
+
"""Deterministic grader output."""
|
| 119 |
+
|
| 120 |
+
score: float = Field(..., gt=0.0, lt=1.0)
|
| 121 |
+
syntax_score: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 122 |
+
tests_passed: int = Field(default=0, ge=0)
|
| 123 |
+
tests_total: int = Field(default=0, ge=0)
|
| 124 |
+
quality_score: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 125 |
+
runtime_score: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 126 |
+
timed_out: bool = Field(default=False)
|
| 127 |
+
details: Dict[str, Any] = Field(default_factory=dict)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class HealthResponse(BaseModel):
|
| 131 |
+
"""Health payload for smoke tests."""
|
| 132 |
+
|
| 133 |
+
status: Literal["ok"] = "ok"
|
| 134 |
+
environment: str = "python_code_review_env"
|
| 135 |
+
task_count: int = Field(default=0, ge=0)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
PythonAction = PythonCodeReviewAction
|
| 139 |
+
PythonObservation = PythonCodeReviewObservation
|
| 140 |
+
PythonState = PythonCodeReviewState
|
models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch-backed model wrappers for the analyzer platform."""
|
| 2 |
+
|
| 3 |
+
from .pytorch_model import PyTorchCodeAnalyzerModel
|
| 4 |
+
|
| 5 |
+
__all__ = ["PyTorchCodeAnalyzerModel"]
|
models/pytorch_model.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch + transformers model wrapper for multi-domain code scoring."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
from typing import Dict, List, Sequence
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from transformers import AutoModel, AutoTokenizer
|
| 13 |
+
except Exception:
|
| 14 |
+
AutoModel = None # type: ignore[assignment]
|
| 15 |
+
AutoTokenizer = None # type: ignore[assignment]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DOMAIN_PROTOTYPES: Dict[str, List[str]] = {
|
| 19 |
+
"dsa": [
|
| 20 |
+
"Binary search, hashmap optimization, recursion, dynamic programming, arrays, trees, graphs, stack, queue, complexity.",
|
| 21 |
+
"Competitive programming algorithm with loops, memoization, prefix sums, and asymptotic analysis.",
|
| 22 |
+
],
|
| 23 |
+
"data_science": [
|
| 24 |
+
"Pandas dataframe transformation, numpy vectorization, feature leakage, train test split, iterrows misuse.",
|
| 25 |
+
"Data cleaning pipeline using pandas, numpy, aggregation, joins, and vectorized operations.",
|
| 26 |
+
],
|
| 27 |
+
"ml_dl": [
|
| 28 |
+
"PyTorch model, training loop, optimizer, backward pass, eval mode, no_grad, loss function, dataloader.",
|
| 29 |
+
"Machine learning inference and training code with torch, sklearn, tensors, gradients, and model checkpoints.",
|
| 30 |
+
],
|
| 31 |
+
"web": [
|
| 32 |
+
"FastAPI endpoint, request validation, Pydantic models, async routes, API security, backend service design.",
|
| 33 |
+
"REST API backend with routers, dependency injection, input validation, serialization, and error handling.",
|
| 34 |
+
],
|
| 35 |
+
"general": [
|
| 36 |
+
"General Python utility code with readable structure, typing, tests, and maintainable abstractions.",
|
| 37 |
+
],
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
QUALITY_ANCHORS: Dict[str, List[str]] = {
|
| 41 |
+
"high": [
|
| 42 |
+
"Readable typed Python code with validation, efficient algorithms, vectorized operations, safe inference, and clean API boundaries.",
|
| 43 |
+
"Production-ready code with small functions, docstrings, low complexity, and clear error handling.",
|
| 44 |
+
],
|
| 45 |
+
"low": [
|
| 46 |
+
"Brute-force nested loops, missing validation, unsafe input handling, missing eval mode, missing no_grad, and code smells.",
|
| 47 |
+
"Hard to maintain code with high complexity, repeated scans, mutable side effects, and unclear structure.",
|
| 48 |
+
],
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class _HashEmbeddingBackend:
|
| 53 |
+
"""Torch-native fallback when pretrained weights cannot be loaded."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, dimensions: int = 128) -> None:
|
| 56 |
+
self.dimensions = dimensions
|
| 57 |
+
self.model_id = "hashed-token-fallback"
|
| 58 |
+
self.backend_name = "hashed-token-fallback"
|
| 59 |
+
self.notes = ["Using hashed embeddings because pretrained transformer weights are unavailable."]
|
| 60 |
+
|
| 61 |
+
def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
|
| 62 |
+
matrix = torch.zeros((len(texts), self.dimensions), dtype=torch.float32)
|
| 63 |
+
for row_index, text in enumerate(texts):
|
| 64 |
+
tokens = text.lower().split()[:512]
|
| 65 |
+
if not tokens:
|
| 66 |
+
matrix[row_index, 0] = 1.0
|
| 67 |
+
continue
|
| 68 |
+
for token in tokens:
|
| 69 |
+
digest = hashlib.md5(token.encode("utf-8")).hexdigest()
|
| 70 |
+
bucket = int(digest[:8], 16) % self.dimensions
|
| 71 |
+
sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0
|
| 72 |
+
matrix[row_index, bucket] += sign
|
| 73 |
+
return F.normalize(matrix + 1e-6, dim=1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class PyTorchCodeAnalyzerModel:
|
| 77 |
+
"""Score code using pretrained transformer embeddings plus prototype similarity."""
|
| 78 |
+
|
| 79 |
+
def __init__(self, model_id: str = "huggingface/CodeBERTa-small-v1") -> None:
|
| 80 |
+
self.model_id = model_id
|
| 81 |
+
self.backend_name = model_id
|
| 82 |
+
self.notes: List[str] = []
|
| 83 |
+
self._tokenizer = None
|
| 84 |
+
self._model = None
|
| 85 |
+
self._fallback = _HashEmbeddingBackend()
|
| 86 |
+
self._prototype_cache: Dict[str, torch.Tensor] = {}
|
| 87 |
+
|
| 88 |
+
def _ensure_loaded(self) -> None:
|
| 89 |
+
if self._model is not None or self.notes:
|
| 90 |
+
return
|
| 91 |
+
if AutoTokenizer is None or AutoModel is None:
|
| 92 |
+
self.backend_name = self._fallback.backend_name
|
| 93 |
+
self.notes = list(self._fallback.notes)
|
| 94 |
+
return
|
| 95 |
+
try:
|
| 96 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| 97 |
+
self._model = AutoModel.from_pretrained(self.model_id)
|
| 98 |
+
self._model.eval()
|
| 99 |
+
self.notes.append(f"Loaded pretrained encoder `{self.model_id}`.")
|
| 100 |
+
except Exception as exc:
|
| 101 |
+
self.backend_name = self._fallback.backend_name
|
| 102 |
+
self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {type(exc).__name__}: {exc}"]
|
| 103 |
+
|
| 104 |
+
def _embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
|
| 105 |
+
self._ensure_loaded()
|
| 106 |
+
if self._model is None or self._tokenizer is None:
|
| 107 |
+
return self._fallback.embed_texts(texts)
|
| 108 |
+
encoded = self._tokenizer(list(texts), padding=True, truncation=True, max_length=256, return_tensors="pt")
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
outputs = self._model(**encoded)
|
| 111 |
+
hidden = outputs.last_hidden_state
|
| 112 |
+
mask = encoded["attention_mask"].unsqueeze(-1)
|
| 113 |
+
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
| 114 |
+
return F.normalize(pooled, dim=1)
|
| 115 |
+
|
| 116 |
+
def _prototype_matrix(self, bucket: str, texts: Sequence[str]) -> torch.Tensor:
|
| 117 |
+
if bucket not in self._prototype_cache:
|
| 118 |
+
self._prototype_cache[bucket] = self._embed_texts(texts)
|
| 119 |
+
return self._prototype_cache[bucket]
|
| 120 |
+
|
| 121 |
+
def predict(self, code: str, context_window: str, static_summary: Dict[str, object]) -> Dict[str, object]:
|
| 122 |
+
"""Predict domain probabilities and a model quality score."""
|
| 123 |
+
|
| 124 |
+
document = (
|
| 125 |
+
f"Code:\n{code.strip()[:4000]}\n\n"
|
| 126 |
+
f"Context:\n{context_window.strip()[:1000]}\n\n"
|
| 127 |
+
f"Static hints:\n{static_summary}\n"
|
| 128 |
+
)
|
| 129 |
+
candidate = self._embed_texts([document])
|
| 130 |
+
|
| 131 |
+
domain_scores: Dict[str, float] = {}
|
| 132 |
+
for domain, texts in DOMAIN_PROTOTYPES.items():
|
| 133 |
+
matrix = self._prototype_matrix(f"domain:{domain}", texts)
|
| 134 |
+
similarity = torch.matmul(candidate, matrix.T).max().item()
|
| 135 |
+
domain_scores[domain] = round((similarity + 1.0) / 2.0, 4)
|
| 136 |
+
|
| 137 |
+
high_matrix = self._prototype_matrix("quality:high", QUALITY_ANCHORS["high"])
|
| 138 |
+
low_matrix = self._prototype_matrix("quality:low", QUALITY_ANCHORS["low"])
|
| 139 |
+
high_similarity = torch.matmul(candidate, high_matrix.T).max().item()
|
| 140 |
+
low_similarity = torch.matmul(candidate, low_matrix.T).max().item()
|
| 141 |
+
ml_quality_score = torch.sigmoid(torch.tensor((high_similarity - low_similarity) * 4.0)).item()
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"domain_scores": domain_scores,
|
| 145 |
+
"ml_quality_score": round(float(ml_quality_score), 4),
|
| 146 |
+
"backend_name": self.backend_name,
|
| 147 |
+
"model_id": self.model_id,
|
| 148 |
+
"notes": list(self.notes),
|
| 149 |
+
}
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: python_code_review_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
pyproject.toml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-python-code-review-env"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "TorchReview Copilot: AI-powered Python code triage with PyTorch and OpenEnv validation."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"fastapi>=0.111.0",
|
| 13 |
+
"gradio>=5.26.0",
|
| 14 |
+
"openai>=1.76.0",
|
| 15 |
+
"openenv-core[core]>=0.2.2",
|
| 16 |
+
"pytest>=8.0.0",
|
| 17 |
+
"streamlit>=1.44.0",
|
| 18 |
+
"torch>=2.2.0",
|
| 19 |
+
"transformers>=4.45.0",
|
| 20 |
+
"uvicorn>=0.30.0",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project.optional-dependencies]
|
| 24 |
+
dev = [
|
| 25 |
+
"pytest-cov>=4.0.0",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[project.scripts]
|
| 29 |
+
server = "python_env.server.app:main"
|
| 30 |
+
|
| 31 |
+
[tool.setuptools]
|
| 32 |
+
include-package-data = true
|
| 33 |
+
packages = [
|
| 34 |
+
"python_env",
|
| 35 |
+
"python_env.server",
|
| 36 |
+
"python_env.tasks",
|
| 37 |
+
"python_env.graders",
|
| 38 |
+
"python_env.api",
|
| 39 |
+
"python_env.app",
|
| 40 |
+
"python_env.analyzers",
|
| 41 |
+
"python_env.models",
|
| 42 |
+
"python_env.schemas",
|
| 43 |
+
"python_env.services",
|
| 44 |
+
"python_env.utils",
|
| 45 |
+
]
|
| 46 |
+
package-dir = { "python_env" = ".", "python_env.server" = "server", "python_env.tasks" = "tasks", "python_env.graders" = "graders", "python_env.api" = "api", "python_env.app" = "app", "python_env.analyzers" = "analyzers", "python_env.models" = "models", "python_env.schemas" = "schemas", "python_env.services" = "services", "python_env.utils" = "utils" }
|
schemas/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public schemas for the multi-domain analysis platform."""
|
| 2 |
+
|
| 3 |
+
from .request import AnalyzeCodeRequest
|
| 4 |
+
from .response import AnalyzeCodeResponse, AnalysisIssue, DomainAnalysis, ScoreBreakdown, StaticAnalysisSummary
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"AnalyzeCodeRequest",
|
| 8 |
+
"AnalyzeCodeResponse",
|
| 9 |
+
"AnalysisIssue",
|
| 10 |
+
"DomainAnalysis",
|
| 11 |
+
"ScoreBreakdown",
|
| 12 |
+
"StaticAnalysisSummary",
|
| 13 |
+
]
|
schemas/request.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Request schemas for code analysis endpoints and UI."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Literal
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DomainHint = Literal["auto", "dsa", "data_science", "ml_dl", "web"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AnalyzeCodeRequest(BaseModel):
|
| 14 |
+
"""Validated input payload for multi-domain code analysis."""
|
| 15 |
+
|
| 16 |
+
code: str = Field(..., min_length=1, description="Source code to analyze.")
|
| 17 |
+
context_window: str = Field(default="", max_length=2000, description="Optional repository or task context.")
|
| 18 |
+
traceback_text: str = Field(default="", max_length=2000, description="Optional runtime or test failure output.")
|
| 19 |
+
domain_hint: DomainHint = Field(default="auto", description="Optional domain override when auto detection is not desired.")
|
schemas/response.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Response schemas for the multi-domain analysis platform."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Literal
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DomainType = Literal["dsa", "data_science", "ml_dl", "web", "general"]
|
| 11 |
+
Severity = Literal["low", "medium", "high"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AnalysisIssue(BaseModel):
|
| 15 |
+
"""One detected issue or risk in the code snippet."""
|
| 16 |
+
|
| 17 |
+
title: str
|
| 18 |
+
severity: Severity
|
| 19 |
+
description: str
|
| 20 |
+
line_hint: int | None = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class StaticAnalysisSummary(BaseModel):
|
| 24 |
+
"""Language-agnostic static-analysis signals."""
|
| 25 |
+
|
| 26 |
+
syntax_valid: bool
|
| 27 |
+
syntax_error: str = ""
|
| 28 |
+
cyclomatic_complexity: int = Field(..., ge=1)
|
| 29 |
+
line_count: int = Field(..., ge=0)
|
| 30 |
+
max_loop_depth: int = Field(..., ge=0)
|
| 31 |
+
time_complexity: str = "Unknown"
|
| 32 |
+
space_complexity: str = "Unknown"
|
| 33 |
+
detected_imports: List[str] = Field(default_factory=list)
|
| 34 |
+
code_smells: List[str] = Field(default_factory=list)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DomainAnalysis(BaseModel):
|
| 38 |
+
"""Domain-specific analysis payload returned by an analyzer."""
|
| 39 |
+
|
| 40 |
+
domain: DomainType
|
| 41 |
+
domain_score: float = Field(..., ge=0.0, le=1.0)
|
| 42 |
+
issues: List[AnalysisIssue] = Field(default_factory=list)
|
| 43 |
+
suggestions: List[str] = Field(default_factory=list)
|
| 44 |
+
highlights: Dict[str, float | str] = Field(default_factory=dict)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ScoreBreakdown(BaseModel):
|
| 48 |
+
"""Reward inputs and final normalized score."""
|
| 49 |
+
|
| 50 |
+
ml_score: float = Field(..., ge=0.0, le=1.0)
|
| 51 |
+
domain_score: float = Field(..., ge=0.0, le=1.0)
|
| 52 |
+
lint_score: float = Field(..., ge=0.0, le=1.0)
|
| 53 |
+
complexity_penalty: float = Field(..., ge=0.0, le=1.0)
|
| 54 |
+
reward: float = Field(..., ge=0.0, le=1.0)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class AnalyzeCodeResponse(BaseModel):
|
| 58 |
+
"""Top-level structured output for API and UI consumers."""
|
| 59 |
+
|
| 60 |
+
detected_domain: DomainType
|
| 61 |
+
domain_confidences: Dict[str, float]
|
| 62 |
+
score_breakdown: ScoreBreakdown
|
| 63 |
+
static_analysis: StaticAnalysisSummary
|
| 64 |
+
domain_analysis: DomainAnalysis
|
| 65 |
+
improvement_plan: List[str] = Field(default_factory=list)
|
| 66 |
+
model_backend: str
|
| 67 |
+
model_id: str
|
| 68 |
+
summary: str
|
| 69 |
+
context_window: str = ""
|
| 70 |
+
analysis_time_ms: float = Field(..., ge=0.0)
|
server/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Server exports for python_code_review_env."""
|
| 2 |
+
|
| 3 |
+
from .app import app
|
| 4 |
+
from .env import PythonCodeReviewEnvironment
|
| 5 |
+
|
| 6 |
+
__all__ = ["app", "PythonCodeReviewEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI + Gradio entrypoint for TorchReview Copilot."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from openenv.core.env_server.http_server import create_app
|
| 7 |
+
except Exception as exc: # pragma: no cover
|
| 8 |
+
raise ImportError(
|
| 9 |
+
"openenv-core is required to run the API server. Install project dependencies first."
|
| 10 |
+
) from exc
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import gradio as gr
|
| 14 |
+
except Exception:
|
| 15 |
+
gr = None # type: ignore[assignment]
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from ..Models import PythonCodeReviewAction, PythonCodeReviewObservation
|
| 19 |
+
from .env import PythonCodeReviewEnvironment
|
| 20 |
+
from .demo import build_demo
|
| 21 |
+
except ImportError:
|
| 22 |
+
from Models import PythonCodeReviewAction, PythonCodeReviewObservation
|
| 23 |
+
from server.env import PythonCodeReviewEnvironment
|
| 24 |
+
from server.demo import build_demo
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_application():
|
| 28 |
+
"""Compose the OpenEnv API with the Gradio demo frontend."""
|
| 29 |
+
|
| 30 |
+
api_app = create_app(
|
| 31 |
+
PythonCodeReviewEnvironment,
|
| 32 |
+
PythonCodeReviewAction,
|
| 33 |
+
PythonCodeReviewObservation,
|
| 34 |
+
env_name="python_code_review_env",
|
| 35 |
+
max_concurrent_envs=4,
|
| 36 |
+
)
|
| 37 |
+
if gr is None:
|
| 38 |
+
return api_app
|
| 39 |
+
return gr.mount_gradio_app(api_app, build_demo(), path="/")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
app = build_application()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
|
| 46 |
+
import uvicorn
|
| 47 |
+
|
| 48 |
+
uvicorn.run(app, host=host, port=port)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
main()
|
server/demo.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio UI for TorchReview Copilot."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from html import escape
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from ..triage import get_default_engine
|
| 11 |
+
except ImportError:
|
| 12 |
+
from triage import get_default_engine
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
CSS = """
|
| 16 |
+
:root {
|
| 17 |
+
--paper: #f6f1e8;
|
| 18 |
+
--ink: #162521;
|
| 19 |
+
--accent: #d95d39;
|
| 20 |
+
--panel: #fffdf8;
|
| 21 |
+
--border: #d6c4b8;
|
| 22 |
+
--muted: #5f6f67;
|
| 23 |
+
--good: #2d7d62;
|
| 24 |
+
--warn: #b76516;
|
| 25 |
+
--high: #b23a48;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
body, .gradio-container {
|
| 29 |
+
background:
|
| 30 |
+
radial-gradient(circle at top left, rgba(247, 197, 159, 0.35), transparent 35%),
|
| 31 |
+
linear-gradient(135deg, #f9f6ef 0%, #efe5d3 100%);
|
| 32 |
+
color: var(--ink);
|
| 33 |
+
font-family: Georgia, "Times New Roman", serif;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
.gradio-container {
|
| 37 |
+
max-width: 1260px !important;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
.hero-card,
|
| 41 |
+
.metric-card,
|
| 42 |
+
.subtle-card {
|
| 43 |
+
background: rgba(255, 253, 248, 0.95);
|
| 44 |
+
border: 1px solid var(--border);
|
| 45 |
+
border-radius: 20px;
|
| 46 |
+
box-shadow: 0 16px 40px rgba(22, 37, 33, 0.08);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.hero-card {
|
| 50 |
+
padding: 28px 30px;
|
| 51 |
+
margin-bottom: 12px;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
.metric-card,
|
| 55 |
+
.subtle-card {
|
| 56 |
+
padding: 20px 22px;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.eyebrow {
|
| 60 |
+
text-transform: uppercase;
|
| 61 |
+
letter-spacing: 0.12em;
|
| 62 |
+
font-size: 12px;
|
| 63 |
+
color: var(--accent);
|
| 64 |
+
margin-bottom: 10px;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
.hero-title {
|
| 68 |
+
font-size: 44px;
|
| 69 |
+
line-height: 1.05;
|
| 70 |
+
margin: 0 0 10px;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
.hero-copy {
|
| 74 |
+
margin: 0;
|
| 75 |
+
font-size: 18px;
|
| 76 |
+
line-height: 1.55;
|
| 77 |
+
color: var(--muted);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.summary-title {
|
| 81 |
+
display: flex;
|
| 82 |
+
justify-content: space-between;
|
| 83 |
+
gap: 12px;
|
| 84 |
+
align-items: center;
|
| 85 |
+
margin-bottom: 14px;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.pill {
|
| 89 |
+
display: inline-block;
|
| 90 |
+
padding: 6px 12px;
|
| 91 |
+
border-radius: 999px;
|
| 92 |
+
font-size: 12px;
|
| 93 |
+
text-transform: uppercase;
|
| 94 |
+
letter-spacing: 0.08em;
|
| 95 |
+
background: #efe5d3;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.pill.low { color: var(--good); }
|
| 99 |
+
.pill.medium { color: var(--warn); }
|
| 100 |
+
.pill.high { color: var(--high); }
|
| 101 |
+
|
| 102 |
+
.summary-grid {
|
| 103 |
+
display: grid;
|
| 104 |
+
grid-template-columns: repeat(2, minmax(0, 1fr));
|
| 105 |
+
gap: 12px;
|
| 106 |
+
margin-top: 16px;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
.summary-stat {
|
| 110 |
+
background: #fff7ef;
|
| 111 |
+
border-radius: 14px;
|
| 112 |
+
padding: 12px 14px;
|
| 113 |
+
border: 1px solid rgba(214, 196, 184, 0.8);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
.summary-stat strong {
|
| 117 |
+
display: block;
|
| 118 |
+
font-size: 12px;
|
| 119 |
+
text-transform: uppercase;
|
| 120 |
+
letter-spacing: 0.08em;
|
| 121 |
+
color: var(--muted);
|
| 122 |
+
margin-bottom: 6px;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
.radar-wrap {
|
| 126 |
+
display: grid;
|
| 127 |
+
gap: 12px;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
.bar {
|
| 131 |
+
display: grid;
|
| 132 |
+
gap: 6px;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.bar-head {
|
| 136 |
+
display: flex;
|
| 137 |
+
justify-content: space-between;
|
| 138 |
+
font-size: 13px;
|
| 139 |
+
color: var(--muted);
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.bar-track {
|
| 143 |
+
width: 100%;
|
| 144 |
+
height: 12px;
|
| 145 |
+
background: #f2e5d6;
|
| 146 |
+
border-radius: 999px;
|
| 147 |
+
overflow: hidden;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
.bar-fill {
|
| 151 |
+
height: 100%;
|
| 152 |
+
border-radius: 999px;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
.matched-box {
|
| 156 |
+
background: #fff7ef;
|
| 157 |
+
border: 1px solid rgba(214, 196, 184, 0.8);
|
| 158 |
+
border-radius: 16px;
|
| 159 |
+
padding: 14px;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.how-grid {
|
| 163 |
+
display: grid;
|
| 164 |
+
grid-template-columns: repeat(4, minmax(0, 1fr));
|
| 165 |
+
gap: 12px;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
.how-step {
|
| 169 |
+
background: rgba(255, 253, 248, 0.9);
|
| 170 |
+
border: 1px solid var(--border);
|
| 171 |
+
border-radius: 18px;
|
| 172 |
+
padding: 16px;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
@media (max-width: 900px) {
|
| 176 |
+
.hero-title {
|
| 177 |
+
font-size: 34px;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
.summary-grid,
|
| 181 |
+
.how-grid {
|
| 182 |
+
grid-template-columns: 1fr;
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _default_outputs() -> tuple[str, str, str, str, str]:
|
| 189 |
+
return (
|
| 190 |
+
"<div class='metric-card'><div class='eyebrow'>Awaiting Analysis</div><p class='hero-copy'>Paste Python code, add an optional traceback, or load one of the built-in examples.</p></div>",
|
| 191 |
+
"<div class='metric-card'><div class='eyebrow'>Live Triage Radar</div><p class='hero-copy'>Confidence bars will appear after the first analysis run.</p></div>",
|
| 192 |
+
"### Improvement Plan\nAnalyze a sample to generate syntax, edge-case, and scalability recommendations.",
|
| 193 |
+
"### Known Pattern Match\nThe nearest OpenEnv task will be highlighted here after inference runs.",
|
| 194 |
+
"### Model Notes\nBackend and extracted signal details will appear here.",
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _summary_html(result) -> str:
|
| 199 |
+
issue = escape(result.issue_label.title())
|
| 200 |
+
summary = escape(result.summary)
|
| 201 |
+
next_action = escape(result.suggested_next_action)
|
| 202 |
+
return f"""
|
| 203 |
+
<div class="metric-card">
|
| 204 |
+
<div class="summary-title">
|
| 205 |
+
<div>
|
| 206 |
+
<div class="eyebrow">TorchReview Verdict</div>
|
| 207 |
+
<h3 style="margin:0;font-size:30px;">{issue} Issue</h3>
|
| 208 |
+
</div>
|
| 209 |
+
<span class="pill {escape(result.repair_risk)}">{escape(result.repair_risk)} repair risk</span>
|
| 210 |
+
</div>
|
| 211 |
+
<p class="hero-copy">{summary}</p>
|
| 212 |
+
<div class="summary-grid">
|
| 213 |
+
<div class="summary-stat">
|
| 214 |
+
<strong>Reward Score</strong>
|
| 215 |
+
{result.reward_score:.0%}
|
| 216 |
+
</div>
|
| 217 |
+
<div class="summary-stat">
|
| 218 |
+
<strong>ML Quality</strong>
|
| 219 |
+
{result.ml_quality_score:.0%}
|
| 220 |
+
</div>
|
| 221 |
+
<div class="summary-stat">
|
| 222 |
+
<strong>Matched Pattern</strong>
|
| 223 |
+
{escape(result.matched_pattern.title)}
|
| 224 |
+
</div>
|
| 225 |
+
<div class="summary-stat">
|
| 226 |
+
<strong>Inference Backend</strong>
|
| 227 |
+
{escape(result.model_backend)}
|
| 228 |
+
</div>
|
| 229 |
+
<div class="summary-stat">
|
| 230 |
+
<strong>Lint Score</strong>
|
| 231 |
+
{result.lint_score:.0%}
|
| 232 |
+
</div>
|
| 233 |
+
<div class="summary-stat">
|
| 234 |
+
<strong>Complexity Penalty</strong>
|
| 235 |
+
{result.complexity_penalty:.0%}
|
| 236 |
+
</div>
|
| 237 |
+
<div class="summary-stat">
|
| 238 |
+
<strong>Next Action</strong>
|
| 239 |
+
{next_action}
|
| 240 |
+
</div>
|
| 241 |
+
</div>
|
| 242 |
+
</div>
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _radar_html(result) -> str:
|
| 247 |
+
colors = {
|
| 248 |
+
"syntax": "#d95d39",
|
| 249 |
+
"logic": "#4f772d",
|
| 250 |
+
"performance": "#355070",
|
| 251 |
+
}
|
| 252 |
+
bars = []
|
| 253 |
+
for label, score in result.confidence_scores.items():
|
| 254 |
+
bars.append(
|
| 255 |
+
f"""
|
| 256 |
+
<div class="bar">
|
| 257 |
+
<div class="bar-head"><span>{escape(label.title())}</span><span>{score:.0%}</span></div>
|
| 258 |
+
<div class="bar-track">
|
| 259 |
+
<div class="bar-fill" style="width:{score * 100:.1f}%; background:{colors.get(label, '#d95d39')};"></div>
|
| 260 |
+
</div>
|
| 261 |
+
</div>
|
| 262 |
+
"""
|
| 263 |
+
)
|
| 264 |
+
return f"""
|
| 265 |
+
<div class="metric-card radar-wrap">
|
| 266 |
+
<div class="eyebrow">Live Triage Radar</div>
|
| 267 |
+
{''.join(bars)}
|
| 268 |
+
<div class="matched-box">
|
| 269 |
+
<strong>Nearest Known Pattern:</strong> {escape(result.matched_pattern.title)}<br>
|
| 270 |
+
<span style="color:#5f6f67;">{escape(result.matched_pattern.summary)}</span>
|
| 271 |
+
</div>
|
| 272 |
+
</div>
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _plan_markdown(result) -> str:
|
| 277 |
+
plan_lines = "\n".join(f"{index + 1}. {step}" for index, step in enumerate(result.repair_plan))
|
| 278 |
+
return (
|
| 279 |
+
"### Improvement Plan\n"
|
| 280 |
+
f"**Primary issue:** `{result.issue_label}`\n\n"
|
| 281 |
+
f"{plan_lines}\n\n"
|
| 282 |
+
f"**Suggested next action:** {result.suggested_next_action}"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _match_markdown(result) -> str:
|
| 287 |
+
return (
|
| 288 |
+
"### Known Pattern Match\n"
|
| 289 |
+
f"**Task:** `{result.matched_pattern.task_id}` \n"
|
| 290 |
+
f"**Title:** {result.matched_pattern.title} \n"
|
| 291 |
+
f"**Why it matched:** {result.matched_pattern.rationale} \n"
|
| 292 |
+
f"**Similarity:** {result.matched_pattern.similarity:.0%}"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _model_markdown(result) -> str:
|
| 297 |
+
signal_lines = "\n".join(
|
| 298 |
+
f"- `{signal.name}` -> {signal.value} ({signal.impact}, weight {signal.weight:.2f}): {signal.evidence}"
|
| 299 |
+
for signal in result.extracted_signals
|
| 300 |
+
) or "- No strong static signals were extracted."
|
| 301 |
+
notes = "\n".join(f"- {item}" for item in result.inference_notes) or "- No additional backend notes."
|
| 302 |
+
return (
|
| 303 |
+
"### Model Notes\n"
|
| 304 |
+
f"- **Model backend:** `{result.model_backend}`\n"
|
| 305 |
+
f"- **Model id:** `{result.model_id}`\n"
|
| 306 |
+
f"- **Analysis time:** `{result.analysis_time_ms:.2f} ms`\n\n"
|
| 307 |
+
"### Reward Formula\n"
|
| 308 |
+
f"- `reward = (0.5 x {result.ml_quality_score:.2f}) + (0.3 x {result.lint_score:.2f}) - (0.2 x {result.complexity_penalty:.2f})`\n"
|
| 309 |
+
f"- **Final reward:** `{result.reward_score:.2f}`\n\n"
|
| 310 |
+
"### Extracted Signals\n"
|
| 311 |
+
f"{signal_lines}\n\n"
|
| 312 |
+
"### Backend Notes\n"
|
| 313 |
+
f"{notes}"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def analyze_inputs(code: str, traceback_text: str, context_window: str) -> tuple[str, str, str, str, str]:
|
| 318 |
+
"""Run the triage engine and format outputs for the Gradio UI."""
|
| 319 |
+
|
| 320 |
+
result = get_default_engine().triage(code or "", traceback_text or "", context_window or "")
|
| 321 |
+
return (
|
| 322 |
+
_summary_html(result),
|
| 323 |
+
_radar_html(result),
|
| 324 |
+
_plan_markdown(result),
|
| 325 |
+
_match_markdown(result),
|
| 326 |
+
_model_markdown(result),
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def load_example(example_key: str) -> tuple[str, str, str, str, str, str, str, str, str]:
|
| 331 |
+
"""Populate the UI from a built-in example and immediately analyze it."""
|
| 332 |
+
|
| 333 |
+
example = get_default_engine().example_map()[example_key]
|
| 334 |
+
outputs = analyze_inputs(example.code, example.traceback_text, example.context_window)
|
| 335 |
+
header = (
|
| 336 |
+
f"### Example Scenario\n"
|
| 337 |
+
f"**{example.title}** \n"
|
| 338 |
+
f"{example.summary} \n"
|
| 339 |
+
f"Label target: `{example.label}`"
|
| 340 |
+
)
|
| 341 |
+
return (example.code, example.traceback_text, example.context_window, header, *outputs)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def build_demo() -> gr.Blocks:
|
| 345 |
+
"""Create the TorchReview Copilot Gradio application."""
|
| 346 |
+
|
| 347 |
+
examples = get_default_engine().example_map()
|
| 348 |
+
first_example = next(iter(examples.values()))
|
| 349 |
+
|
| 350 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="amber"), css=CSS, title="TorchReview Copilot") as demo:
|
| 351 |
+
gr.HTML(
|
| 352 |
+
"""
|
| 353 |
+
<div class="hero-card">
|
| 354 |
+
<div class="eyebrow">Meta PyTorch OpenEnv Hackathon Demo</div>
|
| 355 |
+
<h1 class="hero-title">TorchReview Copilot</h1>
|
| 356 |
+
<p class="hero-copy">
|
| 357 |
+
AI-powered code review and improvement system using PyTorch to score code quality, surface bugs,
|
| 358 |
+
and generate a three-step improvement plan. OpenEnv stays underneath as the deterministic validation engine.
|
| 359 |
+
</p>
|
| 360 |
+
</div>
|
| 361 |
+
"""
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
with gr.Row():
|
| 365 |
+
with gr.Column(scale=6):
|
| 366 |
+
example_choice = gr.Radio(
|
| 367 |
+
choices=[(item.title, item.key) for item in examples.values()],
|
| 368 |
+
value=first_example.key,
|
| 369 |
+
label="Try a built-in failure scenario",
|
| 370 |
+
info="Switching examples updates the Live Triage Radar immediately.",
|
| 371 |
+
)
|
| 372 |
+
example_header = gr.Markdown()
|
| 373 |
+
code_input = gr.Code(
|
| 374 |
+
value=first_example.code,
|
| 375 |
+
language="python",
|
| 376 |
+
lines=18,
|
| 377 |
+
label="Python code under review",
|
| 378 |
+
)
|
| 379 |
+
traceback_input = gr.Textbox(
|
| 380 |
+
value=first_example.traceback_text,
|
| 381 |
+
lines=7,
|
| 382 |
+
label="Optional traceback / failing test output",
|
| 383 |
+
placeholder="Paste stack traces, assertion failures, or benchmark notes here.",
|
| 384 |
+
)
|
| 385 |
+
context_input = gr.Textbox(
|
| 386 |
+
value=first_example.context_window,
|
| 387 |
+
lines=4,
|
| 388 |
+
label="Context window",
|
| 389 |
+
placeholder="Describe expected behavior, constraints, or repository context.",
|
| 390 |
+
)
|
| 391 |
+
with gr.Row():
|
| 392 |
+
analyze_button = gr.Button("Analyze & Score Code", variant="primary")
|
| 393 |
+
clear_button = gr.Button("Clear Inputs", variant="secondary")
|
| 394 |
+
|
| 395 |
+
with gr.Column(scale=5):
|
| 396 |
+
summary_html = gr.HTML()
|
| 397 |
+
radar_html = gr.HTML()
|
| 398 |
+
plan_markdown = gr.Markdown()
|
| 399 |
+
match_markdown = gr.Markdown()
|
| 400 |
+
model_markdown = gr.Markdown()
|
| 401 |
+
|
| 402 |
+
gr.HTML(
|
| 403 |
+
"""
|
| 404 |
+
<div class="subtle-card" style="margin-top: 12px;">
|
| 405 |
+
<div class="eyebrow">How It Works</div>
|
| 406 |
+
<div class="how-grid">
|
| 407 |
+
<div class="how-step"><strong>Input</strong><br>Code plus optional traceback or benchmark signal.</div>
|
| 408 |
+
<div class="how-step"><strong>Processing</strong><br>Static checks extract parser, lint, complexity, and runtime clues.</div>
|
| 409 |
+
<div class="how-step"><strong>Model</strong><br>CodeBERTa embeddings run through PyTorch and score code quality against known OpenEnv patterns.</div>
|
| 410 |
+
<div class="how-step"><strong>Output</strong><br>Confidence radar, reward score, and a three-step improvement plan.</div>
|
| 411 |
+
</div>
|
| 412 |
+
</div>
|
| 413 |
+
"""
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
example_choice.change(
|
| 417 |
+
fn=load_example,
|
| 418 |
+
inputs=example_choice,
|
| 419 |
+
outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
|
| 420 |
+
show_progress="hidden",
|
| 421 |
+
)
|
| 422 |
+
analyze_button.click(
|
| 423 |
+
fn=analyze_inputs,
|
| 424 |
+
inputs=[code_input, traceback_input, context_input],
|
| 425 |
+
outputs=[summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
|
| 426 |
+
show_progress="minimal",
|
| 427 |
+
)
|
| 428 |
+
clear_button.click(
|
| 429 |
+
fn=lambda: ("", "", "", "### Example Scenario\nChoose a built-in example or paste custom code.", *_default_outputs()),
|
| 430 |
+
inputs=None,
|
| 431 |
+
outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
|
| 432 |
+
show_progress="hidden",
|
| 433 |
+
)
|
| 434 |
+
demo.load(
|
| 435 |
+
fn=load_example,
|
| 436 |
+
inputs=example_choice,
|
| 437 |
+
outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
|
| 438 |
+
show_progress="hidden",
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
return demo
|
server/env.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv environment implementation for Python code review tasks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Optional, Tuple
|
| 6 |
+
from uuid import uuid4
|
| 7 |
+
|
| 8 |
+
from openenv.core.env_server.interfaces import Environment
|
| 9 |
+
from openenv.core.env_server.types import EnvironmentMetadata
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from ..graders import grade_task
|
| 13 |
+
from ..graders.shared import component_score, safe_ratio, strict_score
|
| 14 |
+
from ..Models import (
|
| 15 |
+
HistoryEntry,
|
| 16 |
+
PythonCodeReviewAction,
|
| 17 |
+
PythonCodeReviewObservation,
|
| 18 |
+
PythonCodeReviewState,
|
| 19 |
+
RewardDetails,
|
| 20 |
+
TaskGrade,
|
| 21 |
+
)
|
| 22 |
+
from ..tasks import ReviewTask, list_tasks, select_task
|
| 23 |
+
except ImportError:
|
| 24 |
+
from graders import grade_task
|
| 25 |
+
from graders.shared import component_score, safe_ratio, strict_score
|
| 26 |
+
from Models import (
|
| 27 |
+
HistoryEntry,
|
| 28 |
+
PythonCodeReviewAction,
|
| 29 |
+
PythonCodeReviewObservation,
|
| 30 |
+
PythonCodeReviewState,
|
| 31 |
+
RewardDetails,
|
| 32 |
+
TaskGrade,
|
| 33 |
+
)
|
| 34 |
+
from tasks import ReviewTask, list_tasks, select_task
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _empty_grade() -> TaskGrade:
|
| 38 |
+
return TaskGrade(
|
| 39 |
+
score=component_score(0.01),
|
| 40 |
+
syntax_score=component_score(0.01),
|
| 41 |
+
tests_passed=0,
|
| 42 |
+
tests_total=0,
|
| 43 |
+
quality_score=component_score(0.01),
|
| 44 |
+
runtime_score=component_score(0.01),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _reward_value(value: float) -> float:
|
| 49 |
+
return strict_score(value)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class PythonCodeReviewEnvironment(
|
| 53 |
+
Environment[PythonCodeReviewAction, PythonCodeReviewObservation, PythonCodeReviewState]
|
| 54 |
+
):
|
| 55 |
+
"""Structured environment for deterministic Python code review workflows."""
|
| 56 |
+
|
| 57 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 58 |
+
|
| 59 |
+
def __init__(self, verbose: bool = False, **_: Any) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.verbose = verbose
|
| 62 |
+
self._task: ReviewTask = list_tasks()[0]
|
| 63 |
+
self._current_code: str = self._task.starter_code
|
| 64 |
+
self._history: list[HistoryEntry] = []
|
| 65 |
+
self._last_reward = RewardDetails(value=0.1, reason="Environment initialized.")
|
| 66 |
+
self._current_grade = _empty_grade()
|
| 67 |
+
self._state = PythonCodeReviewState(episode_id=str(uuid4()), step_count=0)
|
| 68 |
+
self.reset()
|
| 69 |
+
|
| 70 |
+
def reset(
|
| 71 |
+
self,
|
| 72 |
+
seed: Optional[int] = None,
|
| 73 |
+
episode_id: Optional[str] = None,
|
| 74 |
+
**kwargs: Any,
|
| 75 |
+
) -> PythonCodeReviewObservation:
|
| 76 |
+
task_id = kwargs.get("task_id")
|
| 77 |
+
self._task = select_task(seed=seed, task_id=task_id)
|
| 78 |
+
self._current_code = self._task.starter_code
|
| 79 |
+
self._history = []
|
| 80 |
+
self._last_reward = RewardDetails(value=0.1, reason="Environment reset.")
|
| 81 |
+
self._current_grade = grade_task(self._task, self._current_code, include_hidden=False)
|
| 82 |
+
|
| 83 |
+
self._state = PythonCodeReviewState(
|
| 84 |
+
episode_id=episode_id or str(uuid4()),
|
| 85 |
+
step_count=0,
|
| 86 |
+
task_id=self._task.task_id,
|
| 87 |
+
difficulty=self._task.difficulty,
|
| 88 |
+
task_kind=self._task.task_kind,
|
| 89 |
+
attempts_remaining=self._task.max_steps,
|
| 90 |
+
current_code=self._current_code,
|
| 91 |
+
errors=self._format_errors(self._current_grade),
|
| 92 |
+
test_results=self._format_test_results(self._current_grade),
|
| 93 |
+
history=[],
|
| 94 |
+
score=self._current_grade.score,
|
| 95 |
+
done=False,
|
| 96 |
+
)
|
| 97 |
+
return self._build_observation(
|
| 98 |
+
grade=self._current_grade,
|
| 99 |
+
status=f"Loaded task {self._task.task_id}.",
|
| 100 |
+
reward_details=self._last_reward,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def step(
|
| 104 |
+
self,
|
| 105 |
+
action: PythonCodeReviewAction,
|
| 106 |
+
timeout_s: Optional[float] = None,
|
| 107 |
+
**kwargs: Any,
|
| 108 |
+
) -> PythonCodeReviewObservation:
|
| 109 |
+
observation, _, _, _ = self._step_transition(action, timeout_s=timeout_s, **kwargs)
|
| 110 |
+
return observation
|
| 111 |
+
|
| 112 |
+
def step_result(
|
| 113 |
+
self,
|
| 114 |
+
action: PythonCodeReviewAction,
|
| 115 |
+
timeout_s: Optional[float] = None,
|
| 116 |
+
**kwargs: Any,
|
| 117 |
+
) -> Tuple[PythonCodeReviewObservation, float, bool, Dict[str, Any]]:
|
| 118 |
+
"""Gym-style helper used by local scripts and tests."""
|
| 119 |
+
|
| 120 |
+
return self._step_transition(action, timeout_s=timeout_s, **kwargs)
|
| 121 |
+
|
| 122 |
+
def _step_transition(
|
| 123 |
+
self,
|
| 124 |
+
action: PythonCodeReviewAction,
|
| 125 |
+
timeout_s: Optional[float] = None,
|
| 126 |
+
**kwargs: Any,
|
| 127 |
+
) -> Tuple[PythonCodeReviewObservation, float, bool, Dict[str, Any]]:
|
| 128 |
+
if self._state.done:
|
| 129 |
+
reward = RewardDetails(
|
| 130 |
+
value=_reward_value(0.05 + 0.25 * self._current_grade.score),
|
| 131 |
+
reason="Episode already finished. Call reset() to continue.",
|
| 132 |
+
)
|
| 133 |
+
observation = self._build_observation(
|
| 134 |
+
grade=self._current_grade,
|
| 135 |
+
status="Episode already finished.",
|
| 136 |
+
reward_details=reward,
|
| 137 |
+
)
|
| 138 |
+
return observation, reward.value, observation.done, {"task_id": observation.task_id, "score": observation.score}
|
| 139 |
+
|
| 140 |
+
previous_grade = self._current_grade
|
| 141 |
+
status = ""
|
| 142 |
+
invalid_action = False
|
| 143 |
+
code_changed = False
|
| 144 |
+
use_hidden_grading = False
|
| 145 |
+
|
| 146 |
+
if action.action_type == "edit_code":
|
| 147 |
+
if not action.code or not action.code.strip():
|
| 148 |
+
invalid_action = True
|
| 149 |
+
status = "edit_code requires a non-empty code payload."
|
| 150 |
+
else:
|
| 151 |
+
code_changed = action.code != self._current_code
|
| 152 |
+
self._current_code = action.code
|
| 153 |
+
status = "Updated working copy from agent patch."
|
| 154 |
+
elif action.action_type == "submit_solution":
|
| 155 |
+
if action.code is not None and action.code.strip():
|
| 156 |
+
code_changed = action.code != self._current_code
|
| 157 |
+
self._current_code = action.code
|
| 158 |
+
use_hidden_grading = True
|
| 159 |
+
status = "Submission received for final grading."
|
| 160 |
+
elif action.action_type == "run_tests":
|
| 161 |
+
status = "Executed public validation suite."
|
| 162 |
+
elif action.action_type == "analyze_code":
|
| 163 |
+
status = "Generated static review summary."
|
| 164 |
+
else: # pragma: no cover
|
| 165 |
+
invalid_action = True
|
| 166 |
+
status = f"Unsupported action_type: {action.action_type}"
|
| 167 |
+
|
| 168 |
+
self._state.step_count += 1
|
| 169 |
+
|
| 170 |
+
if invalid_action:
|
| 171 |
+
current_grade = previous_grade
|
| 172 |
+
else:
|
| 173 |
+
current_grade = grade_task(
|
| 174 |
+
self._task,
|
| 175 |
+
self._current_code,
|
| 176 |
+
include_hidden=use_hidden_grading,
|
| 177 |
+
timeout_s=timeout_s or 3.0,
|
| 178 |
+
)
|
| 179 |
+
if action.action_type == "analyze_code":
|
| 180 |
+
status = self._analysis_status(current_grade)
|
| 181 |
+
elif action.action_type == "run_tests":
|
| 182 |
+
status = self._run_tests_status(current_grade, use_hidden_grading)
|
| 183 |
+
elif action.action_type == "submit_solution":
|
| 184 |
+
status = self._submission_status(current_grade)
|
| 185 |
+
|
| 186 |
+
done = use_hidden_grading or self._state.step_count >= self._task.max_steps
|
| 187 |
+
if self._state.step_count >= self._task.max_steps and not use_hidden_grading:
|
| 188 |
+
status = f"{status} Step budget exhausted."
|
| 189 |
+
|
| 190 |
+
reward_details = self._compute_reward(
|
| 191 |
+
previous_grade=previous_grade,
|
| 192 |
+
current_grade=current_grade,
|
| 193 |
+
action=action,
|
| 194 |
+
invalid_action=invalid_action,
|
| 195 |
+
timed_out=current_grade.timed_out,
|
| 196 |
+
code_changed=code_changed,
|
| 197 |
+
final_submission=use_hidden_grading,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self._history.append(
|
| 201 |
+
HistoryEntry(
|
| 202 |
+
step=self._state.step_count,
|
| 203 |
+
action_type=action.action_type,
|
| 204 |
+
status=status,
|
| 205 |
+
reward=reward_details.value,
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self._current_grade = current_grade
|
| 210 |
+
self._last_reward = reward_details
|
| 211 |
+
attempts_remaining = max(self._task.max_steps - self._state.step_count, 0)
|
| 212 |
+
|
| 213 |
+
self._state.task_id = self._task.task_id
|
| 214 |
+
self._state.difficulty = self._task.difficulty
|
| 215 |
+
self._state.task_kind = self._task.task_kind
|
| 216 |
+
self._state.attempts_remaining = attempts_remaining
|
| 217 |
+
self._state.current_code = self._current_code
|
| 218 |
+
self._state.errors = self._format_errors(current_grade)
|
| 219 |
+
self._state.test_results = self._format_test_results(current_grade)
|
| 220 |
+
self._state.history = list(self._history)
|
| 221 |
+
self._state.score = current_grade.score
|
| 222 |
+
self._state.done = done
|
| 223 |
+
|
| 224 |
+
observation = self._build_observation(
|
| 225 |
+
grade=current_grade,
|
| 226 |
+
status=status,
|
| 227 |
+
reward_details=reward_details,
|
| 228 |
+
)
|
| 229 |
+
return observation, reward_details.value, observation.done, {"task_id": observation.task_id, "score": observation.score}
|
| 230 |
+
|
| 231 |
+
@property
|
| 232 |
+
def state(self) -> PythonCodeReviewState:
|
| 233 |
+
return self._state
|
| 234 |
+
|
| 235 |
+
def _build_observation(
|
| 236 |
+
self,
|
| 237 |
+
*,
|
| 238 |
+
grade: TaskGrade,
|
| 239 |
+
status: str,
|
| 240 |
+
reward_details: RewardDetails,
|
| 241 |
+
) -> PythonCodeReviewObservation:
|
| 242 |
+
return PythonCodeReviewObservation(
|
| 243 |
+
task_id=self._task.task_id,
|
| 244 |
+
title=self._task.title,
|
| 245 |
+
difficulty=self._task.difficulty,
|
| 246 |
+
task_kind=self._task.task_kind,
|
| 247 |
+
task_description=self._task.task_description,
|
| 248 |
+
current_code=self._current_code,
|
| 249 |
+
errors=self._format_errors(grade),
|
| 250 |
+
test_results=self._format_test_results(grade),
|
| 251 |
+
visible_tests=list(self._task.visible_tests),
|
| 252 |
+
history=list(self._history),
|
| 253 |
+
attempts_remaining=self._state.attempts_remaining,
|
| 254 |
+
last_action_status=status,
|
| 255 |
+
score=grade.score,
|
| 256 |
+
reward=reward_details.value,
|
| 257 |
+
done=self._state.done,
|
| 258 |
+
reward_details=reward_details,
|
| 259 |
+
metadata={
|
| 260 |
+
"goal": self._task.goal,
|
| 261 |
+
"repo_summary": self._task.repo_summary,
|
| 262 |
+
"changed_files": self._task.changed_files,
|
| 263 |
+
"available_files": self._task.available_files,
|
| 264 |
+
"grade_details": grade.details,
|
| 265 |
+
},
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
def _compute_reward(
|
| 269 |
+
self,
|
| 270 |
+
*,
|
| 271 |
+
previous_grade: TaskGrade,
|
| 272 |
+
current_grade: TaskGrade,
|
| 273 |
+
action: PythonCodeReviewAction,
|
| 274 |
+
invalid_action: bool,
|
| 275 |
+
timed_out: bool,
|
| 276 |
+
code_changed: bool,
|
| 277 |
+
final_submission: bool,
|
| 278 |
+
) -> RewardDetails:
|
| 279 |
+
prev_score = previous_grade.score
|
| 280 |
+
curr_score = current_grade.score
|
| 281 |
+
prev_rate = safe_ratio(previous_grade.tests_passed, previous_grade.tests_total)
|
| 282 |
+
curr_rate = safe_ratio(current_grade.tests_passed, current_grade.tests_total)
|
| 283 |
+
|
| 284 |
+
syntax_reward = 0.14 if previous_grade.syntax_score < 0.9 and current_grade.syntax_score >= 0.9 else 0.0
|
| 285 |
+
test_reward = round(max(curr_rate - prev_rate, 0.0) * 0.22, 3)
|
| 286 |
+
progress_delta = round(max(curr_score - prev_score, 0.0) * 0.35, 3)
|
| 287 |
+
quality_bonus = round(max(current_grade.quality_score - previous_grade.quality_score, 0.0) * 0.08, 3)
|
| 288 |
+
correctness_bonus = 0.12 if final_submission and curr_score >= 0.94 and prev_score < 0.94 else 0.0
|
| 289 |
+
|
| 290 |
+
invalid_action_penalty = 0.12 if invalid_action else 0.0
|
| 291 |
+
timeout_penalty = 0.14 if timed_out else 0.0
|
| 292 |
+
regression_penalty = round(max(prev_score - curr_score, 0.0) * 0.2, 3)
|
| 293 |
+
stagnation_penalty = 0.06 if action.action_type == "edit_code" and not code_changed else 0.0
|
| 294 |
+
|
| 295 |
+
raw_value = (
|
| 296 |
+
0.1
|
| 297 |
+
+ 0.45 * curr_score
|
| 298 |
+
+ syntax_reward
|
| 299 |
+
+ test_reward
|
| 300 |
+
+ progress_delta
|
| 301 |
+
+ quality_bonus
|
| 302 |
+
+ correctness_bonus
|
| 303 |
+
- invalid_action_penalty
|
| 304 |
+
- timeout_penalty
|
| 305 |
+
- regression_penalty
|
| 306 |
+
- stagnation_penalty
|
| 307 |
+
)
|
| 308 |
+
value = _reward_value(raw_value)
|
| 309 |
+
|
| 310 |
+
reason_parts = []
|
| 311 |
+
if syntax_reward:
|
| 312 |
+
reason_parts.append("syntax fixed")
|
| 313 |
+
if test_reward:
|
| 314 |
+
reason_parts.append("public test progress")
|
| 315 |
+
if progress_delta:
|
| 316 |
+
reason_parts.append("overall score improved")
|
| 317 |
+
if quality_bonus:
|
| 318 |
+
reason_parts.append("code quality improved")
|
| 319 |
+
if correctness_bonus:
|
| 320 |
+
reason_parts.append("full correctness bonus")
|
| 321 |
+
if invalid_action_penalty:
|
| 322 |
+
reason_parts.append("invalid action penalty")
|
| 323 |
+
if timeout_penalty:
|
| 324 |
+
reason_parts.append("timeout penalty")
|
| 325 |
+
if regression_penalty:
|
| 326 |
+
reason_parts.append("regression penalty")
|
| 327 |
+
if stagnation_penalty:
|
| 328 |
+
reason_parts.append("unchanged patch penalty")
|
| 329 |
+
if not reason_parts:
|
| 330 |
+
reason_parts.append("no meaningful state change")
|
| 331 |
+
|
| 332 |
+
return RewardDetails(
|
| 333 |
+
value=value,
|
| 334 |
+
syntax_reward=syntax_reward,
|
| 335 |
+
test_reward=test_reward,
|
| 336 |
+
correctness_bonus=correctness_bonus,
|
| 337 |
+
quality_bonus=quality_bonus,
|
| 338 |
+
progress_delta=progress_delta,
|
| 339 |
+
invalid_action_penalty=invalid_action_penalty,
|
| 340 |
+
timeout_penalty=timeout_penalty,
|
| 341 |
+
regression_penalty=regression_penalty,
|
| 342 |
+
stagnation_penalty=stagnation_penalty,
|
| 343 |
+
reason=", ".join(reason_parts),
|
| 344 |
+
prev_score=prev_score,
|
| 345 |
+
curr_score=curr_score,
|
| 346 |
+
code_changed=code_changed,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def _format_errors(self, grade: TaskGrade) -> str:
|
| 350 |
+
compile_error = str(grade.details.get("compile_error", "")).strip()
|
| 351 |
+
if compile_error:
|
| 352 |
+
return compile_error
|
| 353 |
+
return "Code parses successfully."
|
| 354 |
+
|
| 355 |
+
def _format_test_results(self, grade: TaskGrade) -> str:
|
| 356 |
+
parts = [grade.details.get("test_summary", "No test feedback available.")]
|
| 357 |
+
benchmark = grade.details.get("benchmark")
|
| 358 |
+
if isinstance(benchmark, dict):
|
| 359 |
+
parts.append(
|
| 360 |
+
"Benchmark: "
|
| 361 |
+
f"candidate {benchmark['candidate_seconds']}s vs baseline {benchmark['baseline_seconds']}s "
|
| 362 |
+
f"(x{benchmark['improvement_ratio']})."
|
| 363 |
+
)
|
| 364 |
+
elif isinstance(benchmark, str) and benchmark:
|
| 365 |
+
parts.append(f"Benchmark: {benchmark}")
|
| 366 |
+
return "\n".join(part for part in parts if part)
|
| 367 |
+
|
| 368 |
+
def _analysis_status(self, grade: TaskGrade) -> str:
|
| 369 |
+
notes = grade.details.get("quality_notes", [])
|
| 370 |
+
quality_note = notes[0] if notes else "No major static quality issues detected."
|
| 371 |
+
return (
|
| 372 |
+
f"Syntax score {grade.syntax_score:.2f}; "
|
| 373 |
+
f"public tests {grade.tests_passed}/{grade.tests_total}; "
|
| 374 |
+
f"quality {grade.quality_score:.2f}. {quality_note}"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
def _run_tests_status(self, grade: TaskGrade, include_hidden: bool) -> str:
|
| 378 |
+
visibility = "full" if include_hidden else "public"
|
| 379 |
+
return f"Ran {visibility} tests: {grade.tests_passed}/{grade.tests_total} passed."
|
| 380 |
+
|
| 381 |
+
def _submission_status(self, grade: TaskGrade) -> str:
|
| 382 |
+
runtime_text = ""
|
| 383 |
+
if isinstance(grade.details.get("benchmark"), dict):
|
| 384 |
+
runtime_text = f" runtime {grade.runtime_score:.2f};"
|
| 385 |
+
return (
|
| 386 |
+
f"Submission graded with score {grade.score:.2f}; "
|
| 387 |
+
f"tests {grade.tests_passed}/{grade.tests_total};"
|
| 388 |
+
f"{runtime_text} quality {grade.quality_score:.2f}."
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 392 |
+
return EnvironmentMetadata(
|
| 393 |
+
name="python_code_review_env",
|
| 394 |
+
description="Production-style Python code review environment with deterministic grading.",
|
| 395 |
+
version="1.0.0",
|
| 396 |
+
)
|
server/python_env_environment.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Backward-compatible import shim for the environment class."""
|
| 2 |
+
|
| 3 |
+
from .env import PythonCodeReviewEnvironment, PythonCodeReviewEnvironment as PythonEnvironment
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.2
|
| 2 |
+
fastapi>=0.111.0
|
| 3 |
+
gradio>=5.26.0
|
| 4 |
+
uvicorn>=0.30.0
|
| 5 |
+
pytest>=8.0.0
|
| 6 |
+
openai>=1.76.0
|
| 7 |
+
streamlit>=1.44.0
|
| 8 |
+
torch>=2.2.0
|
| 9 |
+
transformers>=4.45.0
|
services/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Service layer for orchestrating analysis, suggestions, and rewards."""
|
| 2 |
+
|
| 3 |
+
from .analysis_service import AnalysisService
|
| 4 |
+
from .reward_service import RewardService
|
| 5 |
+
from .suggestion_service import SuggestionService
|
| 6 |
+
|
| 7 |
+
__all__ = ["AnalysisService", "RewardService", "SuggestionService"]
|
services/analysis_service.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Orchestration layer for multi-domain code analysis."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
from typing import Any, Callable, Dict
|
| 7 |
+
|
| 8 |
+
from analyzers import analyze_data_science_code, analyze_dsa_code, analyze_ml_code, analyze_web_code
|
| 9 |
+
from models import PyTorchCodeAnalyzerModel
|
| 10 |
+
from schemas.request import AnalyzeCodeRequest
|
| 11 |
+
from schemas.response import AnalyzeCodeResponse, DomainAnalysis, StaticAnalysisSummary
|
| 12 |
+
from services.reward_service import RewardService
|
| 13 |
+
from services.suggestion_service import SuggestionService
|
| 14 |
+
from utils import estimate_complexity, parse_code_structure
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _lint_score(parsed: Dict[str, Any]) -> float:
|
| 18 |
+
"""Convert structural smells into a normalized lint-style score."""
|
| 19 |
+
|
| 20 |
+
score = 1.0
|
| 21 |
+
if not parsed.get("syntax_valid", True):
|
| 22 |
+
score -= 0.45
|
| 23 |
+
score -= min(parsed.get("long_lines", 0), 5) * 0.03
|
| 24 |
+
if parsed.get("tabs_used"):
|
| 25 |
+
score -= 0.1
|
| 26 |
+
if parsed.get("trailing_whitespace_lines"):
|
| 27 |
+
score -= 0.05
|
| 28 |
+
if parsed.get("docstring_ratio", 0.0) == 0.0 and parsed.get("function_names"):
|
| 29 |
+
score -= 0.08
|
| 30 |
+
return round(max(0.0, min(1.0, score)), 4)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AnalysisService:
|
| 34 |
+
"""End-to-end analysis pipeline shared by API and UI."""
|
| 35 |
+
|
| 36 |
+
def __init__(self) -> None:
|
| 37 |
+
self.model = PyTorchCodeAnalyzerModel()
|
| 38 |
+
self.reward_service = RewardService()
|
| 39 |
+
self.suggestion_service = SuggestionService()
|
| 40 |
+
self._analyzers: Dict[str, Callable[[str, Dict[str, Any], Dict[str, Any]], DomainAnalysis]] = {
|
| 41 |
+
"dsa": analyze_dsa_code,
|
| 42 |
+
"data_science": analyze_data_science_code,
|
| 43 |
+
"ml_dl": analyze_ml_code,
|
| 44 |
+
"web": analyze_web_code,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
def _heuristic_domain_scores(self, parsed: Dict[str, Any], code: str) -> Dict[str, float]:
|
| 48 |
+
"""Derive domain priors from imports and syntax-level hints."""
|
| 49 |
+
|
| 50 |
+
scores = {
|
| 51 |
+
"dsa": 0.2 + (0.15 if parsed.get("uses_recursion") else 0.0) + (0.15 if parsed.get("max_loop_depth", 0) >= 1 else 0.0),
|
| 52 |
+
"data_science": 0.2 + (0.35 if parsed.get("uses_pandas") or parsed.get("uses_numpy") else 0.0),
|
| 53 |
+
"ml_dl": 0.2 + (0.35 if parsed.get("uses_torch") or parsed.get("uses_sklearn") else 0.0),
|
| 54 |
+
"web": 0.2 + (0.35 if parsed.get("uses_fastapi") or parsed.get("uses_flask") else 0.0) + (0.1 if parsed.get("route_decorators") else 0.0),
|
| 55 |
+
"general": 0.2,
|
| 56 |
+
}
|
| 57 |
+
if "fastapi" in code.lower():
|
| 58 |
+
scores["web"] += 0.1
|
| 59 |
+
if "pandas" in code.lower() or "numpy" in code.lower():
|
| 60 |
+
scores["data_science"] += 0.1
|
| 61 |
+
if "torch" in code.lower():
|
| 62 |
+
scores["ml_dl"] += 0.1
|
| 63 |
+
if "while" in code or "for" in code:
|
| 64 |
+
scores["dsa"] += 0.05
|
| 65 |
+
return {key: round(min(value, 0.99), 4) for key, value in scores.items()}
|
| 66 |
+
|
| 67 |
+
def analyze(self, request: AnalyzeCodeRequest) -> AnalyzeCodeResponse:
|
| 68 |
+
"""Run the complete multi-domain analysis pipeline."""
|
| 69 |
+
|
| 70 |
+
started = time.perf_counter()
|
| 71 |
+
parsed = parse_code_structure(request.code)
|
| 72 |
+
complexity = estimate_complexity(parsed, request.code)
|
| 73 |
+
model_prediction = self.model.predict(request.code, request.context_window, parsed)
|
| 74 |
+
heuristic_scores = self._heuristic_domain_scores(parsed, request.code)
|
| 75 |
+
|
| 76 |
+
combined_scores = {}
|
| 77 |
+
for domain, heuristic_score in heuristic_scores.items():
|
| 78 |
+
model_score = float(model_prediction["domain_scores"].get(domain, 0.2))
|
| 79 |
+
combined_scores[domain] = round((0.6 * model_score) + (0.4 * heuristic_score), 4)
|
| 80 |
+
|
| 81 |
+
detected_domain = request.domain_hint if request.domain_hint != "auto" else max(combined_scores, key=combined_scores.get)
|
| 82 |
+
analyzer = self._analyzers.get(detected_domain)
|
| 83 |
+
domain_analysis = (
|
| 84 |
+
analyzer(request.code, parsed, complexity)
|
| 85 |
+
if analyzer is not None
|
| 86 |
+
else DomainAnalysis(
|
| 87 |
+
domain="general",
|
| 88 |
+
domain_score=0.6,
|
| 89 |
+
issues=[],
|
| 90 |
+
suggestions=["Add stronger domain-specific context for deeper analysis."],
|
| 91 |
+
highlights={},
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
lint_score = _lint_score(parsed)
|
| 96 |
+
score_breakdown = self.reward_service.compute(
|
| 97 |
+
ml_score=float(model_prediction["ml_quality_score"]),
|
| 98 |
+
domain_score=domain_analysis.domain_score,
|
| 99 |
+
lint_score=lint_score,
|
| 100 |
+
complexity_penalty=float(complexity["complexity_penalty"]),
|
| 101 |
+
)
|
| 102 |
+
static_analysis = StaticAnalysisSummary(
|
| 103 |
+
syntax_valid=bool(parsed["syntax_valid"]),
|
| 104 |
+
syntax_error=str(parsed["syntax_error"]),
|
| 105 |
+
cyclomatic_complexity=int(complexity["cyclomatic_complexity"]),
|
| 106 |
+
line_count=int(parsed["line_count"]),
|
| 107 |
+
max_loop_depth=int(parsed["max_loop_depth"]),
|
| 108 |
+
time_complexity=str(complexity["time_complexity"]),
|
| 109 |
+
space_complexity=str(complexity["space_complexity"]),
|
| 110 |
+
detected_imports=list(parsed["imports"]),
|
| 111 |
+
code_smells=list(parsed["code_smells"]),
|
| 112 |
+
)
|
| 113 |
+
improvement_plan = self.suggestion_service.build_improvement_plan(
|
| 114 |
+
domain_analysis=domain_analysis,
|
| 115 |
+
static_analysis=static_analysis,
|
| 116 |
+
)
|
| 117 |
+
summary = (
|
| 118 |
+
f"Detected `{detected_domain}` code with a model score of {score_breakdown.ml_score:.0%}, "
|
| 119 |
+
f"domain score {score_breakdown.domain_score:.0%}, and final reward {score_breakdown.reward:.0%}."
|
| 120 |
+
)
|
| 121 |
+
return AnalyzeCodeResponse(
|
| 122 |
+
detected_domain=detected_domain, # type: ignore[arg-type]
|
| 123 |
+
domain_confidences=combined_scores,
|
| 124 |
+
score_breakdown=score_breakdown,
|
| 125 |
+
static_analysis=static_analysis,
|
| 126 |
+
domain_analysis=domain_analysis,
|
| 127 |
+
improvement_plan=improvement_plan,
|
| 128 |
+
model_backend=str(model_prediction["backend_name"]),
|
| 129 |
+
model_id=str(model_prediction["model_id"]),
|
| 130 |
+
summary=summary,
|
| 131 |
+
context_window=request.context_window,
|
| 132 |
+
analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2),
|
| 133 |
+
)
|
services/reward_service.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward shaping logic for RL-ready code analysis scores."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from schemas.response import ScoreBreakdown
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RewardService:
|
| 9 |
+
"""Compute reward scores from model, domain, lint, and complexity signals."""
|
| 10 |
+
|
| 11 |
+
def compute(self, *, ml_score: float, domain_score: float, lint_score: float, complexity_penalty: float) -> ScoreBreakdown:
|
| 12 |
+
"""Apply the weighted reward formula and clamp the result."""
|
| 13 |
+
|
| 14 |
+
reward = max(
|
| 15 |
+
0.0,
|
| 16 |
+
min(
|
| 17 |
+
1.0,
|
| 18 |
+
(0.4 * ml_score) + (0.2 * domain_score) + (0.2 * lint_score) - (0.2 * complexity_penalty),
|
| 19 |
+
),
|
| 20 |
+
)
|
| 21 |
+
return ScoreBreakdown(
|
| 22 |
+
ml_score=round(ml_score, 4),
|
| 23 |
+
domain_score=round(domain_score, 4),
|
| 24 |
+
lint_score=round(lint_score, 4),
|
| 25 |
+
complexity_penalty=round(complexity_penalty, 4),
|
| 26 |
+
reward=round(reward, 4),
|
| 27 |
+
)
|
services/suggestion_service.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Suggestion and improvement-plan generation for analyzed code."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from schemas.response import DomainAnalysis, StaticAnalysisSummary
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SuggestionService:
|
| 9 |
+
"""Build high-signal improvement steps from analysis output."""
|
| 10 |
+
|
| 11 |
+
def build_improvement_plan(self, *, domain_analysis: DomainAnalysis, static_analysis: StaticAnalysisSummary) -> list[str]:
|
| 12 |
+
"""Return a compact three-step plan optimized for developer action."""
|
| 13 |
+
|
| 14 |
+
primary_issue = (
|
| 15 |
+
domain_analysis.issues[0].description
|
| 16 |
+
if domain_analysis.issues
|
| 17 |
+
else "Stabilize correctness first and keep the public behavior explicit."
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
step_one = f"Step 1 - Correctness and safety: {primary_issue}"
|
| 21 |
+
step_two = "Step 2 - Edge cases: test empty inputs, boundary values, malformed payloads, and failure-mode behavior explicitly."
|
| 22 |
+
step_three = "Step 3 - Scalability: reduce repeated scans, lower cyclomatic complexity, and benchmark the path on realistic input sizes."
|
| 23 |
+
|
| 24 |
+
if domain_analysis.suggestions:
|
| 25 |
+
step_three = f"{step_three} Priority hint: {domain_analysis.suggestions[0]}"
|
| 26 |
+
if not static_analysis.syntax_valid:
|
| 27 |
+
step_one = f"Step 1 - Correctness and safety: fix the syntax error first ({static_analysis.syntax_error})."
|
| 28 |
+
return [step_one, step_two, step_three]
|
tasks/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task catalog for python_code_review_env."""
|
| 2 |
+
|
| 3 |
+
from .catalog import ReviewTask, get_task, list_tasks, select_task
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def task_ids() -> list[str]:
|
| 7 |
+
"""Return stable task identifiers for validators."""
|
| 8 |
+
|
| 9 |
+
return [task.task_id for task in list_tasks()]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = ["ReviewTask", "get_task", "list_tasks", "select_task", "task_ids"]
|
tasks/catalog.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic task definitions for the code review environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from textwrap import dedent
|
| 7 |
+
from typing import Any, Dict, List, Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _code(value: str) -> str:
|
| 11 |
+
return dedent(value).strip() + "\n"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class CallCase:
|
| 16 |
+
"""One executable function call used by graders."""
|
| 17 |
+
|
| 18 |
+
label: str
|
| 19 |
+
args: tuple[Any, ...] = ()
|
| 20 |
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
| 21 |
+
expected: Any = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class ReviewTask:
|
| 26 |
+
"""Static task definition."""
|
| 27 |
+
|
| 28 |
+
task_id: str
|
| 29 |
+
title: str
|
| 30 |
+
difficulty: str
|
| 31 |
+
task_kind: str
|
| 32 |
+
task_description: str
|
| 33 |
+
starter_code: str
|
| 34 |
+
reference_code: str
|
| 35 |
+
function_name: str
|
| 36 |
+
visible_tests: List[str]
|
| 37 |
+
public_cases: List[CallCase]
|
| 38 |
+
hidden_cases: List[CallCase]
|
| 39 |
+
repo_summary: str
|
| 40 |
+
changed_files: List[str]
|
| 41 |
+
available_files: List[str]
|
| 42 |
+
goal: str
|
| 43 |
+
max_steps: int
|
| 44 |
+
benchmark_config: Optional[Dict[str, int]] = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
TASKS: List[ReviewTask] = [
|
| 48 |
+
ReviewTask(
|
| 49 |
+
task_id="syntax_fix_invoice_totals",
|
| 50 |
+
title="Fix the invoice total syntax regression",
|
| 51 |
+
difficulty="easy",
|
| 52 |
+
task_kind="syntax_fix",
|
| 53 |
+
task_description=(
|
| 54 |
+
"A recent refactor broke the helper that normalizes invoice totals before "
|
| 55 |
+
"daily reconciliation. Repair the Python syntax so the function compiles "
|
| 56 |
+
"and returns the correct total for mixed integer and string inputs."
|
| 57 |
+
),
|
| 58 |
+
starter_code=_code(
|
| 59 |
+
"""
|
| 60 |
+
def normalize_invoice_totals(records):
|
| 61 |
+
cleaned = []
|
| 62 |
+
for record in records
|
| 63 |
+
if "total" not in record:
|
| 64 |
+
continue
|
| 65 |
+
value = int(record["total"])
|
| 66 |
+
cleaned.append(value)
|
| 67 |
+
return sum(cleaned
|
| 68 |
+
"""
|
| 69 |
+
),
|
| 70 |
+
reference_code=_code(
|
| 71 |
+
'''
|
| 72 |
+
def normalize_invoice_totals(records):
|
| 73 |
+
"""Return the sum of invoice totals that are present in the payload."""
|
| 74 |
+
cleaned = []
|
| 75 |
+
for record in records:
|
| 76 |
+
if "total" not in record:
|
| 77 |
+
continue
|
| 78 |
+
cleaned.append(int(record["total"]))
|
| 79 |
+
return sum(cleaned)
|
| 80 |
+
'''
|
| 81 |
+
),
|
| 82 |
+
function_name="normalize_invoice_totals",
|
| 83 |
+
visible_tests=[
|
| 84 |
+
"normalize_invoice_totals([{'total': '4'}, {'total': 5}, {}]) == 9",
|
| 85 |
+
"normalize_invoice_totals([]) == 0",
|
| 86 |
+
],
|
| 87 |
+
public_cases=[
|
| 88 |
+
CallCase(
|
| 89 |
+
label="mixed string and int totals",
|
| 90 |
+
args=([{"total": "4"}, {"total": 5}, {}],),
|
| 91 |
+
expected=9,
|
| 92 |
+
),
|
| 93 |
+
CallCase(label="empty input", args=([],), expected=0),
|
| 94 |
+
],
|
| 95 |
+
hidden_cases=[
|
| 96 |
+
CallCase(
|
| 97 |
+
label="skip missing totals",
|
| 98 |
+
args=([{}, {"total": "2"}, {"total": "8"}],),
|
| 99 |
+
expected=10,
|
| 100 |
+
),
|
| 101 |
+
CallCase(
|
| 102 |
+
label="handle negative adjustments",
|
| 103 |
+
args=([{"total": "11"}, {"total": -3}],),
|
| 104 |
+
expected=8,
|
| 105 |
+
),
|
| 106 |
+
],
|
| 107 |
+
repo_summary=(
|
| 108 |
+
"services/billing/reconciliation.py computes end-of-day invoice totals for "
|
| 109 |
+
"a CPU-only batch job."
|
| 110 |
+
),
|
| 111 |
+
changed_files=["services/billing/reconciliation.py"],
|
| 112 |
+
available_files=["services/billing/reconciliation.py", "tests/test_reconciliation.py"],
|
| 113 |
+
goal="Restore a compiling implementation for invoice total normalization.",
|
| 114 |
+
max_steps=6,
|
| 115 |
+
),
|
| 116 |
+
ReviewTask(
|
| 117 |
+
task_id="bug_fix_session_windows",
|
| 118 |
+
title="Repair session window collapsing logic",
|
| 119 |
+
difficulty="medium",
|
| 120 |
+
task_kind="bug_fix",
|
| 121 |
+
task_description=(
|
| 122 |
+
"The session aggregator regressed after a cleanup pass. Public tests expose "
|
| 123 |
+
"incorrect boundary handling and the final session is missing. Fix the logic "
|
| 124 |
+
"without changing the function contract."
|
| 125 |
+
),
|
| 126 |
+
starter_code=_code(
|
| 127 |
+
"""
|
| 128 |
+
def collapse_sessions(events, idle_timeout_minutes):
|
| 129 |
+
if not events:
|
| 130 |
+
return []
|
| 131 |
+
|
| 132 |
+
sessions = []
|
| 133 |
+
current_start = events[0]["minute"]
|
| 134 |
+
current_end = current_start
|
| 135 |
+
|
| 136 |
+
for event in events[1:]:
|
| 137 |
+
minute = event["minute"]
|
| 138 |
+
if minute - current_end > idle_timeout_minutes:
|
| 139 |
+
sessions.append((current_start, current_end))
|
| 140 |
+
current_start = minute
|
| 141 |
+
current_end = minute
|
| 142 |
+
|
| 143 |
+
return sessions
|
| 144 |
+
"""
|
| 145 |
+
),
|
| 146 |
+
reference_code=_code(
|
| 147 |
+
'''
|
| 148 |
+
def collapse_sessions(events, idle_timeout_minutes):
|
| 149 |
+
"""Collapse activity events into inclusive session windows."""
|
| 150 |
+
if not events:
|
| 151 |
+
return []
|
| 152 |
+
|
| 153 |
+
sessions = []
|
| 154 |
+
current_start = events[0]["minute"]
|
| 155 |
+
current_end = current_start
|
| 156 |
+
|
| 157 |
+
for event in events[1:]:
|
| 158 |
+
minute = event["minute"]
|
| 159 |
+
if minute - current_end >= idle_timeout_minutes:
|
| 160 |
+
sessions.append((current_start, current_end))
|
| 161 |
+
current_start = minute
|
| 162 |
+
current_end = minute
|
| 163 |
+
|
| 164 |
+
sessions.append((current_start, current_end))
|
| 165 |
+
return sessions
|
| 166 |
+
'''
|
| 167 |
+
),
|
| 168 |
+
function_name="collapse_sessions",
|
| 169 |
+
visible_tests=[
|
| 170 |
+
"collapse_sessions([{'minute': 1}, {'minute': 3}, {'minute': 8}], 4) == [(1, 3), (8, 8)]",
|
| 171 |
+
"collapse_sessions([{'minute': 5}, {'minute': 9}], 4) == [(5, 5), (9, 9)]",
|
| 172 |
+
],
|
| 173 |
+
public_cases=[
|
| 174 |
+
CallCase(
|
| 175 |
+
label="split when idle timeout is exceeded",
|
| 176 |
+
args=([{"minute": 1}, {"minute": 3}, {"minute": 8}], 4),
|
| 177 |
+
expected=[(1, 3), (8, 8)],
|
| 178 |
+
),
|
| 179 |
+
CallCase(
|
| 180 |
+
label="boundary is inclusive",
|
| 181 |
+
args=([{"minute": 5}, {"minute": 9}], 4),
|
| 182 |
+
expected=[(5, 5), (9, 9)],
|
| 183 |
+
),
|
| 184 |
+
],
|
| 185 |
+
hidden_cases=[
|
| 186 |
+
CallCase(
|
| 187 |
+
label="single continuous session",
|
| 188 |
+
args=([{"minute": 2}, {"minute": 4}, {"minute": 5}], 4),
|
| 189 |
+
expected=[(2, 5)],
|
| 190 |
+
),
|
| 191 |
+
CallCase(label="empty input", args=([], 10), expected=[]),
|
| 192 |
+
CallCase(
|
| 193 |
+
label="multiple boundaries",
|
| 194 |
+
args=([{"minute": 1}, {"minute": 5}, {"minute": 9}, {"minute": 14}], 4),
|
| 195 |
+
expected=[(1, 1), (5, 5), (9, 9), (14, 14)],
|
| 196 |
+
),
|
| 197 |
+
],
|
| 198 |
+
repo_summary=(
|
| 199 |
+
"analytics/sessionizer.py condenses sorted clickstream events into user "
|
| 200 |
+
"sessions for downstream retention reports."
|
| 201 |
+
),
|
| 202 |
+
changed_files=["analytics/sessionizer.py"],
|
| 203 |
+
available_files=["analytics/sessionizer.py", "tests/test_sessionizer.py"],
|
| 204 |
+
goal="Make session collapsing match the expected timeout semantics.",
|
| 205 |
+
max_steps=8,
|
| 206 |
+
),
|
| 207 |
+
ReviewTask(
|
| 208 |
+
task_id="optimization_rank_active_users",
|
| 209 |
+
title="Optimize the active-user ranking pipeline",
|
| 210 |
+
difficulty="hard",
|
| 211 |
+
task_kind="optimization",
|
| 212 |
+
task_description=(
|
| 213 |
+
"The reporting job is correct enough for small fixtures but too slow for the "
|
| 214 |
+
"daily production export. Preserve the API, keep the output deterministic, "
|
| 215 |
+
"and refactor the implementation for speed and readability."
|
| 216 |
+
),
|
| 217 |
+
starter_code=_code(
|
| 218 |
+
"""
|
| 219 |
+
def rank_active_users(events):
|
| 220 |
+
users = []
|
| 221 |
+
for event in events:
|
| 222 |
+
if event["status"] == "active":
|
| 223 |
+
found = False
|
| 224 |
+
for existing in users:
|
| 225 |
+
if existing == event["user_id"]:
|
| 226 |
+
found = True
|
| 227 |
+
if not found:
|
| 228 |
+
users.append(event["user_id"])
|
| 229 |
+
|
| 230 |
+
totals = []
|
| 231 |
+
for user in users:
|
| 232 |
+
count = 0
|
| 233 |
+
for event in events:
|
| 234 |
+
if event["status"] == "active" and event["user_id"] == user:
|
| 235 |
+
count = count + 1
|
| 236 |
+
totals.append((user, count))
|
| 237 |
+
|
| 238 |
+
totals.sort(key=lambda item: (-item[1], item[0]))
|
| 239 |
+
return totals
|
| 240 |
+
"""
|
| 241 |
+
),
|
| 242 |
+
reference_code=_code(
|
| 243 |
+
'''
|
| 244 |
+
from collections import Counter
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def rank_active_users(events):
|
| 248 |
+
"""Return users ranked by number of active events."""
|
| 249 |
+
counts = Counter(
|
| 250 |
+
event["user_id"]
|
| 251 |
+
for event in events
|
| 252 |
+
if event["status"] == "active"
|
| 253 |
+
)
|
| 254 |
+
return sorted(counts.items(), key=lambda item: (-item[1], item[0]))
|
| 255 |
+
'''
|
| 256 |
+
),
|
| 257 |
+
function_name="rank_active_users",
|
| 258 |
+
visible_tests=[
|
| 259 |
+
"rank_active_users([{'user_id': 'b', 'status': 'active'}, {'user_id': 'a', 'status': 'active'}, {'user_id': 'b', 'status': 'inactive'}]) == [('a', 1), ('b', 1)]",
|
| 260 |
+
"rank_active_users([{'user_id': 'u1', 'status': 'active'}, {'user_id': 'u1', 'status': 'active'}, {'user_id': 'u2', 'status': 'active'}]) == [('u1', 2), ('u2', 1)]",
|
| 261 |
+
],
|
| 262 |
+
public_cases=[
|
| 263 |
+
CallCase(
|
| 264 |
+
label="inactive events are ignored",
|
| 265 |
+
args=([{"user_id": "b", "status": "active"}, {"user_id": "a", "status": "active"}, {"user_id": "b", "status": "inactive"}],),
|
| 266 |
+
expected=[("a", 1), ("b", 1)],
|
| 267 |
+
),
|
| 268 |
+
CallCase(
|
| 269 |
+
label="counts repeated active users",
|
| 270 |
+
args=([{"user_id": "u1", "status": "active"}, {"user_id": "u1", "status": "active"}, {"user_id": "u2", "status": "active"}],),
|
| 271 |
+
expected=[("u1", 2), ("u2", 1)],
|
| 272 |
+
),
|
| 273 |
+
],
|
| 274 |
+
hidden_cases=[
|
| 275 |
+
CallCase(
|
| 276 |
+
label="stable alphabetical tie-break",
|
| 277 |
+
args=([{"user_id": "u3", "status": "active"}, {"user_id": "u2", "status": "active"}, {"user_id": "u3", "status": "active"}, {"user_id": "u2", "status": "active"}],),
|
| 278 |
+
expected=[("u2", 2), ("u3", 2)],
|
| 279 |
+
),
|
| 280 |
+
CallCase(label="empty input", args=([],), expected=[]),
|
| 281 |
+
CallCase(
|
| 282 |
+
label="mixed active and inactive states",
|
| 283 |
+
args=([{"user_id": "x", "status": "inactive"}, {"user_id": "x", "status": "active"}, {"user_id": "y", "status": "active"}, {"user_id": "x", "status": "active"}],),
|
| 284 |
+
expected=[("x", 2), ("y", 1)],
|
| 285 |
+
),
|
| 286 |
+
],
|
| 287 |
+
repo_summary=(
|
| 288 |
+
"reports/activity_rankings.py feeds a nightly export that runs on a small CPU "
|
| 289 |
+
"instance and has become too slow after customer growth."
|
| 290 |
+
),
|
| 291 |
+
changed_files=["reports/activity_rankings.py"],
|
| 292 |
+
available_files=["reports/activity_rankings.py", "tests/test_activity_rankings.py"],
|
| 293 |
+
goal="Keep the output stable while improving runtime and code quality.",
|
| 294 |
+
max_steps=10,
|
| 295 |
+
benchmark_config={"user_pool": 240, "events_per_user": 36, "iterations": 8},
|
| 296 |
+
),
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
TASK_BY_ID = {task.task_id: task for task in TASKS}
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def list_tasks() -> List[ReviewTask]:
|
| 303 |
+
"""Return all supported tasks."""
|
| 304 |
+
|
| 305 |
+
return list(TASKS)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_task(task_id: str) -> ReviewTask:
|
| 309 |
+
"""Fetch a task by identifier."""
|
| 310 |
+
|
| 311 |
+
try:
|
| 312 |
+
return TASK_BY_ID[task_id]
|
| 313 |
+
except KeyError as exc:
|
| 314 |
+
raise ValueError(f"Unknown task_id: {task_id}") from exc
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def select_task(seed: Optional[int] = None, task_id: Optional[str] = None) -> ReviewTask:
|
| 318 |
+
"""Select a task deterministically by explicit id or seed."""
|
| 319 |
+
|
| 320 |
+
if task_id:
|
| 321 |
+
return get_task(task_id)
|
| 322 |
+
if seed is None:
|
| 323 |
+
return TASKS[0]
|
| 324 |
+
return TASKS[seed % len(TASKS)]
|
tests/test_multi_domain_platform.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from fastapi.testclient import TestClient
|
| 4 |
+
|
| 5 |
+
from api.main import app
|
| 6 |
+
from schemas.request import AnalyzeCodeRequest
|
| 7 |
+
from services.analysis_service import AnalysisService
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_analysis_service_detects_web_code() -> None:
|
| 11 |
+
service = AnalysisService()
|
| 12 |
+
request = AnalyzeCodeRequest(
|
| 13 |
+
code="from fastapi import FastAPI\napp = FastAPI()\n\n@app.get('/health')\ndef health():\n return {'status': 'ok'}\n",
|
| 14 |
+
domain_hint="auto",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
result = service.analyze(request)
|
| 18 |
+
|
| 19 |
+
assert result.detected_domain == "web"
|
| 20 |
+
assert 0.0 <= result.score_breakdown.reward <= 1.0
|
| 21 |
+
assert len(result.improvement_plan) == 3
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_analysis_service_detects_dsa_code() -> None:
|
| 25 |
+
service = AnalysisService()
|
| 26 |
+
request = AnalyzeCodeRequest(
|
| 27 |
+
code="def has_pair(nums, target):\n for i in range(len(nums)):\n for j in range(i + 1, len(nums)):\n if nums[i] + nums[j] == target:\n return True\n return False\n",
|
| 28 |
+
domain_hint="auto",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
result = service.analyze(request)
|
| 32 |
+
|
| 33 |
+
assert result.detected_domain == "dsa"
|
| 34 |
+
assert result.static_analysis.time_complexity in {"O(n^2)", "O(n^3)"}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_api_analyze_endpoint_returns_valid_payload() -> None:
|
| 38 |
+
client = TestClient(app)
|
| 39 |
+
response = client.post(
|
| 40 |
+
"/analyze",
|
| 41 |
+
json={
|
| 42 |
+
"code": "import torch\n\ndef predict(model, x):\n return model(x)\n",
|
| 43 |
+
"context_window": "Inference helper for a classifier",
|
| 44 |
+
"traceback_text": "",
|
| 45 |
+
"domain_hint": "auto",
|
| 46 |
+
},
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
assert response.status_code == 200
|
| 50 |
+
payload = response.json()
|
| 51 |
+
assert "detected_domain" in payload
|
| 52 |
+
assert "score_breakdown" in payload
|
tests/test_scoring.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from graders import grade_task
|
| 4 |
+
from Models import PythonCodeReviewAction
|
| 5 |
+
from server.env import PythonCodeReviewEnvironment
|
| 6 |
+
from tasks import list_tasks
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def assert_open_unit_interval(value: float) -> None:
|
| 10 |
+
assert 0 < value < 1, f"Invalid score: {value}"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_task_grades_stay_strictly_between_zero_and_one() -> None:
|
| 14 |
+
for task in list_tasks():
|
| 15 |
+
starter_grade = grade_task(task, task.starter_code, include_hidden=False)
|
| 16 |
+
reference_grade = grade_task(task, task.reference_code, include_hidden=True)
|
| 17 |
+
|
| 18 |
+
for grade in (starter_grade, reference_grade):
|
| 19 |
+
assert_open_unit_interval(grade.score)
|
| 20 |
+
assert_open_unit_interval(grade.syntax_score)
|
| 21 |
+
assert_open_unit_interval(grade.quality_score)
|
| 22 |
+
assert_open_unit_interval(grade.runtime_score)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_environment_scores_and_rewards_stay_in_open_interval() -> None:
|
| 26 |
+
env = PythonCodeReviewEnvironment(verbose=False)
|
| 27 |
+
observation = env.reset(task_id="bug_fix_session_windows")
|
| 28 |
+
|
| 29 |
+
assert_open_unit_interval(observation.score)
|
| 30 |
+
assert_open_unit_interval(observation.reward_details.value)
|
| 31 |
+
|
| 32 |
+
no_op_action = PythonCodeReviewAction(action_type="edit_code", code=observation.current_code)
|
| 33 |
+
next_observation, reward, _, _ = env.step_result(no_op_action)
|
| 34 |
+
assert_open_unit_interval(next_observation.score)
|
| 35 |
+
assert_open_unit_interval(reward)
|
| 36 |
+
assert_open_unit_interval(next_observation.reward_details.value)
|
| 37 |
+
|
| 38 |
+
submit_action = PythonCodeReviewAction(action_type="submit_solution", code=env._task.reference_code)
|
| 39 |
+
final_observation, final_reward, _, _ = env.step_result(submit_action)
|
| 40 |
+
assert_open_unit_interval(final_observation.score)
|
| 41 |
+
assert_open_unit_interval(final_reward)
|
| 42 |
+
assert_open_unit_interval(final_observation.reward_details.value)
|
tests/test_triage_pipeline.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from fastapi.testclient import TestClient
|
| 4 |
+
|
| 5 |
+
from triage import CodeTriageEngine, HashingEmbeddingBackend
|
| 6 |
+
from triage_catalog import build_examples
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_hashing_backend_returns_normalized_embeddings() -> None:
|
| 10 |
+
backend = HashingEmbeddingBackend(dimensions=32)
|
| 11 |
+
embeddings = backend.embed_texts(["def foo():\n return 1", "for x in items:\n pass"])
|
| 12 |
+
|
| 13 |
+
assert embeddings.shape == (2, 32)
|
| 14 |
+
for row in embeddings:
|
| 15 |
+
assert round(float(row.norm().item()), 5) == 1.0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_examples_map_to_expected_labels_with_fallback_backend() -> None:
|
| 19 |
+
examples = build_examples()
|
| 20 |
+
engine = CodeTriageEngine(backend=HashingEmbeddingBackend())
|
| 21 |
+
|
| 22 |
+
for example in examples:
|
| 23 |
+
result = engine.triage(example.code, example.traceback_text, example.context_window)
|
| 24 |
+
assert result.issue_label == example.label
|
| 25 |
+
assert 0.0 <= result.reward_score <= 1.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_syntax_example_exposes_parser_signal() -> None:
|
| 29 |
+
example = next(item for item in build_examples() if item.label == "syntax")
|
| 30 |
+
engine = CodeTriageEngine(backend=HashingEmbeddingBackend())
|
| 31 |
+
|
| 32 |
+
result = engine.triage(example.code, example.traceback_text, example.context_window)
|
| 33 |
+
|
| 34 |
+
assert any(signal.name == "syntax_parse" and signal.value == "fails" for signal in result.extracted_signals)
|
| 35 |
+
assert result.matched_pattern.task_id == example.task_id
|
| 36 |
+
assert result.repair_plan[0].startswith("Step 1 - Syntax checking and bug fixes")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_composed_app_preserves_health_route() -> None:
|
| 40 |
+
from server.app import build_application
|
| 41 |
+
|
| 42 |
+
client = TestClient(build_application())
|
| 43 |
+
response = client.get("/health")
|
| 44 |
+
|
| 45 |
+
assert response.status_code == 200
|
| 46 |
+
assert response.json()["status"] == "ok"
|
triage.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch-backed triage pipeline for TorchReview Copilot."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import ast
|
| 6 |
+
import hashlib
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
from typing import List, Sequence
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from transformers import AutoModel, AutoTokenizer
|
| 18 |
+
except Exception:
|
| 19 |
+
AutoModel = None # type: ignore[assignment]
|
| 20 |
+
AutoTokenizer = None # type: ignore[assignment]
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from .triage_catalog import build_examples, build_prototypes
|
| 24 |
+
from .triage_models import (
|
| 25 |
+
IssueLabel,
|
| 26 |
+
PrototypeMatch,
|
| 27 |
+
TriageExample,
|
| 28 |
+
TriagePrototype,
|
| 29 |
+
TriageResult,
|
| 30 |
+
TriageSignal,
|
| 31 |
+
)
|
| 32 |
+
except ImportError:
|
| 33 |
+
from triage_catalog import build_examples, build_prototypes
|
| 34 |
+
from triage_models import (
|
| 35 |
+
IssueLabel,
|
| 36 |
+
PrototypeMatch,
|
| 37 |
+
TriageExample,
|
| 38 |
+
TriagePrototype,
|
| 39 |
+
TriageResult,
|
| 40 |
+
TriageSignal,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
MODEL_ID = os.getenv("TRIAGE_MODEL_ID", "huggingface/CodeBERTa-small-v1")
|
| 45 |
+
MODEL_MAX_LENGTH = int(os.getenv("TRIAGE_MODEL_MAX_LENGTH", "256"))
|
| 46 |
+
LABELS: tuple[IssueLabel, ...] = ("syntax", "logic", "performance")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class _LoopDepthVisitor(ast.NodeVisitor):
|
| 50 |
+
"""Track the maximum loop nesting depth in a code snippet."""
|
| 51 |
+
|
| 52 |
+
def __init__(self) -> None:
|
| 53 |
+
self.depth = 0
|
| 54 |
+
self.max_depth = 0
|
| 55 |
+
|
| 56 |
+
def _visit_loop(self, node: ast.AST) -> None:
|
| 57 |
+
self.depth += 1
|
| 58 |
+
self.max_depth = max(self.max_depth, self.depth)
|
| 59 |
+
self.generic_visit(node)
|
| 60 |
+
self.depth -= 1
|
| 61 |
+
|
| 62 |
+
def visit_For(self, node: ast.For) -> None: # noqa: N802
|
| 63 |
+
self._visit_loop(node)
|
| 64 |
+
|
| 65 |
+
def visit_While(self, node: ast.While) -> None: # noqa: N802
|
| 66 |
+
self._visit_loop(node)
|
| 67 |
+
|
| 68 |
+
def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802
|
| 69 |
+
self._visit_loop(node)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class HashingEmbeddingBackend:
|
| 73 |
+
"""Deterministic torch-native fallback when pretrained weights are unavailable."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, dimensions: int = 96) -> None:
|
| 76 |
+
self.dimensions = dimensions
|
| 77 |
+
self.model_id = "hashed-token-fallback"
|
| 78 |
+
self.backend_name = "hashed-token-fallback"
|
| 79 |
+
self.notes = ["Using hashed torch embeddings because pretrained weights are unavailable."]
|
| 80 |
+
|
| 81 |
+
def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
|
| 82 |
+
rows = torch.zeros((len(texts), self.dimensions), dtype=torch.float32)
|
| 83 |
+
for row_index, text in enumerate(texts):
|
| 84 |
+
tokens = re.findall(r"[A-Za-z_]+|\d+|==|!=|<=|>=|\S", text.lower())[:512]
|
| 85 |
+
if not tokens:
|
| 86 |
+
rows[row_index, 0] = 1.0
|
| 87 |
+
continue
|
| 88 |
+
for token in tokens:
|
| 89 |
+
digest = hashlib.md5(token.encode("utf-8")).hexdigest()
|
| 90 |
+
bucket = int(digest[:8], 16) % self.dimensions
|
| 91 |
+
sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0
|
| 92 |
+
rows[row_index, bucket] += sign
|
| 93 |
+
return F.normalize(rows + 1e-6, dim=1)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class TransformersEmbeddingBackend:
|
| 97 |
+
"""Mean-pool CodeBERTa embeddings via torch + transformers."""
|
| 98 |
+
|
| 99 |
+
def __init__(self, model_id: str = MODEL_ID, force_fallback: bool = False) -> None:
|
| 100 |
+
self.model_id = model_id
|
| 101 |
+
self.force_fallback = force_fallback
|
| 102 |
+
self.backend_name = model_id
|
| 103 |
+
self.notes: List[str] = []
|
| 104 |
+
self._fallback = HashingEmbeddingBackend()
|
| 105 |
+
self._tokenizer = None
|
| 106 |
+
self._model = None
|
| 107 |
+
self._load_error = ""
|
| 108 |
+
if force_fallback:
|
| 109 |
+
self.backend_name = self._fallback.backend_name
|
| 110 |
+
self.notes = list(self._fallback.notes)
|
| 111 |
+
|
| 112 |
+
def _ensure_loaded(self) -> None:
|
| 113 |
+
if self.force_fallback or self._model is not None or self._load_error:
|
| 114 |
+
return
|
| 115 |
+
if AutoTokenizer is None or AutoModel is None:
|
| 116 |
+
self._load_error = "transformers is not installed."
|
| 117 |
+
else:
|
| 118 |
+
try:
|
| 119 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| 120 |
+
self._model = AutoModel.from_pretrained(self.model_id)
|
| 121 |
+
self._model.eval()
|
| 122 |
+
self.notes.append(f"Loaded pretrained encoder `{self.model_id}` for inference.")
|
| 123 |
+
except Exception as exc:
|
| 124 |
+
self._load_error = f"{type(exc).__name__}: {exc}"
|
| 125 |
+
|
| 126 |
+
if self._load_error:
|
| 127 |
+
self.backend_name = self._fallback.backend_name
|
| 128 |
+
self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {self._load_error}"]
|
| 129 |
+
|
| 130 |
+
def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
|
| 131 |
+
self._ensure_loaded()
|
| 132 |
+
if self._model is None or self._tokenizer is None:
|
| 133 |
+
return self._fallback.embed_texts(texts)
|
| 134 |
+
|
| 135 |
+
encoded = self._tokenizer(
|
| 136 |
+
list(texts),
|
| 137 |
+
padding=True,
|
| 138 |
+
truncation=True,
|
| 139 |
+
max_length=MODEL_MAX_LENGTH,
|
| 140 |
+
return_tensors="pt",
|
| 141 |
+
)
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
outputs = self._model(**encoded)
|
| 144 |
+
hidden_state = outputs.last_hidden_state
|
| 145 |
+
mask = encoded["attention_mask"].unsqueeze(-1)
|
| 146 |
+
pooled = (hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
| 147 |
+
return F.normalize(pooled, dim=1)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _sanitize_text(value: str) -> str:
|
| 151 |
+
text = (value or "").strip()
|
| 152 |
+
return text[:4000]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _safe_softmax(scores: dict[IssueLabel, float]) -> dict[str, float]:
|
| 156 |
+
tensor = torch.tensor([scores[label] for label in LABELS], dtype=torch.float32)
|
| 157 |
+
probabilities = torch.softmax(tensor * 4.0, dim=0)
|
| 158 |
+
return {label: round(float(probabilities[index]), 4) for index, label in enumerate(LABELS)}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _loop_depth(code: str) -> int:
|
| 162 |
+
try:
|
| 163 |
+
tree = ast.parse(code)
|
| 164 |
+
except SyntaxError:
|
| 165 |
+
return 0
|
| 166 |
+
visitor = _LoopDepthVisitor()
|
| 167 |
+
visitor.visit(tree)
|
| 168 |
+
return visitor.max_depth
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _repair_risk(label: IssueLabel, confidence: float, signal_count: int) -> str:
|
| 172 |
+
base = {"syntax": 0.25, "logic": 0.55, "performance": 0.7}[label]
|
| 173 |
+
if confidence < 0.55:
|
| 174 |
+
base += 0.12
|
| 175 |
+
if signal_count >= 4:
|
| 176 |
+
base += 0.08
|
| 177 |
+
if base < 0.4:
|
| 178 |
+
return "low"
|
| 179 |
+
if base < 0.72:
|
| 180 |
+
return "medium"
|
| 181 |
+
return "high"
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _clamp_unit(value: float) -> float:
|
| 185 |
+
return round(max(0.0, min(1.0, float(value))), 4)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _lint_score(code: str) -> float:
|
| 189 |
+
stripped_lines = [line.rstrip("\n") for line in code.splitlines()]
|
| 190 |
+
if not stripped_lines:
|
| 191 |
+
return 0.2
|
| 192 |
+
|
| 193 |
+
score = 1.0
|
| 194 |
+
if any(len(line) > 88 for line in stripped_lines):
|
| 195 |
+
score -= 0.15
|
| 196 |
+
if any(line.rstrip() != line for line in stripped_lines):
|
| 197 |
+
score -= 0.1
|
| 198 |
+
if any("\t" in line for line in stripped_lines):
|
| 199 |
+
score -= 0.1
|
| 200 |
+
try:
|
| 201 |
+
tree = ast.parse(code)
|
| 202 |
+
functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)]
|
| 203 |
+
if functions and not ast.get_docstring(functions[0]):
|
| 204 |
+
score -= 0.08
|
| 205 |
+
except SyntaxError:
|
| 206 |
+
score -= 0.45
|
| 207 |
+
return _clamp_unit(score)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _complexity_penalty(code: str) -> float:
|
| 211 |
+
try:
|
| 212 |
+
tree = ast.parse(code)
|
| 213 |
+
except SyntaxError:
|
| 214 |
+
return 0.95
|
| 215 |
+
branch_nodes = sum(isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.Match)) for node in ast.walk(tree))
|
| 216 |
+
loop_depth = _loop_depth(code)
|
| 217 |
+
penalty = 0.1 + min(branch_nodes, 8) * 0.07 + min(loop_depth, 4) * 0.12
|
| 218 |
+
return _clamp_unit(penalty)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class CodeTriageEngine:
|
| 222 |
+
"""Combine static signals with PyTorch embeddings to classify code issues."""
|
| 223 |
+
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
*,
|
| 227 |
+
backend: TransformersEmbeddingBackend | HashingEmbeddingBackend | None = None,
|
| 228 |
+
prototypes: Sequence[TriagePrototype] | None = None,
|
| 229 |
+
examples: Sequence[TriageExample] | None = None,
|
| 230 |
+
) -> None:
|
| 231 |
+
self.backend = backend or TransformersEmbeddingBackend()
|
| 232 |
+
self.prototypes = list(prototypes or build_prototypes())
|
| 233 |
+
self.examples = list(examples or build_examples())
|
| 234 |
+
self._prototype_matrix: torch.Tensor | None = None
|
| 235 |
+
self._reference_code_matrix: torch.Tensor | None = None
|
| 236 |
+
|
| 237 |
+
def example_map(self) -> dict[str, TriageExample]:
|
| 238 |
+
"""Return UI examples keyed by task id."""
|
| 239 |
+
|
| 240 |
+
return {example.key: example for example in self.examples}
|
| 241 |
+
|
| 242 |
+
def _build_document(self, code: str, traceback_text: str) -> str:
|
| 243 |
+
trace = _sanitize_text(traceback_text) or "No traceback supplied."
|
| 244 |
+
snippet = _sanitize_text(code) or "# No code supplied."
|
| 245 |
+
return f"Candidate code:\n{snippet}\n\nObserved failure:\n{trace}\n"
|
| 246 |
+
|
| 247 |
+
def _build_review_document(self, code: str, traceback_text: str, context_window: str) -> str:
|
| 248 |
+
context = _sanitize_text(context_window) or "No additional context window supplied."
|
| 249 |
+
return (
|
| 250 |
+
f"{self._build_document(code, traceback_text)}\n"
|
| 251 |
+
f"Context window:\n{context}\n"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def _prototype_embeddings(self) -> torch.Tensor:
|
| 255 |
+
if self._prototype_matrix is None:
|
| 256 |
+
reference_texts = [prototype.reference_text for prototype in self.prototypes]
|
| 257 |
+
self._prototype_matrix = self.backend.embed_texts(reference_texts)
|
| 258 |
+
return self._prototype_matrix
|
| 259 |
+
|
| 260 |
+
def _reference_code_embeddings(self) -> torch.Tensor:
|
| 261 |
+
if self._reference_code_matrix is None:
|
| 262 |
+
reference_codes = [prototype.reference_code for prototype in self.prototypes]
|
| 263 |
+
self._reference_code_matrix = self.backend.embed_texts(reference_codes)
|
| 264 |
+
return self._reference_code_matrix
|
| 265 |
+
|
| 266 |
+
def _extract_signals(self, code: str, traceback_text: str) -> tuple[list[TriageSignal], dict[IssueLabel, float], list[str]]:
|
| 267 |
+
trace = (traceback_text or "").lower()
|
| 268 |
+
heuristic_scores: dict[IssueLabel, float] = {label: 0.15 for label in LABELS}
|
| 269 |
+
signals: list[TriageSignal] = []
|
| 270 |
+
notes: list[str] = []
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
ast.parse(code)
|
| 274 |
+
signals.append(
|
| 275 |
+
TriageSignal(
|
| 276 |
+
name="syntax_parse",
|
| 277 |
+
value="passes",
|
| 278 |
+
impact="syntax",
|
| 279 |
+
weight=0.1,
|
| 280 |
+
evidence="Python AST parsing succeeded.",
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
heuristic_scores["logic"] += 0.05
|
| 284 |
+
except SyntaxError as exc:
|
| 285 |
+
evidence = f"{exc.msg} at line {exc.lineno}"
|
| 286 |
+
signals.append(
|
| 287 |
+
TriageSignal(
|
| 288 |
+
name="syntax_parse",
|
| 289 |
+
value="fails",
|
| 290 |
+
impact="syntax",
|
| 291 |
+
weight=0.95,
|
| 292 |
+
evidence=evidence,
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
heuristic_scores["syntax"] += 0.85
|
| 296 |
+
notes.append(f"Parser failure detected: {evidence}")
|
| 297 |
+
|
| 298 |
+
if any(token in trace for token in ("syntaxerror", "indentationerror", "expected ':'")):
|
| 299 |
+
signals.append(
|
| 300 |
+
TriageSignal(
|
| 301 |
+
name="traceback_keyword",
|
| 302 |
+
value="syntaxerror",
|
| 303 |
+
impact="syntax",
|
| 304 |
+
weight=0.8,
|
| 305 |
+
evidence="Traceback contains a parser error.",
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
heuristic_scores["syntax"] += 0.55
|
| 309 |
+
|
| 310 |
+
if any(token in trace for token in ("assertionerror", "expected:", "actual:", "boundary", "missing", "incorrect")):
|
| 311 |
+
signals.append(
|
| 312 |
+
TriageSignal(
|
| 313 |
+
name="test_failure_signal",
|
| 314 |
+
value="assertion-style failure",
|
| 315 |
+
impact="logic",
|
| 316 |
+
weight=0.7,
|
| 317 |
+
evidence="Failure text points to behavioral mismatch instead of parser issues.",
|
| 318 |
+
)
|
| 319 |
+
)
|
| 320 |
+
heuristic_scores["logic"] += 0.55
|
| 321 |
+
|
| 322 |
+
if any(token in trace for token in ("timeout", "benchmark", "slow", "latency", "performance", "profiler")):
|
| 323 |
+
signals.append(
|
| 324 |
+
TriageSignal(
|
| 325 |
+
name="performance_trace",
|
| 326 |
+
value="latency regression",
|
| 327 |
+
impact="performance",
|
| 328 |
+
weight=0.85,
|
| 329 |
+
evidence="Traceback mentions benchmark or latency pressure.",
|
| 330 |
+
)
|
| 331 |
+
)
|
| 332 |
+
heuristic_scores["performance"] += 0.7
|
| 333 |
+
|
| 334 |
+
loop_depth = _loop_depth(code)
|
| 335 |
+
if loop_depth >= 2:
|
| 336 |
+
signals.append(
|
| 337 |
+
TriageSignal(
|
| 338 |
+
name="loop_depth",
|
| 339 |
+
value=str(loop_depth),
|
| 340 |
+
impact="performance",
|
| 341 |
+
weight=0.65,
|
| 342 |
+
evidence="Nested iteration increases runtime risk on larger fixtures.",
|
| 343 |
+
)
|
| 344 |
+
)
|
| 345 |
+
heuristic_scores["performance"] += 0.35
|
| 346 |
+
|
| 347 |
+
if "Counter(" in code or "defaultdict(" in code or "set(" in code:
|
| 348 |
+
heuristic_scores["performance"] += 0.05
|
| 349 |
+
|
| 350 |
+
if "return sessions" in code and "sessions.append" not in code:
|
| 351 |
+
signals.append(
|
| 352 |
+
TriageSignal(
|
| 353 |
+
name="state_update_gap",
|
| 354 |
+
value="possible missing final append",
|
| 355 |
+
impact="logic",
|
| 356 |
+
weight=0.45,
|
| 357 |
+
evidence="A collection is returned without an obvious final state flush.",
|
| 358 |
+
)
|
| 359 |
+
)
|
| 360 |
+
heuristic_scores["logic"] += 0.18
|
| 361 |
+
|
| 362 |
+
return signals, heuristic_scores, notes
|
| 363 |
+
|
| 364 |
+
def _nearest_match(self, embedding: torch.Tensor) -> tuple[TriagePrototype, float, dict[str, float]]:
|
| 365 |
+
similarities = torch.matmul(embedding, self._prototype_embeddings().T)[0]
|
| 366 |
+
indexed_scores = {
|
| 367 |
+
self.prototypes[index].task_id: round(float((similarities[index] + 1.0) / 2.0), 4)
|
| 368 |
+
for index in range(len(self.prototypes))
|
| 369 |
+
}
|
| 370 |
+
best_index = int(torch.argmax(similarities).item())
|
| 371 |
+
best_prototype = self.prototypes[best_index]
|
| 372 |
+
best_similarity = float((similarities[best_index] + 1.0) / 2.0)
|
| 373 |
+
return best_prototype, best_similarity, indexed_scores
|
| 374 |
+
|
| 375 |
+
def _repair_plan(self, label: IssueLabel, matched: TriagePrototype, context_window: str) -> list[str]:
|
| 376 |
+
context = _sanitize_text(context_window)
|
| 377 |
+
step_one = {
|
| 378 |
+
"syntax": "Step 1 - Syntax checking and bug fixes: resolve the parser break before touching behavior, then align the function with the expected contract.",
|
| 379 |
+
"logic": "Step 1 - Syntax checking and bug fixes: confirm the code parses cleanly, then patch the failing branch or state update causing the incorrect result.",
|
| 380 |
+
"performance": "Step 1 - Syntax checking and bug fixes: keep the implementation correct first, then isolate the slow section without changing external behavior.",
|
| 381 |
+
}[label]
|
| 382 |
+
step_two = (
|
| 383 |
+
"Step 2 - Edge case handling: verify empty input, boundary values, missing fields, and final-state flush behavior "
|
| 384 |
+
f"against the known pattern `{matched.title}`."
|
| 385 |
+
)
|
| 386 |
+
step_three = (
|
| 387 |
+
"Step 3 - Scalability of code: remove repeated full scans, prefer linear-time data structures, "
|
| 388 |
+
"and benchmark the path on a production-like fixture."
|
| 389 |
+
)
|
| 390 |
+
if context:
|
| 391 |
+
step_two = f"{step_two} Context window to preserve: {context}"
|
| 392 |
+
return [step_one, step_two, step_three]
|
| 393 |
+
|
| 394 |
+
def _reference_quality_score(self, code: str, matched: TriagePrototype) -> float:
|
| 395 |
+
candidate = self.backend.embed_texts([_sanitize_text(code) or "# empty"])
|
| 396 |
+
match_index = next(index for index, prototype in enumerate(self.prototypes) if prototype.task_id == matched.task_id)
|
| 397 |
+
reference = self._reference_code_embeddings()[match_index : match_index + 1]
|
| 398 |
+
score = float(torch.matmul(candidate, reference.T)[0][0].item())
|
| 399 |
+
return _clamp_unit((score + 1.0) / 2.0)
|
| 400 |
+
|
| 401 |
+
def triage(self, code: str, traceback_text: str = "", context_window: str = "") -> TriageResult:
|
| 402 |
+
"""Run the full triage pipeline on code plus optional failure context."""
|
| 403 |
+
|
| 404 |
+
started = time.perf_counter()
|
| 405 |
+
document = self._build_review_document(code, traceback_text, context_window)
|
| 406 |
+
signals, heuristic_scores, notes = self._extract_signals(code, traceback_text)
|
| 407 |
+
|
| 408 |
+
candidate_embedding = self.backend.embed_texts([document])
|
| 409 |
+
matched, matched_similarity, prototype_scores = self._nearest_match(candidate_embedding)
|
| 410 |
+
|
| 411 |
+
label_similarity = {label: 0.18 for label in LABELS}
|
| 412 |
+
for prototype in self.prototypes:
|
| 413 |
+
label_similarity[prototype.label] = max(
|
| 414 |
+
label_similarity[prototype.label],
|
| 415 |
+
prototype_scores[prototype.task_id],
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
combined_scores = {
|
| 419 |
+
label: 0.72 * label_similarity[label] + 0.28 * heuristic_scores[label]
|
| 420 |
+
for label in LABELS
|
| 421 |
+
}
|
| 422 |
+
confidence_scores = _safe_softmax(combined_scores)
|
| 423 |
+
issue_label = max(LABELS, key=lambda label: confidence_scores[label])
|
| 424 |
+
top_confidence = confidence_scores[issue_label]
|
| 425 |
+
|
| 426 |
+
top_signal = signals[0].evidence if signals else "Model similarity dominated the decision."
|
| 427 |
+
ml_quality_score = self._reference_quality_score(code, matched)
|
| 428 |
+
lint_score = _lint_score(code)
|
| 429 |
+
complexity_penalty = _complexity_penalty(code)
|
| 430 |
+
reward_score = _clamp_unit((0.5 * ml_quality_score) + (0.3 * lint_score) - (0.2 * complexity_penalty))
|
| 431 |
+
summary = (
|
| 432 |
+
f"Detected a {issue_label} issue with {top_confidence:.0%} confidence. "
|
| 433 |
+
f"The closest known failure pattern is `{matched.title}`, which indicates {matched.summary.lower()}. "
|
| 434 |
+
f"Predicted quality score is {ml_quality_score:.0%} with an RL-ready reward of {reward_score:.0%}."
|
| 435 |
+
)
|
| 436 |
+
suggested_next_action = {
|
| 437 |
+
"syntax": "Fix the parser error first, then rerun validation before changing behavior.",
|
| 438 |
+
"logic": "Step through the smallest failing case and confirm the final branch/update behavior.",
|
| 439 |
+
"performance": "Replace repeated full-list scans with a linear-time aggregation strategy, then benchmark it.",
|
| 440 |
+
}[issue_label]
|
| 441 |
+
|
| 442 |
+
return TriageResult(
|
| 443 |
+
issue_label=issue_label,
|
| 444 |
+
confidence_scores=confidence_scores,
|
| 445 |
+
repair_risk=_repair_risk(issue_label, top_confidence, len(signals)),
|
| 446 |
+
ml_quality_score=ml_quality_score,
|
| 447 |
+
lint_score=lint_score,
|
| 448 |
+
complexity_penalty=complexity_penalty,
|
| 449 |
+
reward_score=reward_score,
|
| 450 |
+
summary=summary,
|
| 451 |
+
matched_pattern=PrototypeMatch(
|
| 452 |
+
task_id=matched.task_id,
|
| 453 |
+
title=matched.title,
|
| 454 |
+
label=matched.label,
|
| 455 |
+
similarity=round(matched_similarity, 4),
|
| 456 |
+
summary=matched.summary,
|
| 457 |
+
rationale=top_signal,
|
| 458 |
+
),
|
| 459 |
+
repair_plan=self._repair_plan(issue_label, matched, context_window),
|
| 460 |
+
suggested_next_action=suggested_next_action,
|
| 461 |
+
extracted_signals=signals,
|
| 462 |
+
model_backend=self.backend.backend_name,
|
| 463 |
+
model_id=self.backend.model_id,
|
| 464 |
+
inference_notes=list(self.backend.notes) + notes,
|
| 465 |
+
analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2),
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@lru_cache(maxsize=1)
|
| 470 |
+
def get_default_engine() -> CodeTriageEngine:
|
| 471 |
+
"""Return a cached triage engine for the running process."""
|
| 472 |
+
|
| 473 |
+
return CodeTriageEngine()
|
triage_catalog.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Curated prototypes and example inputs for TorchReview Copilot."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from .triage_models import IssueLabel, TriageExample, TriagePrototype
|
| 9 |
+
from .tasks import list_tasks
|
| 10 |
+
except ImportError:
|
| 11 |
+
from triage_models import IssueLabel, TriageExample, TriagePrototype
|
| 12 |
+
from tasks import list_tasks
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
TASK_KIND_TO_LABEL: Dict[str, IssueLabel] = {
|
| 16 |
+
"syntax_fix": "syntax",
|
| 17 |
+
"bug_fix": "logic",
|
| 18 |
+
"optimization": "performance",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
TRACEBACK_BY_TASK_ID: Dict[str, str] = {
|
| 22 |
+
"syntax_fix_invoice_totals": (
|
| 23 |
+
"Traceback (most recent call last):\n"
|
| 24 |
+
" File \"services/billing/reconciliation.py\", line 3\n"
|
| 25 |
+
" for record in records\n"
|
| 26 |
+
" ^\n"
|
| 27 |
+
"SyntaxError: expected ':'"
|
| 28 |
+
),
|
| 29 |
+
"bug_fix_session_windows": (
|
| 30 |
+
"AssertionError: collapse_sessions([{'minute': 1}, {'minute': 3}, {'minute': 8}], 4)\n"
|
| 31 |
+
"Expected: [(1, 3), (8, 8)]\n"
|
| 32 |
+
"Actual: [(1, 8)]\n"
|
| 33 |
+
"Boundary handling merges the final session instead of starting a new one."
|
| 34 |
+
),
|
| 35 |
+
"optimization_rank_active_users": (
|
| 36 |
+
"BenchmarkWarning: rank_active_users exceeded the 450ms budget on a nightly export fixture.\n"
|
| 37 |
+
"Profiler hint: repeated scans over the full event list and nested loops dominate runtime."
|
| 38 |
+
),
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
SUMMARY_BY_TASK_ID: Dict[str, str] = {
|
| 42 |
+
"syntax_fix_invoice_totals": "Broken parser state in a billing helper blocks reconciliation jobs.",
|
| 43 |
+
"bug_fix_session_windows": "Session-boundary logic fails on inclusive idle-timeout edges.",
|
| 44 |
+
"optimization_rank_active_users": "A nightly ranking job is correct on small fixtures but too slow at production scale.",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
CONTEXT_BY_TASK_ID: Dict[str, str] = {
|
| 48 |
+
"syntax_fix_invoice_totals": (
|
| 49 |
+
"Context window: this helper runs in an end-of-day billing reconciliation job. "
|
| 50 |
+
"Keep the public function signature intact and restore correct totals for mixed integer/string inputs."
|
| 51 |
+
),
|
| 52 |
+
"bug_fix_session_windows": (
|
| 53 |
+
"Context window: this function groups sorted product analytics events into sessions for retention dashboards. "
|
| 54 |
+
"Boundary behavior must stay deterministic because downstream reports depend on it."
|
| 55 |
+
),
|
| 56 |
+
"optimization_rank_active_users": (
|
| 57 |
+
"Context window: this pipeline feeds a nightly export on a small CPU instance. "
|
| 58 |
+
"Maintain identical output ordering while improving scalability on larger event volumes."
|
| 59 |
+
),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _prototype_text(
|
| 64 |
+
task_id: str,
|
| 65 |
+
title: str,
|
| 66 |
+
description: str,
|
| 67 |
+
repo_summary: str,
|
| 68 |
+
goal: str,
|
| 69 |
+
visible_tests: List[str],
|
| 70 |
+
starter_code: str,
|
| 71 |
+
traceback_text: str,
|
| 72 |
+
) -> str:
|
| 73 |
+
visible = "\n".join(f"- {item}" for item in visible_tests) or "- none"
|
| 74 |
+
return (
|
| 75 |
+
f"Title: {title}\n"
|
| 76 |
+
f"Problem: {description}\n"
|
| 77 |
+
f"Repo context: {repo_summary}\n"
|
| 78 |
+
f"Goal: {goal}\n"
|
| 79 |
+
f"Observed failure:\n{traceback_text}\n"
|
| 80 |
+
f"Visible checks:\n{visible}\n"
|
| 81 |
+
f"Candidate code:\n{starter_code}\n"
|
| 82 |
+
f"Task id: {task_id}\n"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def build_examples() -> List[TriageExample]:
|
| 87 |
+
"""Create stable UI examples from the task catalog."""
|
| 88 |
+
|
| 89 |
+
examples: List[TriageExample] = []
|
| 90 |
+
for task in list_tasks():
|
| 91 |
+
label = TASK_KIND_TO_LABEL[task.task_kind]
|
| 92 |
+
examples.append(
|
| 93 |
+
TriageExample(
|
| 94 |
+
key=task.task_id,
|
| 95 |
+
title=task.title,
|
| 96 |
+
label=label,
|
| 97 |
+
summary=SUMMARY_BY_TASK_ID[task.task_id],
|
| 98 |
+
code=task.starter_code,
|
| 99 |
+
traceback_text=TRACEBACK_BY_TASK_ID[task.task_id],
|
| 100 |
+
context_window=CONTEXT_BY_TASK_ID[task.task_id],
|
| 101 |
+
task_id=task.task_id,
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
return examples
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_prototypes() -> List[TriagePrototype]:
|
| 108 |
+
"""Build canonical triage prototypes from the OpenEnv tasks."""
|
| 109 |
+
|
| 110 |
+
prototypes: List[TriagePrototype] = []
|
| 111 |
+
for task in list_tasks():
|
| 112 |
+
traceback_text = TRACEBACK_BY_TASK_ID[task.task_id]
|
| 113 |
+
prototypes.append(
|
| 114 |
+
TriagePrototype(
|
| 115 |
+
task_id=task.task_id,
|
| 116 |
+
title=task.title,
|
| 117 |
+
label=TASK_KIND_TO_LABEL[task.task_kind],
|
| 118 |
+
summary=SUMMARY_BY_TASK_ID[task.task_id],
|
| 119 |
+
reference_text=_prototype_text(
|
| 120 |
+
task.task_id,
|
| 121 |
+
task.title,
|
| 122 |
+
task.task_description,
|
| 123 |
+
task.repo_summary,
|
| 124 |
+
task.goal,
|
| 125 |
+
list(task.visible_tests),
|
| 126 |
+
task.reference_code,
|
| 127 |
+
traceback_text,
|
| 128 |
+
),
|
| 129 |
+
starter_code=task.starter_code,
|
| 130 |
+
reference_code=task.reference_code,
|
| 131 |
+
traceback_text=traceback_text,
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
return prototypes
|
triage_models.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Typed models for TorchReview Copilot outputs and examples."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Literal
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
IssueLabel = Literal["syntax", "logic", "performance"]
|
| 11 |
+
RiskLevel = Literal["low", "medium", "high"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TriageSignal(BaseModel):
|
| 15 |
+
"""One extracted signal used during issue classification."""
|
| 16 |
+
|
| 17 |
+
name: str
|
| 18 |
+
value: str
|
| 19 |
+
impact: Literal["syntax", "logic", "performance", "mixed"] = "mixed"
|
| 20 |
+
weight: float = Field(..., ge=0.0, le=1.0)
|
| 21 |
+
evidence: str = ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PrototypeMatch(BaseModel):
|
| 25 |
+
"""Nearest known bug pattern from the built-in task catalog."""
|
| 26 |
+
|
| 27 |
+
task_id: str
|
| 28 |
+
title: str
|
| 29 |
+
label: IssueLabel
|
| 30 |
+
similarity: float = Field(..., ge=0.0, le=1.0)
|
| 31 |
+
summary: str
|
| 32 |
+
rationale: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TriageExample(BaseModel):
|
| 36 |
+
"""Example payload exposed in the demo UI."""
|
| 37 |
+
|
| 38 |
+
key: str
|
| 39 |
+
title: str
|
| 40 |
+
label: IssueLabel
|
| 41 |
+
summary: str
|
| 42 |
+
code: str
|
| 43 |
+
traceback_text: str
|
| 44 |
+
context_window: str
|
| 45 |
+
task_id: str
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TriagePrototype(BaseModel):
|
| 49 |
+
"""Canonical issue-pattern representation embedded by the triage engine."""
|
| 50 |
+
|
| 51 |
+
task_id: str
|
| 52 |
+
title: str
|
| 53 |
+
label: IssueLabel
|
| 54 |
+
summary: str
|
| 55 |
+
reference_text: str
|
| 56 |
+
starter_code: str
|
| 57 |
+
reference_code: str
|
| 58 |
+
traceback_text: str
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TriageResult(BaseModel):
|
| 62 |
+
"""Structured output produced by the triage pipeline."""
|
| 63 |
+
|
| 64 |
+
issue_label: IssueLabel
|
| 65 |
+
confidence_scores: Dict[str, float]
|
| 66 |
+
repair_risk: RiskLevel
|
| 67 |
+
ml_quality_score: float = Field(..., ge=0.0, le=1.0)
|
| 68 |
+
lint_score: float = Field(..., ge=0.0, le=1.0)
|
| 69 |
+
complexity_penalty: float = Field(..., ge=0.0, le=1.0)
|
| 70 |
+
reward_score: float = Field(..., ge=0.0, le=1.0)
|
| 71 |
+
summary: str
|
| 72 |
+
matched_pattern: PrototypeMatch
|
| 73 |
+
repair_plan: List[str]
|
| 74 |
+
suggested_next_action: str
|
| 75 |
+
extracted_signals: List[TriageSignal] = Field(default_factory=list)
|
| 76 |
+
model_backend: str
|
| 77 |
+
model_id: str
|
| 78 |
+
inference_notes: List[str] = Field(default_factory=list)
|
| 79 |
+
analysis_time_ms: float = Field(..., ge=0.0)
|