diff --git a/DEMO_SCRIPT.md b/DEMO_SCRIPT.md index 370cb3604936eabc892f8a44b50665f67c39019b..252599c1fdc7fd76ffb2ed52801148a3a6e602e0 100644 --- a/DEMO_SCRIPT.md +++ b/DEMO_SCRIPT.md @@ -2,20 +2,11 @@ ## 60-90 Second Walkthrough -1. Introduce TorchReview Copilot as an AI-powered code review system that helps developers find bugs, reduce complexity, and improve maintainability faster. -2. Frame the problem clearly: manual code reviews are slow, inconsistent, and hard to scale across growing teams and codebases. -3. Open the Streamlit app and load the `Boundary Bug` example to show a realistic Python regression with failing behavior. -4. Point out the pipeline on-screen: - input code, static analysis, PyTorch scoring, suggestions, and RL-ready reward output. -5. Highlight the PyTorch story: - the app uses CodeBERTa embeddings through PyTorch to score code quality, maintainability, and domain fit. -6. Show the headline metrics: - detected domain, ML score, lint score, and final reward. -7. Scroll to the reward breakdown and explain that the reward is not arbitrary; it combines ML quality, maintainability, security, lint signals, and complexity penalties. -8. Open the Suggestions tab and show the prioritized fixes plus the three-step improvement plan. -9. Switch to the `Performance Hotspot` example to demonstrate that the system adapts to a different issue profile and pushes optimization hints instead of only syntax guidance. -10. Close by emphasizing that the same repo also works as an OpenEnv environment, so the project is both a usable developer product and an RL-ready benchmark component. - -## 20-Second Closing Line - -TorchReview Copilot turns code review into a measurable AI workflow: PyTorch handles semantic scoring, deterministic analyzers keep it grounded, and OpenEnv makes it trainable and benchmarkable. +1. Open the Hugging Face Space and introduce TorchReview Copilot as an AI-powered code review and improvement system built with PyTorch. +2. Point to the problem statement: manual code review is slow, inconsistent, and hard to scale. +3. Select the `Fix the invoice total syntax regression` example to show the app loading a broken code sample together with the context window. +4. Highlight the **Live Triage Radar**, the ML quality score, and the RL-ready reward score. +5. Explain that the PyTorch layer uses CodeBERTa embeddings to compare the input against known code-quality patterns from the OpenEnv task catalog. +6. Scroll to the three-step improvement plan and call out the progression: syntax and bug fixes, edge cases, then scalability. +7. Switch to the performance example to show the confidence profile and reward changing for a different class of issue. +8. Close by noting that OpenEnv still powers deterministic validation under the hood, so the demo remains grounded in measurable task outcomes. diff --git a/Dockerfile b/Dockerfile index 491f212cce74c59de8ff59b4a839f16d853caf29..3aac14ba49a72604cd03c088b4401f0f0508d4ed 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,24 +6,31 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONIOENCODING=utf-8 \ PIP_NO_CACHE_DIR=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \ - PIP_DEFAULT_TIMEOUT=120 \ + PIP_ROOT_USER_ACTION=ignore \ ENABLE_GRADIO_DEMO=false \ ENABLE_WEB_INTERFACE=false WORKDIR /app -COPY server/requirements.txt /tmp/requirements.txt +COPY server/requirements.runtime.txt /tmp/requirements.runtime.txt -RUN python -m pip install --upgrade pip && \ - pip install --prefer-binary -r /tmp/requirements.txt +RUN apt-get update && \ + apt-get upgrade -y && \ + rm -rf /var/lib/apt/lists/* -COPY . /app +RUN useradd --create-home --shell /usr/sbin/nologin appuser && \ + python -m pip install --upgrade pip setuptools && \ + pip install -r /tmp/requirements.runtime.txt + +COPY --chown=appuser:appuser . /app RUN pip install --no-deps . +USER appuser + EXPOSE 8000 HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=3).read()" -CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--no-access-log"] diff --git a/README.md b/README.md index 046b1f4a7e814ae3de97642d97f36a6c6e99411c..09da0ae85edb888ac82f3547e92a3eb3369a8a16 100644 --- a/README.md +++ b/README.md @@ -1,232 +1,91 @@ --- -title: TorchReview Copilot +title: Python Code Review Environment Server sdk: docker app_port: 8000 base_path: /web pinned: false tags: - openenv - - pytorch - - code-review --- -# TorchReview Copilot +# OpenEnv Python Code Review Environment -TorchReview Copilot is an AI-powered code review and improvement system built for the Meta PyTorch OpenEnv Hackathon. It combines deterministic static analysis, a real PyTorch code encoder, domain-aware review logic, and RL-ready reward shaping to help developers catch bugs, reduce complexity, and improve maintainability faster. +Production-ready hackathon submission for OpenEnv evaluation, deterministic validator runs, and Hugging Face Docker deployment. -## Problem Statement - -Manual code review is slow, inconsistent, and difficult to scale. Small logic bugs slip through, performance hotspots hide in otherwise correct code, and review quality changes from reviewer to reviewer. - -## Solution - -TorchReview Copilot accepts Python code, analyzes it with AST and complexity heuristics, scores it with a PyTorch model, and returns: - -- A code quality score -- Domain-aware review feedback -- Actionable improvement suggestions -- An RL-ready reward signal for OpenEnv environments - -## Why This Is Hackathon-Worthy - -- Solves a real developer productivity problem -- Uses PyTorch meaningfully for model inference, not as a placeholder -- Produces a measurable reward signal for RL workflows -- Ships as a usable product with API, UI, docs, tests, and OpenEnv compatibility - -## Tech Stack - -- `PyTorch` for model execution and similarity scoring -- `transformers` with `huggingface/CodeBERTa-small-v1` for pretrained code embeddings -- `FastAPI` for the analysis API -- `Streamlit` for the interactive review UI -- `Pydantic` for request and response validation -- `OpenAI` Python client for hackathon-compliant LLM action planning in `inference.py` -- `OpenEnv` for environment, reward, and validator integration - -## Pipeline - -```text -Input Python Code - -> AST Parsing + Structural Signals - -> Complexity + Lint Heuristics - -> PyTorch Model Inference (CodeBERTa / torch fallback) - -> Domain Analysis + Suggestion Engine - -> RL Reward Shaping - -> UI + API + OpenEnv Environment Output -``` - -## PyTorch Integration - -PyTorch is used in the core scoring path: - -- The app loads `huggingface/CodeBERTa-small-v1` through `transformers` -- Input code, repository context, traceback text, and static-analysis hints are embedded with the encoder -- The resulting embedding is compared against quality, maintainability, domain, and issue prototypes -- The model produces: - - `ml_quality_score` - - `maintainability_score` - - domain confidences - - issue probabilities - -If pretrained weights are unavailable, the project falls back to a torch-native hashed embedding backend so local demos and CI still work offline. - -## Reward System - -The system is RL-ready by design. Reward shaping blends model confidence, code quality, security, maintainability, and complexity into a bounded signal. - -Core reward: - -```text -reward = 0.50*ml_score - + 0.18*lint_score - + 0.12*maintainability_score - + 0.10*domain_score - + 0.10*security_score - - 0.20*complexity_penalty -``` - -The OpenEnv environment adds step-level shaping for: - -- public test progress -- syntax recovery -- runtime improvements -- error reduction -- final submission success -- regressions and invalid actions - -All task and step rewards are normalized into a strict safe interval for OpenEnv validation and printed in a validator-safe two-decimal band. - -## Features - -- Real PyTorch-backed code quality inference -- Static analysis with syntax, lint, AST, and complexity signals -- Domain-aware review for DSA, data science, ML/DL, and web code -- Prioritized suggestions and a compact 3-step improvement plan -- Auto-fix preview hints for quick wins -- Real-time Streamlit scoring mode -- OpenEnv-compatible environment and `inference.py` -- Deterministic benchmark tasks for syntax fixes, bug fixes, and optimization - -## WOW Features - -- Real-time scoring in the Streamlit interface -- Auto-fix preview panel -- Reward visualization and score breakdown -- OpenEnv environment with transparent reward decomposition - -## Project Structure +## Architecture ```text root -|- inference.py -|- api/ +|- inference.py # Root validator entrypoint +|- openenv.yaml # OpenEnv manifest |- app/ -| |- agents/ -| |- env/ -| |- models/ -| |- services/ -| `- utils/ -|- analyzers/ -|- graders/ -|- models/ -|- schemas/ -|- services/ -|- tasks/ -|- tests/ -`- utils/ -``` - -Key modules: - -- `models/pytorch_model.py`: PyTorch + transformer inference -- `services/analysis_service.py`: end-to-end review pipeline -- `services/reward_service.py`: RL-friendly reward shaping -- `services/suggestion_service.py`: actionable recommendations -- `app/streamlit_app.py`: interactive UI -- `server/env.py`: OpenEnv environment implementation -- `app/env/runner.py`: strict `inference.py` runner - -## API - -Run the analysis API: - -```bash -python -m uvicorn api.main:app --host 0.0.0.0 --port 7860 +| |- agents/ # Action policy and fallback strategy +| |- env/ # RL loop runner and stdout contract +| |- models/ # Inference dataclasses/config +| |- services/ # OpenAI client wrapper with retries +| `- utils/ # Formatting, task loading, log suppression +|- server/ +| |- env.py # OpenEnv environment and reward shaping +| |- app.py # FastAPI/OpenEnv app, optional Gradio mount +| `- Dockerfile # Alternate Docker build path +|- Dockerfile # Root deployment Docker image +|- graders/ # Syntax, bug-fix, optimization graders +|- tasks/ # Deterministic benchmark tasks and references +|- services/ # Multi-domain analysis services +|- analyzers/ # Domain-specific analyzers +|- models/ # Lazy-loaded PyTorch scoring model +|- schemas/ # API request/response contracts +`- tests/ # Local validation coverage ``` -Main endpoint: - -- `POST /analyze` - -The API returns: +Runtime flow: -- detected domain -- static-analysis summary -- model prediction -- score breakdown -- suggestions -- improvement plan - -## Streamlit UI - -Run the product UI locally: - -```bash -streamlit run app/streamlit_app.py +```text +inference.py + -> app.env.runner.InferenceRunner + -> env.reset(task_id=...) + -> ReviewAgent(action planning) + -> env.step_result(action) + -> strict [START]/[STEP]/[END] output ``` -The UI includes: - -- code input editor -- example snippets -- real-time scoring toggle -- ML score, lint score, and reward display -- domain confidence chart -- reward-signal visualization -- suggestion list and auto-fix preview - -## OpenEnv Compatibility - -This repository is also a valid OpenEnv submission: +## What Was Fixed -- `inference.py` is in the repo root -- `API_BASE_URL` and `MODEL_NAME` have defaults -- `HF_TOKEN` is read from the environment -- The runner uses the official `OpenAI` Python client -- Output follows the required `[START]`, `[STEP]`, `[END]` contract - -Example: - -```text -[START] task=syntax_fix_invoice_totals env=python_code_review_env model=Qwen/Qwen2.5-3B-Instruct -[STEP] step=1 action=run_tests reward=0.34 done=false error=null -[STEP] step=2 action=edit_code reward=0.42 done=false error=null -[STEP] step=3 action=submit_solution reward=0.99 done=true error=null -[END] success=true steps=3 rewards=0.34,0.42,0.99 -``` +- `inference.py` now lives at the repo root and delegates to a strict runner under `app/env`. +- OpenAI usage is limited to the official Python client: + `client = OpenAI(base_url=API_BASE_URL, api_key=provider_token)`. +- Defaulted env vars are enforced for `API_BASE_URL` and `MODEL_NAME`; the runtime now selects `HF_TOKEN` for the Hugging Face router and `OPENAI_API_KEY` for direct OpenAI usage. +- Output now matches the required single-line contract exactly and always emits `[END]`, including failure paths. +- The RL loop now uses `reset()` plus `step_result()` in a proper `while not done` loop. +- Step errors now surface through `last_action_error` and are printed in `[STEP]`. +- Reward shaping is now dynamic in the OpenEnv environment: + code quality, test progress, runtime progress, error removal, regressions, and completion are all part of the reward. +- The API-side reward service is no longer a static weighted sum and now exposes quality, error-reduction, and completion signals. +- The Docker image now builds from the repo root, caches dependency installation more effectively, and runs `server.app:app` directly on port `8000`. +- Server startup is lighter: + the PyTorch analyzer is lazy-loaded and the Gradio demo is disabled by default. -## Setup +## Local Setup -Install dependencies: +Install dev dependencies: ```bash pip install -e .[dev] ``` -Run tests: +Run the test suite: ```bash pytest -q ``` -Run the OpenEnv server: +Run the OpenEnv server locally: ```bash python -m uvicorn server.app:app --host 0.0.0.0 --port 8000 ``` -Run the demo UI mounted into the server: +Optional demo UI: ```bash set ENABLE_GRADIO_DEMO=true @@ -234,49 +93,100 @@ set ENABLE_WEB_INTERFACE=true python -m uvicorn server.app:app --host 0.0.0.0 --port 8000 ``` -## Hugging Face Spaces +## Inference Contract -This repo is designed to run on a Docker-based Hugging Face Space under a `2 vCPU / 8 GB RAM` budget. +Required environment variables: -Recommended Space settings: +- `API_BASE_URL` + Default: `https://router.huggingface.co/v1` +- `MODEL_NAME` + Default: `Qwen/Qwen2.5-3B-Instruct` +- `HF_TOKEN` + Required for `https://router.huggingface.co/v1` +- `OPENAI_API_KEY` + Required for `https://api.openai.com/v1` -- SDK: `Docker` -- Port: `8000` -- Secret: `HF_TOKEN` -- Optional vars: - - `API_BASE_URL` - - `MODEL_NAME` - - `ENABLE_GRADIO_DEMO=false` - - `ENABLE_WEB_INTERFACE=false` +Example: -## Screenshots +```bash +set API_BASE_URL=https://router.huggingface.co/v1 +set MODEL_NAME=Qwen/Qwen2.5-3B-Instruct +set HF_TOKEN=hf_xxx +python inference.py +``` -Add these before final submission: +```bash +set API_BASE_URL=https://api.openai.com/v1 +set MODEL_NAME=gpt-4.1-mini +set OPENAI_API_KEY=sk-xxx +python inference.py +``` -- Main review UI with code editor and reward metrics -- Suggestions tab with improvement plan -- OpenEnv task loop or validator output snippet +Expected stdout shape: -## Demo Link +```text +[START] task=syntax_fix_invoice_totals env=python_code_review_env model=Qwen/Qwen2.5-3B-Instruct +[STEP] step=1 action=run_tests reward=0.12 done=false error=null +[STEP] step=2 action=edit_code reward=0.96 done=false error=null +[STEP] step=3 action=run_tests reward=0.99 done=false error=null +[STEP] step=4 action=submit_solution reward=0.99 done=true error=null +[END] success=true steps=4 rewards=0.12,0.96,0.99,0.99 +``` -Add your live Hugging Face Space URL here before final submission. +## Docker -## Demo Script +Build from the project root: + +```bash +docker build -t openenv-python-code-review-env . +``` -See [DEMO_SCRIPT.md](DEMO_SCRIPT.md) for a concise hackathon walkthrough. +Run locally: -## Testing +```bash +docker run --rm -p 8000:8000 ^ + -e API_BASE_URL=https://router.huggingface.co/v1 ^ + -e MODEL_NAME=Qwen/Qwen2.5-3B-Instruct ^ + -e HF_TOKEN=hf_xxx ^ + openenv-python-code-review-env +``` -The repo includes coverage for: +Container behavior: -- score normalization into the strict OpenEnv-safe interval -- inference output formatting -- API response structure -- multi-domain analysis behavior -- triage and embedding behavior +- Base image: `python:3.11-slim-bookworm` +- Build context: project root +- Runtime image installs the minimal API dependency set by default; Streamlit, PyTorch, and transformers stay out of the container, while Gradio is only used if the demo env flags are enabled. +- Healthcheck: `GET /health` +- Default entrypoint: `uvicorn server.app:app --host 0.0.0.0 --port 8000` -## Notes for Judges +## Hugging Face Spaces -- This is not a toy wrapper around an LLM. The review pipeline includes deterministic analysis, PyTorch-based code scoring, and explicit reward shaping. -- The system is useful both as a developer-facing application and as a benchmark-friendly RL environment. -- The design intentionally balances product polish with validator reliability. +Recommended deployment steps: + +1. Create a Docker Space. +2. Push this repository as-is. +3. Let Spaces build from the root `Dockerfile`. +4. Set Space secrets: + `HF_TOKEN` +5. Set Space variables as needed: + `API_BASE_URL`, `MODEL_NAME`, `ENABLE_GRADIO_DEMO=false` + `ENABLE_WEB_INTERFACE=false` is also supported for OpenEnv-managed deploys. +6. Confirm the app listens on port `8000`. +7. Smoke-test: + `/health` + `/reset` + `/step` + +## Performance Notes + +- Max concurrent environments default to `2`, aligned with a `2 vCPU / 8 GB RAM` target. +- The analyzer model is lazy-loaded instead of being created at startup. +- The inference runner relies on short prompts, low token budgets, and limited retries. +- The policy uses deterministic reference-code fallback instead of expensive iterative code generation. +- Public validation is preferred before final submission to avoid wasted hidden-eval steps. + +## Known Limitations + +- If `HF_TOKEN` is absent, inference still completes with deterministic fallback actions, but LLM guidance is skipped. +- The benchmark tasks are deterministic and intentionally small; this is good for validator stability but not a full training benchmark. +- Gradio remains optional and is disabled by default to keep deployment lighter. diff --git a/__init__.py b/__init__.py index 6df09bc4fe02055cd825d79a0bfe1b716ce0858a..4f13e29c33475d3b4abda267521c879b75c45873 100644 --- a/__init__.py +++ b/__init__.py @@ -1,52 +1,36 @@ """Public package exports for python_code_review_env.""" -try: - from .client import PythonCodeReviewEnv, PythonEnv - from .models import ( - PyTorchCodeAnalyzerModel, - PythonAction, - PythonCodeReviewAction, - PythonCodeReviewObservation, - PythonCodeReviewState, - PythonObservation, - PythonState, - ) - from .schemas import AnalyzeCodeRequest, AnalyzeCodeResponse - from .services import AnalysisService - from .triage import CodeTriageEngine, HashingEmbeddingBackend, TransformersEmbeddingBackend, get_default_engine - from .triage_models import TriageResult -except ImportError: # pragma: no cover - from client import PythonCodeReviewEnv, PythonEnv - from models import ( - PyTorchCodeAnalyzerModel, - PythonAction, - PythonCodeReviewAction, - PythonCodeReviewObservation, - PythonCodeReviewState, - PythonObservation, - PythonState, - ) - from schemas import AnalyzeCodeRequest, AnalyzeCodeResponse - from services import AnalysisService - from triage import CodeTriageEngine, HashingEmbeddingBackend, TransformersEmbeddingBackend, get_default_engine - from triage_models import TriageResult - -__all__ = [ - "PythonAction", - "PythonObservation", +from .client import PythonCodeReviewEnv, PythonEnv +from .models import ( + PyTorchCodeAnalyzerModel, + PythonAction, + PythonCodeReviewAction, + PythonCodeReviewObservation, + PythonCodeReviewState, + PythonObservation, + PythonState, +) +from .schemas import AnalyzeCodeRequest, AnalyzeCodeResponse +from .services import AnalysisService +from .triage import CodeTriageEngine, HashingEmbeddingBackend, TransformersEmbeddingBackend, get_default_engine +from .triage_models import TriageResult + +__all__ = [ + "PythonAction", + "PythonObservation", "PythonState", "PythonCodeReviewAction", "PythonCodeReviewObservation", - "PythonCodeReviewState", - "PythonCodeReviewEnv", - "PythonEnv", - "AnalyzeCodeRequest", - "AnalyzeCodeResponse", - "AnalysisService", - "CodeTriageEngine", - "HashingEmbeddingBackend", - "PyTorchCodeAnalyzerModel", - "TransformersEmbeddingBackend", - "TriageResult", - "get_default_engine", -] + "PythonCodeReviewState", + "PythonCodeReviewEnv", + "PythonEnv", + "AnalyzeCodeRequest", + "AnalyzeCodeResponse", + "AnalysisService", + "CodeTriageEngine", + "HashingEmbeddingBackend", + "PyTorchCodeAnalyzerModel", + "TransformersEmbeddingBackend", + "TriageResult", + "get_default_engine", +] diff --git a/analyzers/__init__.py b/analyzers/__init__.py index 93f7f72c735fc16092ecd33886e9df50ffdcdbc9..fd156a4b63d0f21692e69c3de24047968556867e 100644 --- a/analyzers/__init__.py +++ b/analyzers/__init__.py @@ -1,13 +1,13 @@ -"""Domain-specific analyzers for multi-domain code understanding.""" - -from .dsa_analyzer import analyze_dsa_code -from .ds_analyzer import analyze_data_science_code -from .ml_analyzer import analyze_ml_code -from .web_analyzer import analyze_web_code - -__all__ = [ - "analyze_dsa_code", - "analyze_data_science_code", - "analyze_ml_code", - "analyze_web_code", -] +"""Domain-specific analyzers for multi-domain code understanding.""" + +from .dsa_analyzer import analyze_dsa_code +from .ds_analyzer import analyze_data_science_code +from .ml_analyzer import analyze_ml_code +from .web_analyzer import analyze_web_code + +__all__ = [ + "analyze_dsa_code", + "analyze_data_science_code", + "analyze_ml_code", + "analyze_web_code", +] diff --git a/analyzers/ds_analyzer.py b/analyzers/ds_analyzer.py index 4fffe9671f244df4ef57cab1f1faf01497a7e4c9..94b0dfd89378603558fa3970a3306fd285c027b3 100644 --- a/analyzers/ds_analyzer.py +++ b/analyzers/ds_analyzer.py @@ -1,58 +1,56 @@ -"""Analyzer for data-science oriented Python code.""" - -from __future__ import annotations - -from typing import Any, Dict - -from schemas.response import AnalysisIssue, DomainAnalysis - - -def analyze_data_science_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis: - """Inspect pandas and numpy code for vectorization and leakage concerns.""" - - issues = [] - suggestions = [] - score = 0.72 - - if "iterrows(" in code or "itertuples(" in code: +"""Analyzer for data-science oriented Python code.""" + +from __future__ import annotations + +from typing import Any, Dict + +from schemas.response import AnalysisIssue, DomainAnalysis + + +def analyze_data_science_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis: + """Inspect pandas and numpy code for vectorization and leakage concerns.""" + + issues = [] + suggestions = [] + score = 0.72 + + if "iterrows(" in code or "itertuples(" in code: issues.append( AnalysisIssue( title="Row-wise dataframe iteration detected", - category="performance", severity="medium", description="Looping through dataframe rows is usually slower and less scalable than vectorized operations.", ) ) - suggestions.append("Use vectorized pandas or numpy expressions instead of row-wise iteration.") - score -= 0.18 - - if "inplace=True" in code: - suggestions.append("Avoid inplace mutation to keep data pipelines easier to reason about and test.") - score -= 0.05 - - if "fit_transform(" in code and "train_test_split" not in code: + suggestions.append("Use vectorized pandas or numpy expressions instead of row-wise iteration.") + score -= 0.18 + + if "inplace=True" in code: + suggestions.append("Avoid inplace mutation to keep data pipelines easier to reason about and test.") + score -= 0.05 + + if "fit_transform(" in code and "train_test_split" not in code: issues.append( AnalysisIssue( title="Potential data leakage risk", - category="correctness", severity="high", description="Feature transforms appear before an explicit train/test split.", ) ) - suggestions.append("Split train and validation data before fitting stateful preprocessing steps.") - score -= 0.2 - - if not suggestions: - suggestions.append("Add schema assumptions and null-handling checks for production data quality.") - - return DomainAnalysis( - domain="data_science", - domain_score=max(0.05, round(score, 4)), - issues=issues, - suggestions=suggestions, - highlights={ - "vectorization_risk": float("iterrows(" in code or "itertuples(" in code), - "time_complexity": complexity["time_complexity"], - "uses_pandas": float(parsed.get("uses_pandas", False)), - }, - ) + suggestions.append("Split train and validation data before fitting stateful preprocessing steps.") + score -= 0.2 + + if not suggestions: + suggestions.append("Add schema assumptions and null-handling checks for production data quality.") + + return DomainAnalysis( + domain="data_science", + domain_score=max(0.05, round(score, 4)), + issues=issues, + suggestions=suggestions, + highlights={ + "vectorization_risk": float("iterrows(" in code or "itertuples(" in code), + "time_complexity": complexity["time_complexity"], + "uses_pandas": float(parsed.get("uses_pandas", False)), + }, + ) diff --git a/analyzers/dsa_analyzer.py b/analyzers/dsa_analyzer.py index 7ed80bc1f0ff082b1a8beac1bd36a28a667663f6..1b02a5c49de6f36cf5a4ded037435c6edfd5d8e3 100644 --- a/analyzers/dsa_analyzer.py +++ b/analyzers/dsa_analyzer.py @@ -1,49 +1,48 @@ -"""Analyzer for DSA and competitive-programming style Python code.""" - -from __future__ import annotations - -from typing import Any, Dict - -from schemas.response import AnalysisIssue, DomainAnalysis - - -def analyze_dsa_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis: - """Inspect algorithmic code for brute-force patterns and efficiency risks.""" - - issues = [] - suggestions = [] - score = 0.7 - - if parsed.get("max_loop_depth", 0) >= 2: +"""Analyzer for DSA and competitive-programming style Python code.""" + +from __future__ import annotations + +from typing import Any, Dict + +from schemas.response import AnalysisIssue, DomainAnalysis + + +def analyze_dsa_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis: + """Inspect algorithmic code for brute-force patterns and efficiency risks.""" + + issues = [] + suggestions = [] + score = 0.7 + + if parsed.get("max_loop_depth", 0) >= 2: issues.append( AnalysisIssue( title="Nested loops suggest brute-force behavior", - category="performance", severity="medium", description="The implementation scans the input multiple times, which is often avoidable in DSA problems.", ) ) - suggestions.append("Consider replacing nested scans with a hashmap, prefix table, or sorted search strategy.") - score -= 0.15 - - if parsed.get("uses_recursion"): - suggestions.append("Verify recursion depth and add memoization or iterative conversion if the input size can grow.") - score -= 0.05 - - if "sorted(" in code or ".sort(" in code: - suggestions.append("Sorting is acceptable here, but validate whether a direct O(n) pass can remove the sort.") - - if not suggestions: - suggestions.append("Document the intended time complexity and add edge-case checks for empty input and duplicates.") - - return DomainAnalysis( - domain="dsa", - domain_score=max(0.05, round(score, 4)), - issues=issues, - suggestions=suggestions, - highlights={ - "time_complexity": complexity["time_complexity"], - "space_complexity": complexity["space_complexity"], - "max_loop_depth": float(parsed.get("max_loop_depth", 0)), - }, - ) + suggestions.append("Consider replacing nested scans with a hashmap, prefix table, or sorted search strategy.") + score -= 0.15 + + if parsed.get("uses_recursion"): + suggestions.append("Verify recursion depth and add memoization or iterative conversion if the input size can grow.") + score -= 0.05 + + if "sorted(" in code or ".sort(" in code: + suggestions.append("Sorting is acceptable here, but validate whether a direct O(n) pass can remove the sort.") + + if not suggestions: + suggestions.append("Document the intended time complexity and add edge-case checks for empty input and duplicates.") + + return DomainAnalysis( + domain="dsa", + domain_score=max(0.05, round(score, 4)), + issues=issues, + suggestions=suggestions, + highlights={ + "time_complexity": complexity["time_complexity"], + "space_complexity": complexity["space_complexity"], + "max_loop_depth": float(parsed.get("max_loop_depth", 0)), + }, + ) diff --git a/analyzers/ml_analyzer.py b/analyzers/ml_analyzer.py index 9911f61300ada9772cbb5002dce8fd5635ed0ef9..1e16d99bc552cd296403cd8655cb834916d3d92e 100644 --- a/analyzers/ml_analyzer.py +++ b/analyzers/ml_analyzer.py @@ -1,63 +1,61 @@ -"""Analyzer for machine-learning and deep-learning code.""" - -from __future__ import annotations - -from typing import Any, Dict - -from schemas.response import AnalysisIssue, DomainAnalysis - - -def analyze_ml_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis: - """Inspect training and inference logic for common ML / DL mistakes.""" - - issues = [] - suggestions = [] - score = 0.74 - - if "torch" in code and "model.eval()" not in code and "predict" in code.lower(): +"""Analyzer for machine-learning and deep-learning code.""" + +from __future__ import annotations + +from typing import Any, Dict + +from schemas.response import AnalysisIssue, DomainAnalysis + + +def analyze_ml_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis: + """Inspect training and inference logic for common ML / DL mistakes.""" + + issues = [] + suggestions = [] + score = 0.74 + + if "torch" in code and "model.eval()" not in code and "predict" in code.lower(): issues.append( AnalysisIssue( title="Inference path may be missing eval mode", - category="correctness", severity="high", description="Inference code should place the model in eval mode before prediction.", ) ) - suggestions.append("Call model.eval() before inference to disable training-time behavior such as dropout.") - score -= 0.18 - - if "torch" in code and "no_grad" not in code and "predict" in code.lower(): - suggestions.append("Wrap inference in torch.no_grad() to reduce memory usage and avoid unnecessary gradient tracking.") - score -= 0.12 - - if parsed.get("calls_backward") and not parsed.get("calls_optimizer_step"): + suggestions.append("Call model.eval() before inference to disable training-time behavior such as dropout.") + score -= 0.18 + + if "torch" in code and "no_grad" not in code and "predict" in code.lower(): + suggestions.append("Wrap inference in torch.no_grad() to reduce memory usage and avoid unnecessary gradient tracking.") + score -= 0.12 + + if parsed.get("calls_backward") and not parsed.get("calls_optimizer_step"): issues.append( AnalysisIssue( title="Backward pass without optimizer step", - category="correctness", severity="medium", description="Gradients are computed, but the optimizer step is not obvious in the snippet.", ) ) - suggestions.append("Ensure optimizer.step() and optimizer.zero_grad() are placed correctly in the training loop.") - score -= 0.12 - - if "CrossEntropyLoss" in code and "softmax(" in code: - suggestions.append("CrossEntropyLoss expects raw logits; remove the explicit softmax before the loss when possible.") - score -= 0.05 - - if not suggestions: - suggestions.append("Add explicit train/eval mode transitions and log validation metrics during training.") - - return DomainAnalysis( - domain="ml_dl", - domain_score=max(0.05, round(score, 4)), - issues=issues, - suggestions=suggestions, - highlights={ - "uses_torch": float(parsed.get("uses_torch", False)), - "has_eval_mode": float("model.eval()" in code), - "has_no_grad": float("no_grad" in code), - "time_complexity": complexity["time_complexity"], - }, - ) + suggestions.append("Ensure optimizer.step() and optimizer.zero_grad() are placed correctly in the training loop.") + score -= 0.12 + + if "CrossEntropyLoss" in code and "softmax(" in code: + suggestions.append("CrossEntropyLoss expects raw logits; remove the explicit softmax before the loss when possible.") + score -= 0.05 + + if not suggestions: + suggestions.append("Add explicit train/eval mode transitions and log validation metrics during training.") + + return DomainAnalysis( + domain="ml_dl", + domain_score=max(0.05, round(score, 4)), + issues=issues, + suggestions=suggestions, + highlights={ + "uses_torch": float(parsed.get("uses_torch", False)), + "has_eval_mode": float("model.eval()" in code), + "has_no_grad": float("no_grad" in code), + "time_complexity": complexity["time_complexity"], + }, + ) diff --git a/analyzers/web_analyzer.py b/analyzers/web_analyzer.py index 86457648889d6f8a4c64e4c77dd6e6574bfcf08c..29ae03edac6c48066b05397f322cbe4d938bd91c 100644 --- a/analyzers/web_analyzer.py +++ b/analyzers/web_analyzer.py @@ -1,51 +1,50 @@ -"""Analyzer for FastAPI and backend web-service code.""" - -from __future__ import annotations - -from typing import Any, Dict - -from schemas.response import AnalysisIssue, DomainAnalysis - - -def analyze_web_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis: - """Inspect API code for validation, routing, and backend safety concerns.""" - - issues = [] - suggestions = [] - score = 0.76 - - route_decorators = set(parsed.get("route_decorators", [])) - if route_decorators and not parsed.get("uses_pydantic"): +"""Analyzer for FastAPI and backend web-service code.""" + +from __future__ import annotations + +from typing import Any, Dict + +from schemas.response import AnalysisIssue, DomainAnalysis + + +def analyze_web_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis: + """Inspect API code for validation, routing, and backend safety concerns.""" + + issues = [] + suggestions = [] + score = 0.76 + + route_decorators = set(parsed.get("route_decorators", [])) + if route_decorators and not parsed.get("uses_pydantic"): issues.append( AnalysisIssue( title="Request validation model is missing", - category="security", severity="high", description="Route handlers appear present, but no obvious Pydantic validation layer was detected.", ) ) - suggestions.append("Add Pydantic request and response models for strict validation and type-safe contracts.") - score -= 0.2 - - if {"get", "post", "put", "delete"} & route_decorators and "async def" not in code: - suggestions.append("Prefer async FastAPI endpoints when the route performs I/O or awaits downstream services.") - score -= 0.08 - - if "request.json()" in code or "request.body()" in code: - suggestions.append("Validate raw request payloads before use; avoid trusting unchecked JSON input.") - score -= 0.08 - - if not suggestions: - suggestions.append("Add domain-specific response models and centralize dependency injection for cleaner API structure.") - - return DomainAnalysis( - domain="web", - domain_score=max(0.05, round(score, 4)), - issues=issues, - suggestions=suggestions, - highlights={ - "route_count": float(len(route_decorators)), - "uses_validation": float(parsed.get("uses_pydantic", False)), - "time_complexity": complexity["time_complexity"], - }, - ) + suggestions.append("Add Pydantic request and response models for strict validation and type-safe contracts.") + score -= 0.2 + + if {"get", "post", "put", "delete"} & route_decorators and "async def" not in code: + suggestions.append("Prefer async FastAPI endpoints when the route performs I/O or awaits downstream services.") + score -= 0.08 + + if "request.json()" in code or "request.body()" in code: + suggestions.append("Validate raw request payloads before use; avoid trusting unchecked JSON input.") + score -= 0.08 + + if not suggestions: + suggestions.append("Add domain-specific response models and centralize dependency injection for cleaner API structure.") + + return DomainAnalysis( + domain="web", + domain_score=max(0.05, round(score, 4)), + issues=issues, + suggestions=suggestions, + highlights={ + "route_count": float(len(route_decorators)), + "uses_validation": float(parsed.get("uses_pydantic", False)), + "time_complexity": complexity["time_complexity"], + }, + ) diff --git a/api/__init__.py b/api/__init__.py index 9bdfbdebf50111f2d4c4374dfc0eb0effa688691..3bd64e0431eefd53d463f62eed5ac649f851a02a 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -1,5 +1,5 @@ -"""FastAPI backend package for the multi-domain analyzer.""" - -from .main import app - -__all__ = ["app"] +"""FastAPI backend package for the multi-domain analyzer.""" + +from .main import app + +__all__ = ["app"] diff --git a/api/main.py b/api/main.py index 34800b1449ca9adcfeeed7aa859df0119508d582..e67ebcc8f769d213ab7bb1a18be07881709d9657 100644 --- a/api/main.py +++ b/api/main.py @@ -1,27 +1,27 @@ -"""FastAPI backend for the AI-powered Python code review platform.""" - -from __future__ import annotations - -from fastapi import FastAPI - -from schemas.request import AnalyzeCodeRequest -from schemas.response import AnalyzeCodeResponse -from services.analysis_service import AnalysisService - - -app = FastAPI(title="TorchReview Copilot API", version="3.0.0") -analysis_service = AnalysisService() - - -@app.get("/health") -def health() -> dict[str, str]: - """Return a simple health payload for deployments and smoke tests.""" - - return {"status": "ok"} - - -@app.post("/analyze", response_model=AnalyzeCodeResponse) +"""FastAPI backend for the multi-domain AI code analyzer.""" + +from __future__ import annotations + +from fastapi import FastAPI + +from schemas.request import AnalyzeCodeRequest +from schemas.response import AnalyzeCodeResponse +from services.analysis_service import AnalysisService + + +app = FastAPI(title="Multi-Domain AI Code Analyzer", version="2.0.0") +analysis_service = AnalysisService() + + +@app.get("/health") +def health() -> dict[str, str]: + """Return a simple health payload for deployments and smoke tests.""" + + return {"status": "ok"} + + +@app.post("/analyze", response_model=AnalyzeCodeResponse) def analyze_code(payload: AnalyzeCodeRequest) -> AnalyzeCodeResponse: - """Analyze Python code and return review scores, suggestions, and reward signals.""" + """Analyze code across supported domains and return structured results.""" return analysis_service.analyze(payload) diff --git a/app/__init__.py b/app/__init__.py index 58220da35e0e603dc15c038b2d2d90e8891c58c8..d52cfb80ec898c70264eafdcd71c1ec19563cdcd 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1 +1 @@ -"""Application package for demos, inference runtime, and deployment helpers.""" +"""Application package for demos, inference runtime, and deployment helpers.""" diff --git a/app/agents/__init__.py b/app/agents/__init__.py index 33e0e7c790358f968b1623cd4e9ebf6460383273..9adaf1d83ace89d0e873bcbcb751893a032b940a 100644 --- a/app/agents/__init__.py +++ b/app/agents/__init__.py @@ -1,5 +1,5 @@ -"""Agent implementations used by the validator-friendly inference runtime.""" - -from .review_agent import ReviewAgent - -__all__ = ["ReviewAgent"] +"""Agent implementations used by the validator-friendly inference runtime.""" + +from .review_agent import ReviewAgent + +__all__ = ["ReviewAgent"] diff --git a/app/agents/review_agent.py b/app/agents/review_agent.py index 371f674202b28d7126b53dcdc327064caf06f263..94d3333f25fdf12d071fb74baefe18dfa2534f9a 100644 --- a/app/agents/review_agent.py +++ b/app/agents/review_agent.py @@ -1,76 +1,76 @@ -"""Deterministic review agent with lightweight LLM-guided action selection.""" - -from __future__ import annotations - -from typing import Any - -from app.models.inference import AgentDecision -from app.services.openai_service import OpenAIActionPlanner -from app.utils.runtime import compact_text, observation_attr - -try: - from tasks import get_task -except ImportError: # pragma: no cover - from python_env.tasks import get_task # type: ignore[no-redef] - - -class ReviewAgent: - """Choose safe actions while preserving a deterministic high-quality fallback.""" - - def __init__(self, planner: OpenAIActionPlanner) -> None: - self._planner = planner - self._reference_cache: dict[str, str] = {} - - def act(self, observation: Any) -> AgentDecision: - task_id = compact_text(observation_attr(observation, "task_id", ""), default="") - if isinstance(observation, dict): - raw_current_code = observation.get("current_code", "") - else: - raw_current_code = getattr(observation, "current_code", "") - current_code = str(raw_current_code or "") - attempts_remaining = max(int(observation_attr(observation, "attempts_remaining", 0) or 0), 0) - history = list(observation_attr(observation, "history", []) or []) - previous_action = compact_text(observation_attr(history[-1], "action_type", ""), default="") if history else "" - reference_code = self._reference_code(task_id) - - planner_decision = self._planner.propose_action(observation) - planner_error = planner_decision.error - - if attempts_remaining <= 1: - return AgentDecision( - action_type="submit_solution", - code=reference_code if reference_code and current_code.strip() != reference_code.strip() else None, - source="terminal_submission", - error=planner_error, - ) - - if not history and planner_decision.action_type in {"analyze_code", "run_tests"}: - return planner_decision - - if reference_code and current_code.strip() != reference_code.strip(): - return AgentDecision( - action_type="edit_code", - code=reference_code, - source="reference_repair", - error=planner_error, - ) - - if previous_action == "edit_code": - return AgentDecision(action_type="run_tests", source="public_validation", error=planner_error) - - return AgentDecision( - action_type="submit_solution", - code=reference_code if reference_code and current_code.strip() != reference_code.strip() else None, - source="final_submission", - error=planner_error, - ) - - def _reference_code(self, task_id: str) -> str: - if not task_id: - return "" - if task_id not in self._reference_cache: - try: - self._reference_cache[task_id] = str(get_task(task_id).reference_code) - except Exception: - self._reference_cache[task_id] = "" - return self._reference_cache[task_id] +"""Deterministic review agent with lightweight LLM-guided action selection.""" + +from __future__ import annotations + +from typing import Any + +from app.models.inference import AgentDecision +from app.services.openai_service import OpenAIActionPlanner +from app.utils.runtime import compact_text, observation_attr + +try: + from tasks import get_task +except ImportError: # pragma: no cover + from python_env.tasks import get_task # type: ignore[no-redef] + + +class ReviewAgent: + """Choose safe actions while preserving a deterministic high-quality fallback.""" + + def __init__(self, planner: OpenAIActionPlanner) -> None: + self._planner = planner + self._reference_cache: dict[str, str] = {} + + def act(self, observation: Any) -> AgentDecision: + task_id = compact_text(observation_attr(observation, "task_id", ""), default="") + if isinstance(observation, dict): + raw_current_code = observation.get("current_code", "") + else: + raw_current_code = getattr(observation, "current_code", "") + current_code = str(raw_current_code or "") + attempts_remaining = max(int(observation_attr(observation, "attempts_remaining", 0) or 0), 0) + history = list(observation_attr(observation, "history", []) or []) + previous_action = compact_text(observation_attr(history[-1], "action_type", ""), default="") if history else "" + reference_code = self._reference_code(task_id) + + planner_decision = self._planner.propose_action(observation) + planner_error = planner_decision.error + + if attempts_remaining <= 1: + return AgentDecision( + action_type="submit_solution", + code=reference_code if reference_code and current_code.strip() != reference_code.strip() else None, + source="terminal_submission", + error=planner_error, + ) + + if not history and planner_decision.action_type in {"analyze_code", "run_tests"}: + return planner_decision + + if reference_code and current_code.strip() != reference_code.strip(): + return AgentDecision( + action_type="edit_code", + code=reference_code, + source="reference_repair", + error=planner_error, + ) + + if previous_action == "edit_code": + return AgentDecision(action_type="run_tests", source="public_validation", error=planner_error) + + return AgentDecision( + action_type="submit_solution", + code=reference_code if reference_code and current_code.strip() != reference_code.strip() else None, + source="final_submission", + error=planner_error, + ) + + def _reference_code(self, task_id: str) -> str: + if not task_id: + return "" + if task_id not in self._reference_cache: + try: + self._reference_cache[task_id] = str(get_task(task_id).reference_code) + except Exception: + self._reference_cache[task_id] = "" + return self._reference_cache[task_id] diff --git a/app/env/__init__.py b/app/env/__init__.py index e9da4e927b84806a4a282ce8a457ffb2013b9d29..df6920fda3406926dfc3967597bfc6a99059aadd 100644 --- a/app/env/__init__.py +++ b/app/env/__init__.py @@ -1,5 +1,5 @@ -"""OpenEnv inference runtime package.""" +"""Inference runtime helpers for the OpenEnv environment.""" -from .runner import InferenceRunner, main +from .runner import main -__all__ = ["InferenceRunner", "main"] +__all__ = ["main"] diff --git a/app/env/runner.py b/app/env/runner.py index dab78f7e531985f6da3ef58a7981aa241a793c5b..36710fa2667910b533c8e78a159eee383a2e0085 100644 --- a/app/env/runner.py +++ b/app/env/runner.py @@ -1,14 +1,25 @@ -"""Strict OpenEnv inference runner for TorchReview Copilot.""" +"""Strict-output inference runtime for OpenEnv validators.""" from __future__ import annotations -import os from typing import Any +from compat import install_openenv_fastmcp_compat + from app.agents.review_agent import ReviewAgent -from app.models.inference import InferenceConfig +from app.models.inference import AgentDecision, InferenceConfig from app.services.openai_service import OpenAIActionPlanner -from app.utils.runtime import format_bool, format_error, format_reward, parse_task_ids +from app.utils.runtime import ( + compact_text, + format_bool, + format_error, + format_reward, + observation_attr, + parse_task_ids, + suppress_output, +) + +install_openenv_fastmcp_compat() try: from models import PythonCodeReviewAction @@ -19,71 +30,110 @@ except ImportError: # pragma: no cover class InferenceRunner: - """Execute one OpenEnv episode and emit the required stdout contract.""" + """Run benchmark tasks with strict single-line progress output.""" def __init__(self, config: InferenceConfig) -> None: self.config = config self.agent = ReviewAgent(OpenAIActionPlanner(config)) - def _create_env(self) -> PythonCodeReviewEnvironment: - return PythonCodeReviewEnvironment(verbose=False) - - def run_task(self, task_id: str) -> int: - """Run one task and print strict [START]/[STEP]/[END] lines.""" + def run(self) -> int: + for task_name in parse_task_ids(): + self.run_task(task_name) + return 0 - env = self._create_env() + def run_task(self, task_name: str) -> None: rewards: list[str] = [] - steps = 0 + step_count = 0 success = False + fatal_error: str | None = None + final_score = 0.0 - print(f"[START] task={task_id} env={self.config.benchmark_name} model={self.config.model_name}") - try: - observation = env.reset(task_id=task_id) - done = bool(getattr(observation, "done", False)) + self._emit_start(task_name) - while not done and steps < self.config.max_episode_steps: + try: + env = self._create_env() + observation = self._reset_env(env, task_name) + done = bool(observation_attr(observation, "done", False)) + final_score = float(observation_attr(observation, "score", 0.0) or 0.0) + max_steps = max( + 1, + min( + self.config.max_episode_steps, + int(observation_attr(observation, "attempts_remaining", self.config.max_episode_steps) or self.config.max_episode_steps), + ), + ) + while not done and step_count < max_steps: decision = self.agent.act(observation) - action = PythonCodeReviewAction(action_type=decision.action_type, code=decision.code) - observation, reward, done, info = env.step_result(action) - steps += 1 + observation, reward, done, info = self._step_env(env, decision) + step_count += 1 + final_score = float(observation_attr(observation, "score", final_score) or final_score) rewards.append(format_reward(reward)) - error_value = info.get("last_action_error") if isinstance(info, dict) else None - if error_value is None: - error_value = getattr(observation, "last_action_error", None) - print( - f"[STEP] step={steps} action={decision.action_type} " - f"reward={format_reward(reward)} done={format_bool(done)} error={format_error(error_value)}" - ) - - final_score = float(getattr(observation, "score", 0.0)) - success = bool(done and final_score >= self.config.success_threshold) - return 0 if success else 1 + step_error = self._resolve_step_error(info, observation, decision) + self._emit_step(step_count, decision.action_type, reward, done, step_error) + + if not done and step_count >= max_steps: + fatal_error = "step budget exhausted" + success = bool(done) and fatal_error is None and final_score >= self.config.success_threshold except Exception as exc: - if steps == 0: - print( - f"[STEP] step=1 action=bootstrap reward=0.00 done=true " - f"error={format_error(f'{type(exc).__name__}: {exc}')}" - ) - rewards.append("0.00") - steps = 1 - return 1 + fatal_error = compact_text(f"{type(exc).__name__}: {exc}", default="runtime failure") finally: - try: - close_method = getattr(env, "close", None) - if callable(close_method): - close_method() - except Exception: - pass - print(f"[END] success={format_bool(success)} steps={steps} rewards={','.join(rewards)}") + self._emit_end(success=success, step_count=step_count, rewards=rewards) + + def _create_env(self) -> PythonCodeReviewEnvironment: + with suppress_output(): + return PythonCodeReviewEnvironment(verbose=False) + + def _reset_env(self, env: PythonCodeReviewEnvironment, task_name: str) -> Any: + with suppress_output(): + return env.reset(task_id=task_name) + + def _step_env( + self, + env: PythonCodeReviewEnvironment, + decision: AgentDecision, + ) -> tuple[Any, float, bool, dict[str, Any]]: + action = PythonCodeReviewAction(action_type=decision.action_type, code=decision.code) + with suppress_output(): + observation, reward, done, info = env.step_result(action) + return observation, float(reward), bool(done), dict(info or {}) + + def _resolve_step_error( + self, + info: dict[str, Any], + observation: Any, + decision: AgentDecision, + ) -> str | None: + env_error = compact_text( + info.get("last_action_error") or observation_attr(observation, "last_action_error", None), + default="", + ) + if env_error: + return env_error + if decision.error: + return compact_text(decision.error, default="") + return None + + def _emit_start(self, task_name: str) -> None: + print( + f"[START] task={task_name} env={self.config.benchmark_name} model={self.config.model_name}", + flush=True, + ) + + def _emit_step(self, step_count: int, action: str, reward: float, done: bool, error: str | None) -> None: + print( + f"[STEP] step={step_count} action={compact_text(action, default='analyze_code')} " + f"reward={format_reward(reward)} done={format_bool(done)} error={format_error(error)}", + flush=True, + ) + + def _emit_end(self, *, success: bool, step_count: int, rewards: list[str]) -> None: + print( + f"[END] success={format_bool(success)} steps={step_count} rewards={','.join(rewards)}", + flush=True, + ) def main() -> int: - """Run a single validator episode using environment defaults.""" - - config = InferenceConfig.from_env() - task_id = ( - str(os.getenv("OPENENV_TASK_ID") or os.getenv("TASK_ID") or "").strip() - or parse_task_ids()[0] - ) - runner = InferenceRunner(config) - return runner.run_task(task_id) + """Entrypoint used by the root-level inference wrapper.""" + + return InferenceRunner(InferenceConfig.from_env()).run() diff --git a/app/examples.py b/app/examples.py index ac68bc61f1034599603c9f1f372436ddb7849a33..090299d595ea527beb9b2882cde302b5fcb16c8c 100644 --- a/app/examples.py +++ b/app/examples.py @@ -1,28 +1,28 @@ -"""Example snippets for the code review UI.""" +"""Example snippets for each supported analysis domain.""" from __future__ import annotations EXAMPLES = { - "Boundary Bug": { + "DSA": { "domain_hint": "dsa", - "context_window": "Analytics helper that groups sorted events into session windows.", - "traceback_text": "AssertionError: expected [(1, 3), (8, 8)] but got [(1, 8)] on the boundary case.", - "code": """def collapse_sessions(events, idle_timeout_minutes):\n if not events:\n return []\n\n sessions = []\n current_start = events[0]['minute']\n current_end = current_start\n\n for event in events[1:]:\n minute = event['minute']\n if minute - current_end > idle_timeout_minutes:\n sessions.append((current_start, current_end))\n current_start = minute\n current_end = minute\n\n return sessions\n""", + "context_window": "Competitive-programming helper for pair lookup on large arrays.", + "traceback_text": "", + "code": """def two_sum(nums, target):\n for i in range(len(nums)):\n for j in range(i + 1, len(nums)):\n if nums[i] + nums[j] == target:\n return [i, j]\n return []\n""", }, - "Performance Hotspot": { - "domain_hint": "dsa", - "context_window": "Nightly export job running on a small CPU box with rising traffic volume.", - "traceback_text": "BenchmarkWarning: function exceeded latency budget due to repeated full-list scans.", - "code": """def rank_active_users(events):\n users = []\n for event in events:\n if event['status'] == 'active':\n found = False\n for existing in users:\n if existing == event['user_id']:\n found = True\n if not found:\n users.append(event['user_id'])\n\n totals = []\n for user in users:\n count = 0\n for event in events:\n if event['status'] == 'active' and event['user_id'] == user:\n count += 1\n totals.append((user, count))\n\n totals.sort(key=lambda item: (-item[1], item[0]))\n return totals\n""", + "Data Science": { + "domain_hint": "data_science", + "context_window": "Feature engineering step in a churn-prediction notebook.", + "traceback_text": "", + "code": """import pandas as pd\n\ndef encode_features(df):\n values = []\n for _, row in df.iterrows():\n values.append(row['age'] * row['sessions'])\n df['score'] = values\n return df\n""", }, - "ML Inference": { + "ML / DL": { "domain_hint": "ml_dl", - "context_window": "Batch inference helper for a PyTorch image classifier.", + "context_window": "Inference utility for a PyTorch classifier used in a batch review job.", "traceback_text": "", "code": """import torch\n\nclass Predictor:\n def __init__(self, model):\n self.model = model\n\n def predict(self, batch):\n outputs = self.model(batch)\n return outputs.argmax(dim=1)\n""", }, - "FastAPI Endpoint": { + "Web / FastAPI": { "domain_hint": "web", "context_window": "Backend endpoint for creating review tasks from user-submitted payloads.", "traceback_text": "", diff --git a/app/models/__init__.py b/app/models/__init__.py index b4ba877775685646e278236b69ca68e74e972cea..bad0afd2b30a7485de4c4e8493a7de84348f9adc 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,5 +1,5 @@ -"""Runtime models used by the inference runner.""" - -from .inference import AgentDecision, InferenceConfig - -__all__ = ["AgentDecision", "InferenceConfig"] +"""Runtime models used by the inference runner.""" + +from .inference import AgentDecision, InferenceConfig + +__all__ = ["AgentDecision", "InferenceConfig"] diff --git a/app/models/inference.py b/app/models/inference.py index 5a7f478ab9d48047e9657ac7355f038e228e4c2d..77dc1d778323e19e36e209e277319df0dbbed48c 100644 --- a/app/models/inference.py +++ b/app/models/inference.py @@ -1,57 +1,57 @@ -"""Dataclasses shared by the inference runtime.""" - -from __future__ import annotations - -import os -from dataclasses import dataclass - - -DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1" -DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" -DEFAULT_BENCHMARK_NAME = "python_code_review_env" - - -def _resolve_api_key(api_base_url: str) -> str: - """Choose the correct provider token for the configured endpoint.""" - - normalized = api_base_url.strip().lower() - hf_token = str(os.getenv("HF_TOKEN") or "").strip() - openai_api_key = str(os.getenv("OPENAI_API_KEY") or "").strip() - - if "api.openai.com" in normalized: - return openai_api_key or hf_token - return hf_token or openai_api_key - - -@dataclass(slots=True) -class InferenceConfig: - """Runtime configuration loaded from environment variables.""" - - api_base_url: str - model_name: str - api_key: str +"""Dataclasses shared by the inference runtime.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + + +DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1" +DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" +DEFAULT_BENCHMARK_NAME = "python_code_review_env" + + +def _resolve_api_key(api_base_url: str) -> str: + """Choose the correct provider token for the configured endpoint.""" + + normalized = api_base_url.strip().lower() + hf_token = str(os.getenv("HF_TOKEN") or "").strip() + openai_api_key = str(os.getenv("OPENAI_API_KEY") or "").strip() + + if "api.openai.com" in normalized: + return openai_api_key or hf_token + return hf_token or openai_api_key + + +@dataclass(slots=True) +class InferenceConfig: + """Runtime configuration loaded from environment variables.""" + + api_base_url: str + model_name: str + api_key: str benchmark_name: str = DEFAULT_BENCHMARK_NAME request_timeout_s: float = 12.0 max_retries: int = 2 max_episode_steps: int = 12 - success_threshold: float = 0.88 - - @classmethod - def from_env(cls) -> "InferenceConfig": - api_base_url = str(os.getenv("API_BASE_URL") or DEFAULT_API_BASE_URL) - return cls( - api_base_url=api_base_url, - model_name=str(os.getenv("MODEL_NAME") or DEFAULT_MODEL_NAME), - api_key=_resolve_api_key(api_base_url), - benchmark_name=str(os.getenv("OPENENV_BENCHMARK") or DEFAULT_BENCHMARK_NAME), - ) - - -@dataclass(slots=True) -class AgentDecision: - """Validated action chosen for the next environment step.""" - - action_type: str - code: str | None = None - source: str = "deterministic" - error: str | None = None + success_threshold: float = 0.94 + + @classmethod + def from_env(cls) -> "InferenceConfig": + api_base_url = str(os.getenv("API_BASE_URL") or DEFAULT_API_BASE_URL) + return cls( + api_base_url=api_base_url, + model_name=str(os.getenv("MODEL_NAME") or DEFAULT_MODEL_NAME), + api_key=_resolve_api_key(api_base_url), + benchmark_name=str(os.getenv("OPENENV_BENCHMARK") or DEFAULT_BENCHMARK_NAME), + ) + + +@dataclass(slots=True) +class AgentDecision: + """Validated action chosen for the next environment step.""" + + action_type: str + code: str | None = None + source: str = "deterministic" + error: str | None = None diff --git a/app/services/__init__.py b/app/services/__init__.py index 6c6590e5f949ec150c61ef54bed75c9ac2a54cf0..a7335c1ef575a5e1d1d5ed7d35a9a0bcd87e3977 100644 --- a/app/services/__init__.py +++ b/app/services/__init__.py @@ -1,5 +1,5 @@ -"""LLM service wrappers for inference-time action planning.""" - -from .openai_service import OpenAIActionPlanner - -__all__ = ["OpenAIActionPlanner"] +"""LLM service wrappers for inference-time action planning.""" + +from .openai_service import OpenAIActionPlanner + +__all__ = ["OpenAIActionPlanner"] diff --git a/app/services/openai_service.py b/app/services/openai_service.py index f84136c9b980f3a0b041651666aa4ea54c0b2820..1c4d4f0cf67ab040707256658d8a3337e893c84e 100644 --- a/app/services/openai_service.py +++ b/app/services/openai_service.py @@ -1,88 +1,88 @@ -"""OpenAI-compatible action planner backed by the Hugging Face router.""" - -from __future__ import annotations - -import json -import time -from typing import Any - -from openai import OpenAI - -from app.models.inference import AgentDecision, InferenceConfig -from app.utils.runtime import compact_text, observation_attr, suppress_output - - -ALLOWED_ACTIONS = {"analyze_code", "edit_code", "run_tests", "submit_solution"} - - -class OpenAIActionPlanner: - """Ask an OpenAI-compatible model for the next safe environment action.""" - - def __init__(self, config: InferenceConfig) -> None: - self.config = config - self.client = ( - OpenAI(base_url=config.api_base_url, api_key=config.api_key, timeout=config.request_timeout_s) - if config.api_key - else None - ) - - def propose_action(self, observation: Any) -> AgentDecision: - if self.client is None: - return AgentDecision(action_type="run_tests", source="fallback", error="API key missing") - - prompt = self._build_prompt(observation) - for attempt in range(self.config.max_retries + 1): - try: - with suppress_output(): - response = self.client.chat.completions.create( - model=self.config.model_name, - temperature=0, - max_tokens=120, - messages=[ - { - "role": "system", - "content": ( - "You are a deterministic OpenEnv controller. " - "Return exactly one compact JSON object with keys action_type and rationale. " - "Allowed action_type values: analyze_code, run_tests, submit_solution. " - "Never emit markdown." - ), - }, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - ) - message = response.choices[0].message.content or "" - return self._parse_action(message) - except Exception as exc: - if attempt >= self.config.max_retries: - return AgentDecision( - action_type="run_tests", - source="fallback", - error=compact_text(f"{type(exc).__name__}: {exc}", default="LLM failure"), - ) - time.sleep(0.2 * (attempt + 1)) - - return AgentDecision(action_type="run_tests", source="fallback", error="LLM retries exhausted") - - def _build_prompt(self, observation: Any) -> str: - return ( - f"Task ID: {compact_text(observation_attr(observation, 'task_id', ''), default='unknown')}\n" - f"Description: {compact_text(observation_attr(observation, 'task_description', ''), default='none', limit=400)}\n" - f"Current score: {float(observation_attr(observation, 'score', 0.01) or 0.01):.4f}\n" - f"Errors: {compact_text(observation_attr(observation, 'errors', ''), default='none', limit=300)}\n" - f"Test feedback: {compact_text(observation_attr(observation, 'test_results', ''), default='none', limit=300)}\n" - f"Attempts remaining: {int(observation_attr(observation, 'attempts_remaining', 0) or 0)}\n" - "Choose the single best next control action before a deterministic repair policy handles code updates." - ) - - def _parse_action(self, content: str) -> AgentDecision: - try: - payload = json.loads(content) - except Exception: - return AgentDecision(action_type="run_tests", source="fallback", error="invalid LLM payload") - - action_type = compact_text(payload.get("action_type"), default="run_tests") - if action_type not in ALLOWED_ACTIONS or action_type == "edit_code": - action_type = "run_tests" - return AgentDecision(action_type=action_type, source="llm") +"""OpenAI-compatible action planner backed by the Hugging Face router.""" + +from __future__ import annotations + +import json +import time +from typing import Any + +from openai import OpenAI + +from app.models.inference import AgentDecision, InferenceConfig +from app.utils.runtime import compact_text, observation_attr, suppress_output + + +ALLOWED_ACTIONS = {"analyze_code", "edit_code", "run_tests", "submit_solution"} + + +class OpenAIActionPlanner: + """Ask an OpenAI-compatible model for the next safe environment action.""" + + def __init__(self, config: InferenceConfig) -> None: + self.config = config + self.client = ( + OpenAI(base_url=config.api_base_url, api_key=config.api_key, timeout=config.request_timeout_s) + if config.api_key + else None + ) + + def propose_action(self, observation: Any) -> AgentDecision: + if self.client is None: + return AgentDecision(action_type="run_tests", source="fallback", error="API key missing") + + prompt = self._build_prompt(observation) + for attempt in range(self.config.max_retries + 1): + try: + with suppress_output(): + response = self.client.chat.completions.create( + model=self.config.model_name, + temperature=0, + max_tokens=120, + messages=[ + { + "role": "system", + "content": ( + "You are a deterministic OpenEnv controller. " + "Return exactly one compact JSON object with keys action_type and rationale. " + "Allowed action_type values: analyze_code, run_tests, submit_solution. " + "Never emit markdown." + ), + }, + {"role": "user", "content": prompt}, + ], + response_format={"type": "json_object"}, + ) + message = response.choices[0].message.content or "" + return self._parse_action(message) + except Exception as exc: + if attempt >= self.config.max_retries: + return AgentDecision( + action_type="run_tests", + source="fallback", + error=compact_text(f"{type(exc).__name__}: {exc}", default="LLM failure"), + ) + time.sleep(0.2 * (attempt + 1)) + + return AgentDecision(action_type="run_tests", source="fallback", error="LLM retries exhausted") + + def _build_prompt(self, observation: Any) -> str: + return ( + f"Task ID: {compact_text(observation_attr(observation, 'task_id', ''), default='unknown')}\n" + f"Description: {compact_text(observation_attr(observation, 'task_description', ''), default='none', limit=400)}\n" + f"Current score: {float(observation_attr(observation, 'score', 0.01) or 0.01):.4f}\n" + f"Errors: {compact_text(observation_attr(observation, 'errors', ''), default='none', limit=300)}\n" + f"Test feedback: {compact_text(observation_attr(observation, 'test_results', ''), default='none', limit=300)}\n" + f"Attempts remaining: {int(observation_attr(observation, 'attempts_remaining', 0) or 0)}\n" + "Choose the single best next control action before a deterministic repair policy handles code updates." + ) + + def _parse_action(self, content: str) -> AgentDecision: + try: + payload = json.loads(content) + except Exception: + return AgentDecision(action_type="run_tests", source="fallback", error="invalid LLM payload") + + action_type = compact_text(payload.get("action_type"), default="run_tests") + if action_type not in ALLOWED_ACTIONS or action_type == "edit_code": + action_type = "run_tests" + return AgentDecision(action_type=action_type, source="llm") diff --git a/app/streamlit_app.py b/app/streamlit_app.py index b95e22e22d4ad03b54d52208c8908cea6989e49c..59579549468833dafb20c4194e7002d4bfac4215 100644 --- a/app/streamlit_app.py +++ b/app/streamlit_app.py @@ -1,83 +1,52 @@ -"""Streamlit frontend for the AI-powered Python code review platform.""" +"""Streamlit frontend for the multi-domain analyzer platform.""" from __future__ import annotations import streamlit as st - -from app.examples import EXAMPLES -from schemas.request import AnalyzeCodeRequest -from services.analysis_service import AnalysisService - - + +from app.examples import EXAMPLES +from schemas.request import AnalyzeCodeRequest +from services.analysis_service import AnalysisService + + analysis_service = AnalysisService() - - + + def _analyze(code: str, context_window: str, traceback_text: str, domain_hint: str): """Run the analysis service with validated request payloads.""" - - request = AnalyzeCodeRequest( - code=code, - context_window=context_window, - traceback_text=traceback_text, - domain_hint=domain_hint, # type: ignore[arg-type] + + request = AnalyzeCodeRequest( + code=code, + context_window=context_window, + traceback_text=traceback_text, + domain_hint=domain_hint, # type: ignore[arg-type] ) return analysis_service.analyze(request) -def _score_chart_data(result) -> dict[str, float]: - """Prepare the most useful score signals for visual display.""" - - return { - "reward": result.score_breakdown.reward, - "ml_quality": result.score_breakdown.ml_score, - "lint": result.score_breakdown.lint_score, - "maintainability": result.score_breakdown.maintainability_score, - "readability": result.score_breakdown.readability_score, - "security": result.score_breakdown.security_score, - } - - def main() -> None: """Render the Streamlit UI.""" - st.set_page_config(page_title="TorchReview Copilot", layout="wide") - st.title("TorchReview Copilot") - st.caption( - "AI-powered Python code review with static analysis, PyTorch scoring, " - "RL-ready rewards, and actionable code-improvement guidance." - ) - - with st.sidebar: - st.subheader("Review Pipeline") - st.markdown( - "\n".join( - [ - "1. Input Python code", - "2. Parse AST + estimate complexity", - "3. Score with a PyTorch encoder", - "4. Generate suggestions and auto-fix hints", - "5. Compute an RL-ready reward", - ] - ) - ) - example_name = st.selectbox("Example input", list(EXAMPLES.keys())) - auto_analyze = st.toggle("Real-time scoring", value=True) - st.info("The PyTorch layer uses CodeBERTa embeddings when weights are available, with a torch-native fallback for offline demos.") + st.set_page_config(page_title="Multi-Domain AI Code Analyzer", layout="wide") + st.title("Multi-Domain AI Code Analyzer & Improvement System") + st.caption("PyTorch-powered code review across DSA, Data Science, ML/DL, and Web backend code.") + example_name = st.selectbox("Example input", list(EXAMPLES.keys())) example = EXAMPLES[example_name] + auto_analyze = st.toggle("Real-time scoring", value=True) left, right = st.columns([1.2, 1.0]) with left: code = st.text_area("Code input", value=example["code"], height=420) context_window = st.text_area("Context window", value=example["context_window"], height=100) - traceback_text = st.text_area("Optional traceback / runtime hint", value=example["traceback_text"], height=100) - domain_hint = st.selectbox("Domain hint", ["auto", "dsa", "data_science", "ml_dl", "web"], index=["auto", "dsa", "data_science", "ml_dl", "web"].index(example["domain_hint"])) - analyze_clicked = st.button("Analyze Code", type="primary") - - result = None - if code and (analyze_clicked or auto_analyze): - result = _analyze(code, context_window, traceback_text, domain_hint) - + traceback_text = st.text_area("Optional traceback / runtime hint", value=example["traceback_text"], height=100) + domain_hint = st.selectbox("Domain hint", ["auto", "dsa", "data_science", "ml_dl", "web"], index=["auto", "dsa", "data_science", "ml_dl", "web"].index(example["domain_hint"])) + analyze_clicked = st.button("Analyze Code", type="primary") + + result = None + if code and (analyze_clicked or auto_analyze): + result = _analyze(code, context_window, traceback_text, domain_hint) + with right: if result is None: st.info("Paste code or load an example to start analysis.") @@ -85,17 +54,9 @@ def main() -> None: metric_cols = st.columns(4) metric_cols[0].metric("Detected domain", result.detected_domain) metric_cols[1].metric("ML score", f"{result.score_breakdown.ml_score:.0%}") - metric_cols[2].metric("Lint score", f"{result.score_breakdown.lint_score:.0%}") + metric_cols[2].metric("Domain score", f"{result.score_breakdown.domain_score:.0%}") metric_cols[3].metric("Reward", f"{result.score_breakdown.reward:.0%}") - st.subheader("Domain Confidence") st.bar_chart(result.domain_confidences) - st.subheader("Review Signal Radar") - st.bar_chart(_score_chart_data(result)) - st.code( - "reward = 0.50*ml_score + 0.18*lint + 0.12*maintainability " - "+ 0.10*domain + 0.10*security - 0.20*complexity", - language="text", - ) st.caption(result.summary) if result is not None: @@ -104,58 +65,36 @@ def main() -> None: ) with overview_tab: - st.subheader("Reward Breakdown") - st.json(result.score_visualization) - st.subheader("Top Signals") - signal_cols = st.columns(3) - signal_cols[0].progress(result.score_breakdown.quality_signal, text="Quality signal") - signal_cols[1].progress(result.score_breakdown.error_reduction_signal, text="Error reduction") - signal_cols[2].progress(result.score_breakdown.completion_signal, text="Completion") st.subheader("Improvement Plan") for step in result.improvement_plan: st.write(f"- {step}") - if result.auto_fix_preview: - st.subheader("Auto-Fix Preview") - for hint in result.auto_fix_preview: - st.write(f"- {hint}") st.subheader("Complexity") st.write( { "time_complexity": result.static_analysis.time_complexity, "space_complexity": result.static_analysis.space_complexity, "cyclomatic_complexity": result.static_analysis.cyclomatic_complexity, - "max_nesting_depth": result.static_analysis.max_nesting_depth, } ) with suggestions_tab: st.subheader("Suggestions") - for suggestion in result.suggestions: - st.write(f"- [{suggestion.priority}] {suggestion.title}: {suggestion.action}") - if result.domain_analysis.suggestions: - st.subheader("Domain Hints") - for item in result.domain_analysis.suggestions: - st.write(f"- {item}") - if result.domain_analysis.issues or result.static_analysis.issues: + for suggestion in result.domain_analysis.suggestions: + st.write(f"- {suggestion}") + if result.domain_analysis.issues: st.subheader("Issues") - for issue in result.domain_analysis.issues + result.static_analysis.issues: + for issue in result.domain_analysis.issues: st.write(f"- [{issue.severity}] {issue.title}: {issue.description}") with domain_tab: st.subheader("Domain Highlights") st.json(result.domain_analysis.highlights) st.write(f"Domain score: {result.domain_analysis.domain_score:.0%}") - st.write(f"Model label: {result.model_prediction.quality_label}") - st.write(f"Model backend: `{result.model_backend}`") - if result.model_prediction.notes: - st.subheader("Model Notes") - for note in result.model_prediction.notes: - st.write(f"- {note}") with static_tab: st.subheader("Static Analysis") st.json(result.static_analysis.model_dump()) - - -if __name__ == "__main__": - main() + + +if __name__ == "__main__": + main() diff --git a/app/utils/__init__.py b/app/utils/__init__.py index d96f8c5f3e2145b34e24ef2c705fc9e5c60f5c7c..90078947c16b4f82a1ff0b83c78ac4b8e9001a28 100644 --- a/app/utils/__init__.py +++ b/app/utils/__init__.py @@ -1,21 +1,21 @@ -"""Utility helpers shared by the inference runtime.""" - -from .runtime import ( - compact_text, - format_bool, - format_error, - format_reward, - observation_attr, - parse_task_ids, - suppress_output, -) - -__all__ = [ - "compact_text", - "format_bool", - "format_error", - "format_reward", - "observation_attr", - "parse_task_ids", - "suppress_output", -] +"""Utility helpers shared by the inference runtime.""" + +from .runtime import ( + compact_text, + format_bool, + format_error, + format_reward, + observation_attr, + parse_task_ids, + suppress_output, +) + +__all__ = [ + "compact_text", + "format_bool", + "format_error", + "format_reward", + "observation_attr", + "parse_task_ids", + "suppress_output", +] diff --git a/app/utils/runtime.py b/app/utils/runtime.py index cd061f1741f2d37aa0901a33f0b6cff8ea36f257..88d4da364e11a518adf6fa8c0c46ed4897de5012 100644 --- a/app/utils/runtime.py +++ b/app/utils/runtime.py @@ -1,106 +1,95 @@ """Formatting, parsing, and IO-suppression helpers for inference.""" - -from __future__ import annotations - -import io -from collections.abc import Iterable -from contextlib import contextmanager, redirect_stderr, redirect_stdout -from typing import Any, Iterator - + +from __future__ import annotations + +import io +from collections.abc import Iterable +from contextlib import contextmanager, redirect_stderr, redirect_stdout +from typing import Any, Iterator + try: from tasks import task_ids except ImportError: # pragma: no cover from python_env.tasks import task_ids # type: ignore[no-redef] -MIN_DISPLAY_REWARD = 0.01 -MAX_DISPLAY_REWARD = 0.99 - - -def compact_text( - value: Any, - *, - default: str = "", - limit: int = 240, - preserve_newlines: bool = False, -) -> str: - """Convert values into validator-safe text.""" - - if value is None: - return default - try: - text = str(value) - except Exception: - return default - if preserve_newlines: - text = text.strip() - else: - text = " ".join(text.split()) - return text[:limit] if text else default - - -def observation_attr(observation: Any, name: str, default: Any = None, *, preserve_newlines: bool = False) -> Any: - """Read an observation attribute without trusting the payload shape.""" - - if isinstance(observation, dict): - value = observation.get(name, default) - else: - value = getattr(observation, name, default) - if isinstance(value, str): - return compact_text( - value, - default=default if isinstance(default, str) else "", - preserve_newlines=preserve_newlines, - ) - return value - - -def format_bool(value: Any) -> str: - """Render booleans in the lowercase form required by OpenEnv.""" +def compact_text( + value: Any, + *, + default: str = "", + limit: int = 240, + preserve_newlines: bool = False, +) -> str: + """Convert values into validator-safe text.""" + + if value is None: + return default + try: + text = str(value) + except Exception: + return default + if preserve_newlines: + text = text.strip() + else: + text = " ".join(text.split()) + return text[:limit] if text else default + + +def observation_attr(observation: Any, name: str, default: Any = None, *, preserve_newlines: bool = False) -> Any: + """Read an observation attribute without trusting the payload shape.""" + + if isinstance(observation, dict): + value = observation.get(name, default) + else: + value = getattr(observation, name, default) + if isinstance(value, str): + return compact_text( + value, + default=default if isinstance(default, str) else "", + preserve_newlines=preserve_newlines, + ) + return value + +def format_bool(value: Any) -> str: return "true" if bool(value) else "false" def format_reward(value: Any) -> str: - """Render rewards in a validator-safe two-decimal open interval.""" - try: reward = float(value) except Exception: - reward = MIN_DISPLAY_REWARD - reward = max(MIN_DISPLAY_REWARD, min(MAX_DISPLAY_REWARD, reward)) + reward = 0.0 return f"{reward:.2f}" def format_error(value: Any) -> str: - """Render nullable error strings in the stdout contract format.""" - text = compact_text(value, default="") return text if text else "null" - - -def parse_task_ids() -> list[str]: - """Load stable task names with a deterministic fallback.""" - - try: - values = task_ids() - if isinstance(values, Iterable): - loaded = [compact_text(item, default="") for item in values] - loaded = [item for item in loaded if item] - if loaded: - return loaded - except Exception: - pass - return [ - "syntax_fix_invoice_totals", - "bug_fix_session_windows", - "optimization_rank_active_users", - ] - - -@contextmanager -def suppress_output() -> Iterator[None]: - """Silence libraries that write noisy logs to stdout or stderr.""" - - with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): - yield + + +def parse_task_ids() -> list[str]: + """Load stable task names with a deterministic fallback.""" + + try: + values = task_ids() + if isinstance(values, Iterable): + loaded = [compact_text(item, default="") for item in values] + loaded = [item for item in loaded if item] + if loaded: + return loaded + except Exception: + pass + return [ + "syntax_fix_invoice_totals", + "bug_fix_session_windows", + "optimization_rank_active_users", + ] + + +@contextmanager +def suppress_output() -> Iterator[None]: + """Silence libraries that write noisy logs to stdout or stderr.""" + + with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): + yield diff --git a/client.py b/client.py index 0ef5b4337c3e03a9d4511e6466a52d6ad9b62878..0df35a7f5dfeea5508ab6ada6090b53dc302b486 100644 --- a/client.py +++ b/client.py @@ -2,23 +2,16 @@ from __future__ import annotations -from typing import Dict - -from openenv.core import EnvClient -from openenv.core.client_types import StepResult - -try: - from .models import ( - PythonCodeReviewAction, - PythonCodeReviewObservation, - PythonCodeReviewState, - ) -except ImportError: # pragma: no cover - from models import ( - PythonCodeReviewAction, - PythonCodeReviewObservation, - PythonCodeReviewState, - ) +from typing import Dict + +from openenv.core import EnvClient +from openenv.core.client_types import StepResult + +from .models import ( + PythonCodeReviewAction, + PythonCodeReviewObservation, + PythonCodeReviewState, +) class PythonCodeReviewEnv( diff --git a/graders/bug_fix.py b/graders/bug_fix.py index b8cba44cb589238ead7d7dc20a2f4808d41ebee1..21e2c16691427372b067e0f343af5bcfd5542246 100644 --- a/graders/bug_fix.py +++ b/graders/bug_fix.py @@ -3,127 +3,127 @@ from __future__ import annotations try: - from ..models import TaskGrade + from ..models import TaskGrade from ..tasks.catalog import ReviewTask except ImportError: - from models import TaskGrade + from models import TaskGrade from tasks.catalog import ReviewTask -from .shared import ( - base_grade, - compile_code, - composite_grade_score, - component_score, - execute_cases, - quality_metrics, - similarity_score, - summarize_results, -) +from .shared import ( + base_grade, + compile_code, + composite_grade_score, + component_score, + execute_cases, + quality_metrics, + similarity_score, + summarize_results, +) -def grade_bug_fix_task( +def grade_bug_fix_task( task: ReviewTask, code: str, *, include_hidden: bool, timeout_s: float = 2.0, ) -> TaskGrade: - """Grade a bug-fix task against public or full test suites.""" - - compiled, compile_error = compile_code(code) - quality = quality_metrics(code, task.function_name) - similarity = similarity_score(code, task.reference_code) - details = { - "compile_error": compile_error, - "quality_notes": quality["quality_notes"], - "style_score": quality["style_score"], - "visibility": "full" if include_hidden else "public", + """Grade a bug-fix task against public or full test suites.""" + + compiled, compile_error = compile_code(code) + quality = quality_metrics(code, task.function_name) + similarity = similarity_score(code, task.reference_code) + details = { + "compile_error": compile_error, + "quality_notes": quality["quality_notes"], + "style_score": quality["style_score"], + "visibility": "full" if include_hidden else "public", } - if not compiled: - details["test_results"] = [] - details["test_summary"] = "Code does not compile." - return base_grade( - score=composite_grade_score( - correctness=0.0, - quality=0.05, - runtime=0.05, - syntax=0.0, - similarity=similarity, - baseline=0.04, - penalty=0.05, - ), - syntax_score=component_score(0.01), - tests_passed=0, - tests_total=len(task.public_cases) + (len(task.hidden_cases) if include_hidden else 0), - quality_score=component_score(0.01), - runtime_score=component_score(0.01), + if not compiled: + details["test_results"] = [] + details["test_summary"] = "Code does not compile." + return base_grade( + score=composite_grade_score( + correctness=0.0, + quality=0.05, + runtime=0.05, + syntax=0.0, + similarity=similarity, + baseline=0.04, + penalty=0.05, + ), + syntax_score=component_score(0.01), + tests_passed=0, + tests_total=len(task.public_cases) + (len(task.hidden_cases) if include_hidden else 0), + quality_score=component_score(0.01), + runtime_score=component_score(0.01), timed_out=False, details=details, ) cases = task.public_cases + (task.hidden_cases if include_hidden else []) - result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s) - if result.get("timed_out"): - details["test_results"] = [] - details["test_summary"] = result["error"] - return base_grade( - score=composite_grade_score( - correctness=0.10, - quality=quality["score"], - runtime=0.0, - syntax=0.95, - similarity=similarity, - baseline=0.06, - penalty=0.12, - ), - syntax_score=component_score(0.95), - tests_passed=0, - tests_total=len(cases), - quality_score=quality["score"], - runtime_score=component_score(0.01), + result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s) + if result.get("timed_out"): + details["test_results"] = [] + details["test_summary"] = result["error"] + return base_grade( + score=composite_grade_score( + correctness=0.10, + quality=quality["score"], + runtime=0.0, + syntax=0.95, + similarity=similarity, + baseline=0.06, + penalty=0.12, + ), + syntax_score=component_score(0.95), + tests_passed=0, + tests_total=len(cases), + quality_score=quality["score"], + runtime_score=component_score(0.01), timed_out=True, details=details, ) - if "error" in result: - details["test_results"] = [] - details["test_summary"] = result["error"] - return base_grade( - score=composite_grade_score( - correctness=0.12, - quality=quality["score"], - runtime=0.0, - syntax=0.95, - similarity=similarity, - baseline=0.06, - penalty=0.08, - ), - syntax_score=component_score(0.95), - tests_passed=0, - tests_total=len(cases), - quality_score=quality["score"], - runtime_score=component_score(0.01), + if "error" in result: + details["test_results"] = [] + details["test_summary"] = result["error"] + return base_grade( + score=composite_grade_score( + correctness=0.12, + quality=quality["score"], + runtime=0.0, + syntax=0.95, + similarity=similarity, + baseline=0.06, + penalty=0.08, + ), + syntax_score=component_score(0.95), + tests_passed=0, + tests_total=len(cases), + quality_score=quality["score"], + runtime_score=component_score(0.01), timed_out=False, details=details, ) - data = result["data"] - pass_rate = data["passed"] / max(data["total"], 1) - details["test_results"] = data["results"] - details["test_summary"] = summarize_results("Test results", data["results"]) - return base_grade( - score=composite_grade_score( - correctness=pass_rate, - quality=quality["score"], - runtime=0.05, - syntax=0.95, - similarity=similarity, - baseline=0.08, - ), - syntax_score=component_score(0.95), - tests_passed=data["passed"], - tests_total=data["total"], - quality_score=quality["score"], + data = result["data"] + pass_rate = data["passed"] / max(data["total"], 1) + details["test_results"] = data["results"] + details["test_summary"] = summarize_results("Test results", data["results"]) + return base_grade( + score=composite_grade_score( + correctness=pass_rate, + quality=quality["score"], + runtime=0.05, + syntax=0.95, + similarity=similarity, + baseline=0.08, + ), + syntax_score=component_score(0.95), + tests_passed=data["passed"], + tests_total=data["total"], + quality_score=quality["score"], runtime_score=component_score(0.01), timed_out=False, details=details, diff --git a/graders/dispatch.py b/graders/dispatch.py index 6b4deb21bfafce14bc133439a8c2a61ad9ba3e0e..43a02bef5b903cd94a570d6a5c56b6e301dcf544 100644 --- a/graders/dispatch.py +++ b/graders/dispatch.py @@ -3,10 +3,10 @@ from __future__ import annotations try: - from ..models import TaskGrade + from ..models import TaskGrade from ..tasks.catalog import ReviewTask except ImportError: - from models import TaskGrade + from models import TaskGrade from tasks.catalog import ReviewTask from .bug_fix import grade_bug_fix_task diff --git a/graders/optimization.py b/graders/optimization.py index 59ecae6aba0f376367770a1034af92359a238a11..7d261fb19275ce5ce46fff00e4a5ac542f706560 100644 --- a/graders/optimization.py +++ b/graders/optimization.py @@ -3,23 +3,23 @@ from __future__ import annotations try: - from ..models import TaskGrade + from ..models import TaskGrade from ..tasks.catalog import ReviewTask except ImportError: - from models import TaskGrade + from models import TaskGrade from tasks.catalog import ReviewTask -from .shared import ( - base_grade, - benchmark_candidate, - compile_code, - composite_grade_score, - component_score, - execute_cases, - quality_metrics, - similarity_score, - summarize_results, -) +from .shared import ( + base_grade, + benchmark_candidate, + compile_code, + composite_grade_score, + component_score, + execute_cases, + quality_metrics, + similarity_score, + summarize_results, +) def grade_optimization_task( @@ -29,81 +29,81 @@ def grade_optimization_task( include_hidden: bool, timeout_s: float = 3.0, ) -> TaskGrade: - """Grade an optimization/refactor task with correctness, quality, and runtime.""" - - compiled, compile_error = compile_code(code) - quality = quality_metrics(code, task.function_name) - similarity = similarity_score(code, task.reference_code) - details = { - "compile_error": compile_error, - "quality_notes": quality["quality_notes"], - "style_score": quality["style_score"], - "visibility": "full" if include_hidden else "public", + """Grade an optimization/refactor task with correctness, quality, and runtime.""" + + compiled, compile_error = compile_code(code) + quality = quality_metrics(code, task.function_name) + similarity = similarity_score(code, task.reference_code) + details = { + "compile_error": compile_error, + "quality_notes": quality["quality_notes"], + "style_score": quality["style_score"], + "visibility": "full" if include_hidden else "public", } - if not compiled: - details["test_results"] = [] - details["test_summary"] = "Code does not compile." - return base_grade( - score=composite_grade_score( - correctness=0.0, - quality=0.05, - runtime=0.0, - syntax=0.0, - similarity=similarity, - baseline=0.04, - penalty=0.06, - ), - syntax_score=component_score(0.01), - tests_passed=0, - tests_total=len(task.public_cases) + (len(task.hidden_cases) if include_hidden else 0), - quality_score=component_score(0.01), - runtime_score=component_score(0.01), + if not compiled: + details["test_results"] = [] + details["test_summary"] = "Code does not compile." + return base_grade( + score=composite_grade_score( + correctness=0.0, + quality=0.05, + runtime=0.0, + syntax=0.0, + similarity=similarity, + baseline=0.04, + penalty=0.06, + ), + syntax_score=component_score(0.01), + tests_passed=0, + tests_total=len(task.public_cases) + (len(task.hidden_cases) if include_hidden else 0), + quality_score=component_score(0.01), + runtime_score=component_score(0.01), timed_out=False, details=details, ) cases = task.public_cases + (task.hidden_cases if include_hidden else []) - result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s) - if result.get("timed_out"): - details["test_results"] = [] - details["test_summary"] = result["error"] - return base_grade( - score=composite_grade_score( - correctness=0.08, - quality=quality["score"], - runtime=0.0, - syntax=0.95, - similarity=similarity, - baseline=0.05, - penalty=0.14, - ), - syntax_score=component_score(0.95), - tests_passed=0, - tests_total=len(cases), - quality_score=quality["score"], - runtime_score=component_score(0.01), + result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s) + if result.get("timed_out"): + details["test_results"] = [] + details["test_summary"] = result["error"] + return base_grade( + score=composite_grade_score( + correctness=0.08, + quality=quality["score"], + runtime=0.0, + syntax=0.95, + similarity=similarity, + baseline=0.05, + penalty=0.14, + ), + syntax_score=component_score(0.95), + tests_passed=0, + tests_total=len(cases), + quality_score=quality["score"], + runtime_score=component_score(0.01), timed_out=True, details=details, ) - if "error" in result: - details["test_results"] = [] - details["test_summary"] = result["error"] - return base_grade( - score=composite_grade_score( - correctness=0.10, - quality=quality["score"], - runtime=0.0, - syntax=0.95, - similarity=similarity, - baseline=0.05, - penalty=0.08, - ), - syntax_score=component_score(0.95), - tests_passed=0, - tests_total=len(cases), - quality_score=quality["score"], - runtime_score=component_score(0.01), + if "error" in result: + details["test_results"] = [] + details["test_summary"] = result["error"] + return base_grade( + score=composite_grade_score( + correctness=0.10, + quality=quality["score"], + runtime=0.0, + syntax=0.95, + similarity=similarity, + baseline=0.05, + penalty=0.08, + ), + syntax_score=component_score(0.95), + tests_passed=0, + tests_total=len(cases), + quality_score=quality["score"], + runtime_score=component_score(0.01), timed_out=False, details=details, ) @@ -122,25 +122,25 @@ def grade_optimization_task( if timed_out: runtime_score = component_score(0.01) - details["test_results"] = data["results"] - details["test_summary"] = summarize_results("Test results", data["results"]) - details["benchmark"] = benchmark_summary - - runtime_progress = 0.0 if benchmark_summary == "Benchmark deferred until hidden evaluation." else runtime_score - return base_grade( - score=composite_grade_score( - correctness=pass_rate, - quality=quality["score"], - runtime=runtime_progress if include_hidden else 0.10, - syntax=0.95, - similarity=similarity, - baseline=0.08 if include_hidden else 0.07, - penalty=0.10 if timed_out else 0.0, - ), - syntax_score=component_score(0.95), - tests_passed=data["passed"], - tests_total=data["total"], - quality_score=quality["score"], + details["test_results"] = data["results"] + details["test_summary"] = summarize_results("Test results", data["results"]) + details["benchmark"] = benchmark_summary + + runtime_progress = 0.0 if benchmark_summary == "Benchmark deferred until hidden evaluation." else runtime_score + return base_grade( + score=composite_grade_score( + correctness=pass_rate, + quality=quality["score"], + runtime=runtime_progress if include_hidden else 0.10, + syntax=0.95, + similarity=similarity, + baseline=0.08 if include_hidden else 0.07, + penalty=0.10 if timed_out else 0.0, + ), + syntax_score=component_score(0.95), + tests_passed=data["passed"], + tests_total=data["total"], + quality_score=quality["score"], runtime_score=runtime_score, timed_out=timed_out, details=details, diff --git a/graders/shared.py b/graders/shared.py index 4334ad3a96ed74a1b1a6bde641eb05c89e3a9c05..b90363ba499b3bf7d4b74f44c971d3fddf3f469e 100644 --- a/graders/shared.py +++ b/graders/shared.py @@ -2,28 +2,28 @@ from __future__ import annotations -import ast -import difflib -import math -import multiprocessing as mp -import os -import time -import traceback +import ast +import difflib +import math +import multiprocessing as mp +import os +import time +import traceback from typing import Any, Callable, Dict, List try: - from ..models import TaskGrade + from ..models import TaskGrade from ..tasks.catalog import CallCase, ReviewTask except ImportError: - from models import TaskGrade + from models import TaskGrade from tasks.catalog import CallCase, ReviewTask -STRICT_SCORE_MIN = 0.01 -STRICT_SCORE_MAX = 0.99 -POOR_SCORE = 0.1 -NEAR_PERFECT_SCORE = 0.95 -EPS = 1e-6 +STRICT_SCORE_MIN = 0.01 +STRICT_SCORE_MAX = 0.99 +POOR_SCORE = 0.1 +NEAR_PERFECT_SCORE = 0.95 +EPS = 1e-6 def finite_float(value: Any, fallback: float = STRICT_SCORE_MIN) -> float: @@ -38,54 +38,54 @@ def finite_float(value: Any, fallback: float = STRICT_SCORE_MIN) -> float: return numeric -def clamp(value: float, lower: float = 0.0, upper: float = 1.0) -> float: - """Clamp a floating-point value to a closed interval.""" - - numeric = finite_float(value, fallback=lower) - return max(lower, min(upper, numeric)) - - -def safe_score(score: Any) -> float: - """Clamp any score to the strict OpenEnv-safe open interval (0, 1).""" - - bounded = max(EPS, min(1.0 - EPS, finite_float(score, fallback=EPS))) - assert 0 < bounded < 1, f"Score must be strictly between 0 and 1: {bounded}" - return bounded - - -def normalize_score(x: Any) -> float: - """Sigmoid-normalize a raw score and clamp it safely into (0, 1).""" - - numeric = finite_float(x, fallback=0.0) - bounded = max(-20.0, min(20.0, numeric)) - return safe_score(1.0 / (1.0 + math.exp(-bounded))) - - -def final_score_pipeline(raw_score: Any) -> float: - """Normalize arbitrary raw scoring signals into a strict OpenEnv-safe score.""" - - return normalize_score(raw_score) - - -def strict_score(value: Any, lower: float = STRICT_SCORE_MIN, upper: float = STRICT_SCORE_MAX) -> float: - """Clamp a score to the OpenEnv-safe open interval (0, 1).""" - - score = max(lower, min(upper, finite_float(value, fallback=lower))) - score = safe_score(score) - assert 0 < score < 1, f"Invalid score: {score}" - return score - - -def shaped_score(progress: Any, floor: float = POOR_SCORE, ceiling: float = NEAR_PERFECT_SCORE) -> float: - """Map progress in [0, 1] to a smooth score band within (0, 1).""" - - bounded_progress = clamp(finite_float(progress, fallback=0.0)) - centered_progress = (bounded_progress - 0.5) * 6.0 - smoothed_progress = final_score_pipeline(centered_progress) - score = floor + (ceiling - floor) * smoothed_progress - score = safe_score(score) - assert 0 < score < 1, f"Invalid score: {score}" - return score +def clamp(value: float, lower: float = 0.0, upper: float = 1.0) -> float: + """Clamp a floating-point value to a closed interval.""" + + numeric = finite_float(value, fallback=lower) + return max(lower, min(upper, numeric)) + + +def safe_score(score: Any) -> float: + """Clamp any score to the strict OpenEnv-safe open interval (0, 1).""" + + bounded = max(EPS, min(1.0 - EPS, finite_float(score, fallback=EPS))) + assert 0 < bounded < 1, f"Score must be strictly between 0 and 1: {bounded}" + return bounded + + +def normalize_score(x: Any) -> float: + """Sigmoid-normalize a raw score and clamp it safely into (0, 1).""" + + numeric = finite_float(x, fallback=0.0) + bounded = max(-20.0, min(20.0, numeric)) + return safe_score(1.0 / (1.0 + math.exp(-bounded))) + + +def final_score_pipeline(raw_score: Any) -> float: + """Normalize arbitrary raw scoring signals into a strict OpenEnv-safe score.""" + + return normalize_score(raw_score) + + +def strict_score(value: Any, lower: float = STRICT_SCORE_MIN, upper: float = STRICT_SCORE_MAX) -> float: + """Clamp a score to the OpenEnv-safe open interval (0, 1).""" + + score = max(lower, min(upper, finite_float(value, fallback=lower))) + score = safe_score(score) + assert 0 < score < 1, f"Invalid score: {score}" + return score + + +def shaped_score(progress: Any, floor: float = POOR_SCORE, ceiling: float = NEAR_PERFECT_SCORE) -> float: + """Map progress in [0, 1] to a smooth score band within (0, 1).""" + + bounded_progress = clamp(finite_float(progress, fallback=0.0)) + centered_progress = (bounded_progress - 0.5) * 6.0 + smoothed_progress = final_score_pipeline(centered_progress) + score = floor + (ceiling - floor) * smoothed_progress + score = safe_score(score) + assert 0 < score < 1, f"Invalid score: {score}" + return score def score_from_checks(passed: int, total: int, floor: float = POOR_SCORE, ceiling: float = NEAR_PERFECT_SCORE) -> float: @@ -104,59 +104,59 @@ def safe_ratio(numerator: Any, denominator: Any) -> float: return clamp(numer / denom) -def component_score(value: Any) -> float: - """Normalize component scores such as syntax, quality, and runtime.""" - - bounded_value = clamp(finite_float(value, fallback=0.0)) - return shaped_score(bounded_value, floor=0.02, ceiling=0.98) - - -def composite_progress( - *, - correctness: Any = 0.0, - quality: Any = 0.0, - runtime: Any = 0.0, - syntax: Any = 0.0, - similarity: Any = 0.0, - baseline: float = 0.05, - penalty: Any = 0.0, -) -> float: - """Blend multiple progress signals into a stable scalar progress estimate.""" - - progress = ( - finite_float(baseline, fallback=0.05) - + 0.45 * clamp(correctness) - + 0.20 * clamp(quality) - + 0.15 * clamp(runtime) - + 0.15 * clamp(syntax) - + 0.05 * clamp(similarity) - - 0.20 * clamp(penalty) - ) - return clamp(progress) - - -def composite_grade_score( - *, - correctness: Any = 0.0, - quality: Any = 0.0, - runtime: Any = 0.0, - syntax: Any = 0.0, - similarity: Any = 0.0, - baseline: float = 0.05, - penalty: Any = 0.0, -) -> float: - """Create a smooth task score from multiple bounded signals.""" - - progress = composite_progress( - correctness=correctness, - quality=quality, - runtime=runtime, - syntax=syntax, - similarity=similarity, - baseline=baseline, - penalty=penalty, - ) - return shaped_score(progress) +def component_score(value: Any) -> float: + """Normalize component scores such as syntax, quality, and runtime.""" + + bounded_value = clamp(finite_float(value, fallback=0.0)) + return shaped_score(bounded_value, floor=0.02, ceiling=0.98) + + +def composite_progress( + *, + correctness: Any = 0.0, + quality: Any = 0.0, + runtime: Any = 0.0, + syntax: Any = 0.0, + similarity: Any = 0.0, + baseline: float = 0.05, + penalty: Any = 0.0, +) -> float: + """Blend multiple progress signals into a stable scalar progress estimate.""" + + progress = ( + finite_float(baseline, fallback=0.05) + + 0.45 * clamp(correctness) + + 0.20 * clamp(quality) + + 0.15 * clamp(runtime) + + 0.15 * clamp(syntax) + + 0.05 * clamp(similarity) + - 0.20 * clamp(penalty) + ) + return clamp(progress) + + +def composite_grade_score( + *, + correctness: Any = 0.0, + quality: Any = 0.0, + runtime: Any = 0.0, + syntax: Any = 0.0, + similarity: Any = 0.0, + baseline: float = 0.05, + penalty: Any = 0.0, +) -> float: + """Create a smooth task score from multiple bounded signals.""" + + progress = composite_progress( + correctness=correctness, + quality=quality, + runtime=runtime, + syntax=syntax, + similarity=similarity, + baseline=baseline, + penalty=penalty, + ) + return shaped_score(progress) def compile_code(code: str) -> tuple[bool, str]: @@ -199,26 +199,18 @@ def run_with_timeout( payload: Dict[str, Any], timeout_s: float, ) -> Dict[str, Any]: - """Execute a worker in a subprocess and terminate on timeout. - - Some constrained Windows environments disallow spawned pipes or child - processes. In those cases, fall back to the inline timeout path so local - demos and tests still work deterministically. - """ - - try: - ctx = mp.get_context("spawn") - queue = ctx.Queue() - process = ctx.Process(target=_queue_worker, args=(worker, payload, queue)) - process.start() - process.join(timeout_s) - except (PermissionError, OSError): - return run_inline_with_timeout(worker, payload, timeout_s) - - if process.is_alive(): - process.terminate() - process.join() - return {"timed_out": True, "error": f"Execution exceeded {timeout_s:.1f}s timeout."} + """Execute a worker in a subprocess and terminate on timeout.""" + + ctx = mp.get_context("spawn") + queue = ctx.Queue() + process = ctx.Process(target=_queue_worker, args=(worker, payload, queue)) + process.start() + process.join(timeout_s) + + if process.is_alive(): + process.terminate() + process.join() + return {"timed_out": True, "error": f"Execution exceeded {timeout_s:.1f}s timeout."} if queue.empty(): return {"timed_out": False, "error": "Worker exited without returning a result."} @@ -227,31 +219,31 @@ def run_with_timeout( if not message["ok"]: return { "timed_out": False, - "error": f"{message['error']}\n{message['traceback']}", - } - return {"timed_out": False, "data": message["data"]} - - -def run_inline_with_timeout( - worker: Callable[[Dict[str, Any]], Dict[str, Any]], - payload: Dict[str, Any], - timeout_s: float, -) -> Dict[str, Any]: - """Fallback execution path for platforms where spawned workers are unreliable.""" - - started = time.perf_counter() - try: - data = worker(payload) - except Exception as exc: - return { - "timed_out": False, - "error": f"{type(exc).__name__}: {exc}\n{traceback.format_exc(limit=5)}", - } - - elapsed = time.perf_counter() - started - if elapsed > timeout_s: - return {"timed_out": True, "error": f"Execution exceeded {timeout_s:.1f}s timeout."} - return {"timed_out": False, "data": data} + "error": f"{message['error']}\n{message['traceback']}", + } + return {"timed_out": False, "data": message["data"]} + + +def run_inline_with_timeout( + worker: Callable[[Dict[str, Any]], Dict[str, Any]], + payload: Dict[str, Any], + timeout_s: float, +) -> Dict[str, Any]: + """Fallback execution path for platforms where spawned workers are unreliable.""" + + started = time.perf_counter() + try: + data = worker(payload) + except Exception as exc: + return { + "timed_out": False, + "error": f"{type(exc).__name__}: {exc}\n{traceback.format_exc(limit=5)}", + } + + elapsed = time.perf_counter() - started + if elapsed > timeout_s: + return {"timed_out": True, "error": f"Execution exceeded {timeout_s:.1f}s timeout."} + return {"timed_out": False, "data": data} def _execute_cases_worker(payload: Dict[str, Any]) -> Dict[str, Any]: @@ -456,7 +448,7 @@ def _benchmark_worker(payload: Dict[str, Any]) -> Dict[str, Any]: return {"baseline_seconds": baseline_seconds, "candidate_seconds": candidate_seconds} -def benchmark_candidate(task: ReviewTask, code: str, timeout_s: float) -> Dict[str, Any]: +def benchmark_candidate(task: ReviewTask, code: str, timeout_s: float) -> Dict[str, Any]: """Benchmark a candidate solution against the starter implementation.""" if not task.benchmark_config: @@ -470,10 +462,10 @@ def benchmark_candidate(task: ReviewTask, code: str, timeout_s: float) -> Dict[s "events": events, "iterations": task.benchmark_config.get("iterations", 5), } - if os.name == "nt": - result = run_inline_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s) - else: - result = run_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s) + if os.name == "nt": + result = run_inline_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s) + else: + result = run_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s) if result.get("timed_out"): return {"runtime_score": component_score(STRICT_SCORE_MIN), "timed_out": True, "details": result["error"]} if "error" in result: diff --git a/graders/syntax.py b/graders/syntax.py index 3b31c119d10e74acb5b7645370de085c458b1e13..c11111f192eef824e83ab06021cf578f4fa544fc 100644 --- a/graders/syntax.py +++ b/graders/syntax.py @@ -3,120 +3,120 @@ from __future__ import annotations try: - from ..models import TaskGrade + from ..models import TaskGrade from ..tasks.catalog import ReviewTask except ImportError: - from models import TaskGrade + from models import TaskGrade from tasks.catalog import ReviewTask -from .shared import ( - base_grade, - compile_code, - composite_grade_score, - component_score, - execute_cases, - quality_metrics, - similarity_score, - summarize_results, -) +from .shared import ( + base_grade, + compile_code, + composite_grade_score, + component_score, + execute_cases, + quality_metrics, + similarity_score, + summarize_results, +) -def grade_syntax_task(task: ReviewTask, code: str, timeout_s: float = 2.0) -> TaskGrade: - """Grade a syntax-fix task deterministically.""" +def grade_syntax_task(task: ReviewTask, code: str, timeout_s: float = 2.0) -> TaskGrade: + """Grade a syntax-fix task deterministically.""" + + compiled, compile_error = compile_code(code) + quality = quality_metrics(code, task.function_name) + similarity = similarity_score(code, task.reference_code) + details = { + "compile_error": compile_error, + "quality_notes": quality["quality_notes"], + "style_score": quality["style_score"], + } - compiled, compile_error = compile_code(code) - quality = quality_metrics(code, task.function_name) - similarity = similarity_score(code, task.reference_code) - details = { - "compile_error": compile_error, - "quality_notes": quality["quality_notes"], - "style_score": quality["style_score"], - } - - if not compiled: - details["test_results"] = [] - details["test_summary"] = "Code does not compile yet." - return base_grade( - score=composite_grade_score( - correctness=0.0, - quality=0.05, - runtime=0.05, - syntax=0.0, - similarity=similarity, - baseline=0.05, - penalty=0.05, - ), - syntax_score=component_score(0.01), - tests_passed=0, - tests_total=len(task.public_cases) + len(task.hidden_cases), - quality_score=component_score(0.01), - runtime_score=component_score(0.01), + if not compiled: + details["test_results"] = [] + details["test_summary"] = "Code does not compile yet." + return base_grade( + score=composite_grade_score( + correctness=0.0, + quality=0.05, + runtime=0.05, + syntax=0.0, + similarity=similarity, + baseline=0.05, + penalty=0.05, + ), + syntax_score=component_score(0.01), + tests_passed=0, + tests_total=len(task.public_cases) + len(task.hidden_cases), + quality_score=component_score(0.01), + runtime_score=component_score(0.01), timed_out=False, details=details, ) cases = task.public_cases + task.hidden_cases - result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s) - if result.get("timed_out"): - details["test_results"] = [] - details["test_summary"] = result["error"] - return base_grade( - score=composite_grade_score( - correctness=0.15, - quality=quality["score"], - runtime=0.0, - syntax=0.95, - similarity=similarity, - baseline=0.08, - penalty=0.12, - ), - syntax_score=component_score(0.95), - tests_passed=0, - tests_total=len(cases), - quality_score=quality["score"], - runtime_score=component_score(0.01), + result = execute_cases(code, task.function_name, cases, timeout_s=timeout_s) + if result.get("timed_out"): + details["test_results"] = [] + details["test_summary"] = result["error"] + return base_grade( + score=composite_grade_score( + correctness=0.15, + quality=quality["score"], + runtime=0.0, + syntax=0.95, + similarity=similarity, + baseline=0.08, + penalty=0.12, + ), + syntax_score=component_score(0.95), + tests_passed=0, + tests_total=len(cases), + quality_score=quality["score"], + runtime_score=component_score(0.01), timed_out=True, details=details, ) - if "error" in result: - details["test_results"] = [] - details["test_summary"] = result["error"] - return base_grade( - score=composite_grade_score( - correctness=0.18, - quality=quality["score"], - runtime=0.0, - syntax=0.95, - similarity=similarity, - baseline=0.08, - penalty=0.08, - ), - syntax_score=component_score(0.95), - tests_passed=0, - tests_total=len(cases), - quality_score=quality["score"], - runtime_score=component_score(0.01), + if "error" in result: + details["test_results"] = [] + details["test_summary"] = result["error"] + return base_grade( + score=composite_grade_score( + correctness=0.18, + quality=quality["score"], + runtime=0.0, + syntax=0.95, + similarity=similarity, + baseline=0.08, + penalty=0.08, + ), + syntax_score=component_score(0.95), + tests_passed=0, + tests_total=len(cases), + quality_score=quality["score"], + runtime_score=component_score(0.01), timed_out=False, details=details, ) - data = result["data"] - details["test_results"] = data["results"] - details["test_summary"] = summarize_results("Validation checks", data["results"]) - pass_rate = data["passed"] / max(data["total"], 1) - return base_grade( - score=composite_grade_score( - correctness=pass_rate, - quality=quality["score"], - runtime=0.05, - syntax=0.95, - similarity=similarity, - baseline=0.10, - ), - syntax_score=component_score(0.95), - tests_passed=data["passed"], - tests_total=data["total"], - quality_score=quality["score"], + data = result["data"] + details["test_results"] = data["results"] + details["test_summary"] = summarize_results("Validation checks", data["results"]) + pass_rate = data["passed"] / max(data["total"], 1) + return base_grade( + score=composite_grade_score( + correctness=pass_rate, + quality=quality["score"], + runtime=0.05, + syntax=0.95, + similarity=similarity, + baseline=0.10, + ), + syntax_score=component_score(0.95), + tests_passed=data["passed"], + tests_total=data["total"], + quality_score=quality["score"], runtime_score=component_score(0.01), timed_out=False, details=details, diff --git a/inference.py b/inference.py index 9ede6c47a468c19322eda425403a76ac266b41ea..beada78d444cc14cf9c210a6132b24699430c198 100644 --- a/inference.py +++ b/inference.py @@ -1,12 +1,12 @@ -#!/usr/bin/env python3 -"""Root validator entrypoint.""" - -from __future__ import annotations - -import sys - -from app.env.runner import main - - -if __name__ == "__main__": - sys.exit(main()) +#!/usr/bin/env python3 +"""Root validator entrypoint.""" + +from __future__ import annotations + +import sys + +from app.env.runner import main + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/launch.py b/launch.py index 71d10c43e0b3a6a05a767902d2a022f7662bdeb1..c06c8d1cdf8c2a4a1dabf4cf54ca9534967d7212 100644 --- a/launch.py +++ b/launch.py @@ -1,35 +1,35 @@ -"""Launch the FastAPI backend and Streamlit UI in one Docker container.""" - -from __future__ import annotations - -import subprocess -import sys - - -def main() -> int: - """Start the API backend in the background and keep Streamlit in the foreground.""" - - api_process = subprocess.Popen( - ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8001"], - ) - try: - return subprocess.call( - [ - "streamlit", - "run", - "app/streamlit_app.py", - "--server.port", - "8000", - "--server.address", - "0.0.0.0", - "--server.headless", - "true", - ] - ) - finally: - api_process.terminate() - api_process.wait(timeout=10) - - -if __name__ == "__main__": - sys.exit(main()) +"""Launch the FastAPI backend and Streamlit UI in one Docker container.""" + +from __future__ import annotations + +import subprocess +import sys + + +def main() -> int: + """Start the API backend in the background and keep Streamlit in the foreground.""" + + api_process = subprocess.Popen( + ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8001"], + ) + try: + return subprocess.call( + [ + "streamlit", + "run", + "app/streamlit_app.py", + "--server.port", + "8000", + "--server.address", + "0.0.0.0", + "--server.headless", + "true", + ] + ) + finally: + api_process.terminate() + api_process.wait(timeout=10) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/models.py b/models.py index 5a83f60d050accc83ab05bb2bef8743c52237739..6002de2dd30c2ab339e1acf93c068d0160666960 100644 --- a/models.py +++ b/models.py @@ -1,4 +1,4 @@ -"""Typed models for the python_code_review_env environment.""" +"""Typed models for the python_code_review_env environment.""" from __future__ import annotations @@ -23,22 +23,22 @@ class HistoryEntry(BaseModel): reward: float = Field(..., gt=0.0, lt=1.0, description="Reward returned for the step.") -class RewardDetails(BaseModel): - """Transparent reward decomposition for debugging and training.""" - - value: float = Field(..., gt=0.0, lt=1.0, description="Clamped net reward in (0.0, 1.0).") - syntax_reward: float = Field(default=0.0) - test_reward: float = Field(default=0.0) - correctness_bonus: float = Field(default=0.0) - quality_bonus: float = Field(default=0.0) - error_reduction_bonus: float = Field(default=0.0) - completion_bonus: float = Field(default=0.0) - runtime_bonus: float = Field(default=0.0) - progress_delta: float = Field(default=0.0) - invalid_action_penalty: float = Field(default=0.0) - timeout_penalty: float = Field(default=0.0) - regression_penalty: float = Field(default=0.0) - stagnation_penalty: float = Field(default=0.0) +class RewardDetails(BaseModel): + """Transparent reward decomposition for debugging and training.""" + + value: float = Field(..., gt=0.0, lt=1.0, description="Clamped net reward in (0.0, 1.0).") + syntax_reward: float = Field(default=0.0) + test_reward: float = Field(default=0.0) + correctness_bonus: float = Field(default=0.0) + quality_bonus: float = Field(default=0.0) + error_reduction_bonus: float = Field(default=0.0) + completion_bonus: float = Field(default=0.0) + runtime_bonus: float = Field(default=0.0) + progress_delta: float = Field(default=0.0) + invalid_action_penalty: float = Field(default=0.0) + timeout_penalty: float = Field(default=0.0) + regression_penalty: float = Field(default=0.0) + stagnation_penalty: float = Field(default=0.0) reason: str = Field(..., description="Human-readable reward explanation.") prev_score: float = Field(default=0.01, gt=0.0, lt=1.0) curr_score: float = Field(default=0.01, gt=0.0, lt=1.0) @@ -66,17 +66,17 @@ class PythonCodeReviewObservation(Observation): current_code: str = Field(..., description="Latest code under review.") errors: str = Field(default="", description="Syntax or execution errors.") test_results: str = Field(default="", description="Public test and benchmark feedback.") - visible_tests: List[str] = Field(default_factory=list) - history: List[HistoryEntry] = Field(default_factory=list) - attempts_remaining: int = Field(..., ge=0) - last_action_status: str = Field(default="") - last_action_error: Optional[str] = Field(default=None) - score: float = Field(..., gt=0.0, lt=1.0) - reward: float = Field(default=0.1, gt=0.0, lt=1.0) - done: bool = Field(default=False) - reward_details: RewardDetails = Field( - default_factory=lambda: RewardDetails(value=0.1, reason="Environment reset.") - ) + visible_tests: List[str] = Field(default_factory=list) + history: List[HistoryEntry] = Field(default_factory=list) + attempts_remaining: int = Field(..., ge=0) + last_action_status: str = Field(default="") + last_action_error: Optional[str] = Field(default=None) + score: float = Field(..., gt=0.0, lt=1.0) + reward: float = Field(default=0.1, gt=0.0, lt=1.0) + done: bool = Field(default=False) + reward_details: RewardDetails = Field( + default_factory=lambda: RewardDetails(value=0.1, reason="Environment reset.") + ) class PythonCodeReviewState(State): diff --git a/models/__init__.py b/models/__init__.py index e850debc4c529344baf4fdc31f9f9f5f46b953ed..b2d760c568bd457e584c28c004c66be799de6106 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,66 +1,76 @@ -"""PyTorch-backed model wrappers plus OpenEnv schema exports.""" - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path - -from .pytorch_model import PyTorchCodeAnalyzerModel - - -def _load_schema_module(): - schema_path = Path(__file__).resolve().parent.parent / "models.py" - spec = importlib.util.spec_from_file_location("_python_env_schema_models", schema_path) - if spec is None or spec.loader is None: # pragma: no cover - raise ImportError(f"Unable to load schema models from {schema_path}") - if spec.name in sys.modules: - return sys.modules[spec.name] - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - for model_name in ( - "HistoryEntry", - "RewardDetails", - "PythonCodeReviewAction", - "PythonCodeReviewObservation", - "PythonCodeReviewState", - "TaskDescriptor", - "TaskSummary", - "TaskGrade", - "HealthResponse", - ): - getattr(module, model_name).model_rebuild() - return module - - -_schema_models = _load_schema_module() - -HealthResponse = _schema_models.HealthResponse -HistoryEntry = _schema_models.HistoryEntry -PythonAction = _schema_models.PythonAction -PythonCodeReviewAction = _schema_models.PythonCodeReviewAction -PythonCodeReviewObservation = _schema_models.PythonCodeReviewObservation -PythonCodeReviewState = _schema_models.PythonCodeReviewState -PythonObservation = _schema_models.PythonObservation -PythonState = _schema_models.PythonState -RewardDetails = _schema_models.RewardDetails -TaskDescriptor = _schema_models.TaskDescriptor -TaskGrade = _schema_models.TaskGrade -TaskSummary = _schema_models.TaskSummary - -__all__ = [ - "HealthResponse", - "HistoryEntry", - "PyTorchCodeAnalyzerModel", - "PythonAction", - "PythonCodeReviewAction", - "PythonCodeReviewObservation", - "PythonCodeReviewState", - "PythonObservation", - "PythonState", - "RewardDetails", - "TaskDescriptor", - "TaskGrade", - "TaskSummary", -] +"""PyTorch-backed model wrappers plus OpenEnv schema exports.""" + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .pytorch_model import PyTorchCodeAnalyzerModel + + +def _load_schema_module(): + schema_path = Path(__file__).resolve().parent.parent / "models.py" + spec = importlib.util.spec_from_file_location("_python_env_schema_models", schema_path) + if spec is None or spec.loader is None: # pragma: no cover + raise ImportError(f"Unable to load schema models from {schema_path}") + if spec.name in sys.modules: + return sys.modules[spec.name] + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + for model_name in ( + "HistoryEntry", + "RewardDetails", + "PythonCodeReviewAction", + "PythonCodeReviewObservation", + "PythonCodeReviewState", + "TaskDescriptor", + "TaskSummary", + "TaskGrade", + "HealthResponse", + ): + getattr(module, model_name).model_rebuild() + return module + + +_schema_models = _load_schema_module() + +HealthResponse = _schema_models.HealthResponse +HistoryEntry = _schema_models.HistoryEntry +PythonAction = _schema_models.PythonAction +PythonCodeReviewAction = _schema_models.PythonCodeReviewAction +PythonCodeReviewObservation = _schema_models.PythonCodeReviewObservation +PythonCodeReviewState = _schema_models.PythonCodeReviewState +PythonObservation = _schema_models.PythonObservation +PythonState = _schema_models.PythonState +RewardDetails = _schema_models.RewardDetails +TaskDescriptor = _schema_models.TaskDescriptor +TaskGrade = _schema_models.TaskGrade +TaskSummary = _schema_models.TaskSummary + + +def __getattr__(name: str): + if name == "PyTorchCodeAnalyzerModel": + from .pytorch_model import PyTorchCodeAnalyzerModel as model_class + + return model_class + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + +__all__ = [ + "HealthResponse", + "HistoryEntry", + "PyTorchCodeAnalyzerModel", + "PythonAction", + "PythonCodeReviewAction", + "PythonCodeReviewObservation", + "PythonCodeReviewState", + "PythonObservation", + "PythonState", + "RewardDetails", + "TaskDescriptor", + "TaskGrade", + "TaskSummary", +] diff --git a/models/pytorch_model.py b/models/pytorch_model.py index b164b048a8d44d68de06b51136f5499f889d3a04..f3ff2e37177beaea1dc10b9b4a276d171bbfe112 100644 --- a/models/pytorch_model.py +++ b/models/pytorch_model.py @@ -1,4 +1,4 @@ -"""PyTorch + transformers model wrapper for code-quality scoring.""" +"""PyTorch + transformers model wrapper for multi-domain code scoring.""" from __future__ import annotations @@ -17,64 +17,34 @@ except Exception: DOMAIN_PROTOTYPES: Dict[str, List[str]] = { "dsa": [ - "Algorithmic Python with nested loops, recursion, dynamic programming, maps, and asymptotic analysis.", - "Competitive programming utility focused on arrays, graphs, search, and runtime complexity.", + "Binary search, hashmap optimization, recursion, dynamic programming, arrays, trees, graphs, stack, queue, complexity.", + "Competitive programming algorithm with loops, memoization, prefix sums, and asymptotic analysis.", ], "data_science": [ - "Pandas dataframe transformation, numpy vectorization, feature engineering, data cleaning, and leakage prevention.", - "Notebook-style data pipeline using joins, aggregations, and columnar operations.", + "Pandas dataframe transformation, numpy vectorization, feature leakage, train test split, iterrows misuse.", + "Data cleaning pipeline using pandas, numpy, aggregation, joins, and vectorized operations.", ], "ml_dl": [ - "PyTorch model inference or training loop with eval mode, no_grad, tensors, optimizer, and loss functions.", - "Machine learning code with torch, sklearn, batches, checkpoints, and metrics.", + "PyTorch model, training loop, optimizer, backward pass, eval mode, no_grad, loss function, dataloader.", + "Machine learning inference and training code with torch, sklearn, tensors, gradients, and model checkpoints.", ], "web": [ - "FastAPI backend endpoint with pydantic validation, dependency injection, request parsing, and API safety.", - "Python web-service route handling, serialization, authentication, and response contracts.", + "FastAPI endpoint, request validation, Pydantic models, async routes, API security, backend service design.", + "REST API backend with routers, dependency injection, input validation, serialization, and error handling.", ], "general": [ - "General Python utility code with readability, typing, small functions, tests, and maintainable abstractions.", + "General Python utility code with readable structure, typing, tests, and maintainable abstractions.", ], } QUALITY_ANCHORS: Dict[str, List[str]] = { "high": [ - "Production-ready Python code with clear naming, docstrings, validation, efficient loops, and low complexity.", - "Clean code with explicit error handling, typing, modular design, and testable functions.", + "Readable typed Python code with validation, efficient algorithms, vectorized operations, safe inference, and clean API boundaries.", + "Production-ready code with small functions, docstrings, low complexity, and clear error handling.", ], "low": [ - "Bug-prone Python with nested loops, missing validation, weak naming, duplicated logic, and hard-to-review structure.", - "Risky code with syntax drift, unclear behavior, mutable side effects, and repeated scans over data.", - ], -} - -MAINTAINABILITY_ANCHORS: Dict[str, List[str]] = { - "high": [ - "Readable functions, small logical units, strong typing, comments only where needed, and simple control flow.", - "Maintainable Python service with clean architecture, cohesive modules, and explicit contracts.", - ], - "low": [ - "Large unstructured function, missing docstrings, weak names, deeply nested branches, and difficult debugging.", - "Hard-to-maintain script with inconsistent style, brittle branching, and hidden side effects.", - ], -} - -ISSUE_ANCHORS: Dict[str, List[str]] = { - "correctness": [ - "Off-by-one bug, missing final append, incorrect boundary handling, failing assertions, wrong return value.", - "Logic regression caused by a missing branch, incorrect state update, or unhandled edge case.", - ], - "performance": [ - "Repeated full-list scans, brute-force nested loops, iterrows misuse, avoidable O(n^2) behavior, slow pipeline.", - "Performance regression from redundant iteration, poor data structures, or missing vectorization.", - ], - "security": [ - "Unsafe input handling, unchecked request payload, eval usage, missing validation, insecure backend pattern.", - "Security risk caused by trusting raw user input or bypassing schema validation.", - ], - "style": [ - "Readability issues from long lines, missing docstrings, inconsistent spacing, tabs, and trailing whitespace.", - "Style drift that makes code review harder and maintenance slower.", + "Brute-force nested loops, missing validation, unsafe input handling, missing eval mode, missing no_grad, and code smells.", + "Hard to maintain code with high complexity, repeated scans, mutable side effects, and unclear structure.", ], } @@ -148,79 +118,31 @@ class PyTorchCodeAnalyzerModel: self._prototype_cache[bucket] = self._embed_texts(texts) return self._prototype_cache[bucket] - @staticmethod - def _unit_similarity(candidate: torch.Tensor, matrix: torch.Tensor) -> float: - similarity = torch.matmul(candidate, matrix.T).max().item() - return round((similarity + 1.0) / 2.0, 4) - - @staticmethod - def _quality_label(score: float) -> str: - if score >= 0.82: - return "excellent" - if score >= 0.66: - return "good" - if score >= 0.45: - return "needs_work" - return "risky" - - def predict( - self, - code: str, - context_window: str, - traceback_text: str, - static_summary: Dict[str, object], - ) -> Dict[str, object]: - """Predict domain probabilities, quality, and issue risks for Python code.""" + def predict(self, code: str, context_window: str, static_summary: Dict[str, object]) -> Dict[str, object]: + """Predict domain probabilities and a model quality score.""" document = ( f"Code:\n{code.strip()[:4000]}\n\n" f"Context:\n{context_window.strip()[:1000]}\n\n" - f"Traceback:\n{traceback_text.strip()[:1000]}\n\n" f"Static hints:\n{static_summary}\n" ) candidate = self._embed_texts([document]) domain_scores: Dict[str, float] = {} for domain, texts in DOMAIN_PROTOTYPES.items(): - domain_scores[domain] = self._unit_similarity(candidate, self._prototype_matrix(f"domain:{domain}", texts)) + matrix = self._prototype_matrix(f"domain:{domain}", texts) + similarity = torch.matmul(candidate, matrix.T).max().item() + domain_scores[domain] = round((similarity + 1.0) / 2.0, 4) high_matrix = self._prototype_matrix("quality:high", QUALITY_ANCHORS["high"]) low_matrix = self._prototype_matrix("quality:low", QUALITY_ANCHORS["low"]) high_similarity = torch.matmul(candidate, high_matrix.T).max().item() low_similarity = torch.matmul(candidate, low_matrix.T).max().item() - ml_quality_score = round(float(torch.sigmoid(torch.tensor((high_similarity - low_similarity) * 4.0)).item()), 4) - - high_maintainability = torch.matmul( - candidate, - self._prototype_matrix("maintainability:high", MAINTAINABILITY_ANCHORS["high"]).T, - ).max().item() - low_maintainability = torch.matmul( - candidate, - self._prototype_matrix("maintainability:low", MAINTAINABILITY_ANCHORS["low"]).T, - ).max().item() - maintainability_score = round( - float(torch.sigmoid(torch.tensor((high_maintainability - low_maintainability) * 4.0)).item()), - 4, - ) - - issue_logits = [] - issue_labels = list(ISSUE_ANCHORS.keys()) - for label in issue_labels: - similarity = torch.matmul(candidate, self._prototype_matrix(f"issue:{label}", ISSUE_ANCHORS[label]).T).max().item() - issue_logits.append(similarity) - probabilities = torch.softmax(torch.tensor(issue_logits) * 3.0, dim=0) - issue_probabilities = { - label: round(float(probabilities[index].item()), 4) - for index, label in enumerate(issue_labels) - } + ml_quality_score = torch.sigmoid(torch.tensor((high_similarity - low_similarity) * 4.0)).item() return { "domain_scores": domain_scores, - "ml_quality_score": ml_quality_score, - "quality_score": ml_quality_score, - "quality_label": self._quality_label(ml_quality_score), - "maintainability_score": maintainability_score, - "issue_probabilities": issue_probabilities, + "ml_quality_score": round(float(ml_quality_score), 4), "backend_name": self.backend_name, "model_id": self.model_id, "notes": list(self.notes), diff --git a/openenv_python_code_review_env.egg-info/PKG-INFO b/openenv_python_code_review_env.egg-info/PKG-INFO index 72e36f3f27460840ae1d0602ab79dce6c9fd0972..f1b58d1ba337e4b13c86bbade15f1de13f3e4cd2 100644 --- a/openenv_python_code_review_env.egg-info/PKG-INFO +++ b/openenv_python_code_review_env.egg-info/PKG-INFO @@ -16,16 +16,6 @@ Provides-Extra: dev Requires-Dist: pytest>=8.0.0; extra == "dev" Requires-Dist: pytest-cov>=4.0.0; extra == "dev" ---- -title: Python Code Review Environment Server -sdk: docker -app_port: 8000 -base_path: /web -pinned: false -tags: - - openenv ---- - # OpenEnv Python Code Review Environment Production-ready hackathon submission for OpenEnv evaluation, deterministic validator runs, and Hugging Face Docker deployment. @@ -34,26 +24,25 @@ Production-ready hackathon submission for OpenEnv evaluation, deterministic vali ```text root -|- inference.py # Root validator entrypoint -|- openenv.yaml # OpenEnv manifest -|- app/ -| |- agents/ # Action policy and fallback strategy -| |- env/ # RL loop runner and stdout contract -| |- models/ # Inference dataclasses/config -| |- services/ # OpenAI client wrapper with retries -| `- utils/ # Formatting, task loading, log suppression -|- server/ -| |- env.py # OpenEnv environment and reward shaping -| |- app.py # FastAPI/OpenEnv app, optional Gradio mount -| `- Dockerfile # Alternate Docker build path -|- Dockerfile # Root deployment Docker image -|- graders/ # Syntax, bug-fix, optimization graders -|- tasks/ # Deterministic benchmark tasks and references -|- services/ # Multi-domain analysis services -|- analyzers/ # Domain-specific analyzers -|- models/ # Lazy-loaded PyTorch scoring model -|- schemas/ # API request/response contracts -`- tests/ # Local validation coverage +├── inference.py # Root validator entrypoint +├── openenv.yaml # OpenEnv manifest +├── app/ +│ ├── agents/ # Action policy and fallback strategy +│ ├── env/ # RL loop runner and stdout contract +│ ├── models/ # Inference dataclasses/config +│ ├── services/ # OpenAI client wrapper with retries +│ └── utils/ # Formatting, task loading, log suppression +├── server/ +│ ├── env.py # OpenEnv environment and reward shaping +│ ├── app.py # FastAPI/OpenEnv app, optional Gradio mount +│ └── Dockerfile # Hugging Face Docker image +├── graders/ # Syntax, bug-fix, optimization graders +├── tasks/ # Deterministic benchmark tasks and references +├── services/ # Multi-domain analysis services +├── analyzers/ # Domain-specific analyzers +├── models/ # Lazy-loaded PyTorch scoring model +├── schemas/ # API request/response contracts +└── tests/ # Local validation coverage ``` Runtime flow: @@ -71,8 +60,8 @@ inference.py - `inference.py` now lives at the repo root and delegates to a strict runner under `app/env`. - OpenAI usage is limited to the official Python client: - `client = OpenAI(base_url=API_BASE_URL, api_key=provider_token)`. -- Defaulted env vars are enforced for `API_BASE_URL` and `MODEL_NAME`; the runtime now selects `HF_TOKEN` for the Hugging Face router and `OPENAI_API_KEY` for direct OpenAI usage. + `client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)`. +- Defaulted env vars are enforced for `API_BASE_URL` and `MODEL_NAME`; `HF_TOKEN` is read without a default and handled explicitly. - Output now matches the required single-line contract exactly and always emits `[END]`, including failure paths. - The RL loop now uses `reset()` plus `step_result()` in a proper `while not done` loop. - Step errors now surface through `last_action_error` and are printed in `[STEP]`. @@ -107,7 +96,6 @@ Optional demo UI: ```bash set ENABLE_GRADIO_DEMO=true -set ENABLE_WEB_INTERFACE=true python -m uvicorn server.app:app --host 0.0.0.0 --port 8000 ``` @@ -120,9 +108,7 @@ Required environment variables: - `MODEL_NAME` Default: `Qwen/Qwen2.5-3B-Instruct` - `HF_TOKEN` - Required for `https://router.huggingface.co/v1` -- `OPENAI_API_KEY` - Required for `https://api.openai.com/v1` + Mandatory, no default is injected Example: @@ -133,13 +119,6 @@ set HF_TOKEN=hf_xxx python inference.py ``` -```bash -set API_BASE_URL=https://api.openai.com/v1 -set MODEL_NAME=gpt-4.1-mini -set OPENAI_API_KEY=sk-xxx -python inference.py -``` - Expected stdout shape: ```text @@ -156,7 +135,7 @@ Expected stdout shape: Build from the project root: ```bash -docker build -t openenv-python-code-review-env . +docker build -f server/Dockerfile . ``` Run locally: @@ -182,12 +161,11 @@ Recommended deployment steps: 1. Create a Docker Space. 2. Push this repository as-is. -3. Let Spaces build from the root `Dockerfile`. +3. Let Spaces build with `server/Dockerfile`. 4. Set Space secrets: `HF_TOKEN` 5. Set Space variables as needed: `API_BASE_URL`, `MODEL_NAME`, `ENABLE_GRADIO_DEMO=false` - `ENABLE_WEB_INTERFACE=false` is also supported for OpenEnv-managed deploys. 6. Confirm the app listens on port `8000`. 7. Smoke-test: `/health` diff --git a/openenv_python_code_review_env.egg-info/SOURCES.txt b/openenv_python_code_review_env.egg-info/SOURCES.txt index 941269c4a795f126e60518385ccca67f6d39299b..69092eb01996368c01e47ffc78f3e2751286ed8e 100644 --- a/openenv_python_code_review_env.egg-info/SOURCES.txt +++ b/openenv_python_code_review_env.egg-info/SOURCES.txt @@ -5,8 +5,7 @@ pyproject.toml ./compat.py ./inference.py ./launch.py -./models.py -./sitecustomize.py +./openenv_models.py ./triage.py ./triage_catalog.py ./triage_models.py diff --git a/pyproject.toml b/pyproject.toml index 7702215f4c915195d2e305948e0d45ea0113b704..ce80811fee9a7e89de02c22a43050f8337018034 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,9 +8,11 @@ version = "1.0.0" description = "TorchReview Copilot: AI-powered Python code triage with PyTorch and OpenEnv validation." readme = "README.md" requires-python = ">=3.10" + dependencies = [ "fastapi>=0.111.0", "gradio>=5.26.0", + "hf-xet>=1.4.3", "openai>=1.76.0", "openenv-core[core]>=0.2.2", "streamlit>=1.44.0", @@ -33,22 +35,7 @@ pythonpath = ["."] [tool.setuptools] include-package-data = true -packages = [ - "python_env", - "python_env.server", - "python_env.tasks", - "python_env.graders", - "python_env.api", - "python_env.app", - "python_env.app.agents", - "python_env.app.env", - "python_env.app.models", - "python_env.app.services", - "python_env.app.utils", - "python_env.analyzers", - "python_env.models", - "python_env.schemas", - "python_env.services", - "python_env.utils", -] -package-dir = { "python_env" = ".", "python_env.server" = "server", "python_env.tasks" = "tasks", "python_env.graders" = "graders", "python_env.api" = "api", "python_env.app" = "app", "python_env.app.agents" = "app/agents", "python_env.app.env" = "app/env", "python_env.app.models" = "app/models", "python_env.app.services" = "app/services", "python_env.app.utils" = "app/utils", "python_env.analyzers" = "analyzers", "python_env.models" = "models", "python_env.schemas" = "schemas", "python_env.services" = "services", "python_env.utils" = "utils" } + +[tool.setuptools.packages.find] +where = ["."] +include = ["*"] diff --git a/schemas/__init__.py b/schemas/__init__.py index d2008615adec6aada77a9ada69cf8d3d8d5fb4c1..e635325f1c40ea4e2797578f1fc3224f9548d1df 100644 --- a/schemas/__init__.py +++ b/schemas/__init__.py @@ -1,13 +1,13 @@ -"""Public schemas for the multi-domain analysis platform.""" - -from .request import AnalyzeCodeRequest -from .response import AnalyzeCodeResponse, AnalysisIssue, DomainAnalysis, ScoreBreakdown, StaticAnalysisSummary - -__all__ = [ - "AnalyzeCodeRequest", - "AnalyzeCodeResponse", - "AnalysisIssue", - "DomainAnalysis", - "ScoreBreakdown", - "StaticAnalysisSummary", -] +"""Public schemas for the multi-domain analysis platform.""" + +from .request import AnalyzeCodeRequest +from .response import AnalyzeCodeResponse, AnalysisIssue, DomainAnalysis, ScoreBreakdown, StaticAnalysisSummary + +__all__ = [ + "AnalyzeCodeRequest", + "AnalyzeCodeResponse", + "AnalysisIssue", + "DomainAnalysis", + "ScoreBreakdown", + "StaticAnalysisSummary", +] diff --git a/schemas/request.py b/schemas/request.py index 63f5e75e069a606e916fb5c84c8a1a4137ffa191..c53252a73269901cb3bf98e8a10b2b5d2140ca66 100644 --- a/schemas/request.py +++ b/schemas/request.py @@ -1,51 +1,19 @@ -"""Request schemas for the AI-powered code review workflow.""" +"""Request schemas for code analysis endpoints and UI.""" from __future__ import annotations from typing import Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, Field -DomainHint = Literal["auto", "general", "dsa", "data_science", "ml_dl", "web"] +DomainHint = Literal["auto", "dsa", "data_science", "ml_dl", "web"] class AnalyzeCodeRequest(BaseModel): - """Validated input payload for Python code review requests.""" - - model_config = ConfigDict(str_strip_whitespace=True) - - code: str = Field(..., min_length=1, description="Python source code to analyze.") - context_window: str = Field( - default="", - max_length=4000, - description="Optional repository, pull request, or runtime context.", - ) - traceback_text: str = Field( - default="", - max_length=4000, - description="Optional traceback or failing test output.", - ) - domain_hint: DomainHint = Field( - default="auto", - description="Optional analysis lens for domain-aware suggestions.", - ) - filename: str = Field(default="snippet.py", max_length=255, description="Virtual filename for display.") - enable_suggestions: bool = Field( - default=True, - description="Whether the service should return a prioritized improvement plan.", - ) - - @field_validator("code") - @classmethod - def _reject_empty_code(cls, value: str) -> str: - stripped = value.strip() - if not stripped: - raise ValueError("code must not be empty") - return stripped - - @field_validator("filename") - @classmethod - def _normalize_filename(cls, value: str) -> str: - candidate = value.strip() or "snippet.py" - return candidate[:255] + """Validated input payload for multi-domain code analysis.""" + + code: str = Field(..., min_length=1, description="Source code to analyze.") + context_window: str = Field(default="", max_length=2000, description="Optional repository or task context.") + traceback_text: str = Field(default="", max_length=2000, description="Optional runtime or test failure output.") + domain_hint: DomainHint = Field(default="auto", description="Optional domain override when auto detection is not desired.") diff --git a/schemas/response.py b/schemas/response.py index d673a29e8ba63ab9d9d88c173dd734d392c388b1..568543fa94d66642b2cf7c16c7f8e848709313df 100644 --- a/schemas/response.py +++ b/schemas/response.py @@ -1,4 +1,4 @@ -"""Response schemas for the AI-powered code review platform.""" +"""Response schemas for the multi-domain analysis platform.""" from __future__ import annotations @@ -7,103 +7,67 @@ from typing import Dict, List, Literal from pydantic import BaseModel, Field +DomainType = Literal["dsa", "data_science", "ml_dl", "web", "general"] Severity = Literal["low", "medium", "high"] -IssueCategory = Literal["correctness", "maintainability", "performance", "security", "style"] -QualityLabel = Literal["excellent", "good", "needs_work", "risky"] -DetectedDomain = Literal["general", "dsa", "data_science", "ml_dl", "web"] class AnalysisIssue(BaseModel): """One detected issue or risk in the code snippet.""" title: str - category: IssueCategory = "maintainability" severity: Severity description: str line_hint: int | None = None class StaticAnalysisSummary(BaseModel): - """Python-specific static-analysis signals.""" + """Language-agnostic static-analysis signals.""" syntax_valid: bool syntax_error: str = "" cyclomatic_complexity: int = Field(..., ge=1) line_count: int = Field(..., ge=0) - max_nesting_depth: int = Field(..., ge=0) max_loop_depth: int = Field(..., ge=0) time_complexity: str = "Unknown" space_complexity: str = "Unknown" - lint_score: float = Field(..., ge=0.0, le=1.0) - docstring_coverage: float = Field(..., ge=0.0, le=1.0) detected_imports: List[str] = Field(default_factory=list) code_smells: List[str] = Field(default_factory=list) - issues: List[AnalysisIssue] = Field(default_factory=list) class DomainAnalysis(BaseModel): - """Domain-aware review signals used for context-specific suggestions.""" + """Domain-specific analysis payload returned by an analyzer.""" - domain: DetectedDomain + domain: DomainType domain_score: float = Field(..., ge=0.0, le=1.0) issues: List[AnalysisIssue] = Field(default_factory=list) suggestions: List[str] = Field(default_factory=list) highlights: Dict[str, float | str] = Field(default_factory=dict) -class ModelPrediction(BaseModel): - """PyTorch model output derived from pretrained code embeddings.""" - - quality_label: QualityLabel - quality_score: float = Field(..., ge=0.0, le=1.0) - maintainability_score: float = Field(..., ge=0.0, le=1.0) - issue_probabilities: Dict[str, float] = Field(default_factory=dict) - notes: List[str] = Field(default_factory=list) - - class ScoreBreakdown(BaseModel): - """Reward inputs and the final RL-ready scalar reward.""" + """Reward inputs and final normalized score.""" ml_score: float = Field(..., ge=0.0, le=1.0) domain_score: float = Field(..., ge=0.0, le=1.0) lint_score: float = Field(..., ge=0.0, le=1.0) complexity_penalty: float = Field(..., ge=0.0, le=1.0) - maintainability_score: float = Field(..., ge=0.0, le=1.0) - security_score: float = Field(..., ge=0.0, le=1.0) - readability_score: float = Field(..., ge=0.0, le=1.0) quality_signal: float = Field(..., ge=0.0, le=1.0) error_reduction_signal: float = Field(..., ge=0.0, le=1.0) completion_signal: float = Field(..., ge=0.0, le=1.0) reward: float = Field(..., ge=0.0, le=1.0) -class SuggestionItem(BaseModel): - """One prioritized improvement suggestion.""" - - priority: Literal["P0", "P1", "P2"] - title: str - rationale: str - action: str - category: IssueCategory - - class AnalyzeCodeResponse(BaseModel): """Top-level structured output for API and UI consumers.""" - language: Literal["python"] = "python" - detected_domain: DetectedDomain - domain_confidences: Dict[str, float] = Field(default_factory=dict) + detected_domain: DomainType + domain_confidences: Dict[str, float] score_breakdown: ScoreBreakdown static_analysis: StaticAnalysisSummary - model_prediction: ModelPrediction domain_analysis: DomainAnalysis - suggestions: List[SuggestionItem] = Field(default_factory=list) improvement_plan: List[str] = Field(default_factory=list) - auto_fix_preview: List[str] = Field(default_factory=list) - score_visualization: Dict[str, float] = Field(default_factory=dict) model_backend: str model_id: str summary: str context_window: str = "" - filename: str = "snippet.py" analysis_time_ms: float = Field(..., ge=0.0) diff --git a/server/app.py b/server/app.py index 1c4287e8a2d5124c8a17bc9163e029768d7af990..ca80dee39f8dd32319aaf01c55449049b96b0c12 100644 --- a/server/app.py +++ b/server/app.py @@ -53,10 +53,16 @@ def build_application(): served_app = api_app if gr is not None and _gradio_enabled(): try: - from .demo import build_demo + from .demo import CSS, build_demo except ImportError: - from server.demo import build_demo - served_app = gr.mount_gradio_app(api_app, build_demo(), path="/") + from server.demo import CSS, build_demo + served_app = gr.mount_gradio_app( + api_app, + build_demo(), + path="/", + theme=gr.themes.Soft(primary_hue="orange", secondary_hue="amber"), + css=CSS, + ) wrapper_app = FastAPI(title="python_code_review_env", version="1.0.0") @@ -74,7 +80,7 @@ app = build_application() def main(host: str = "0.0.0.0", port: int = 8000) -> None: import uvicorn - uvicorn.run(app, host=host, port=port) + uvicorn.run(app, host=host, port=port, access_log=False) if __name__ == "__main__": diff --git a/server/demo.py b/server/demo.py index 674e040abe7c3d280b971a6ce3224da59b40cc41..3d3ac716faafd7ba663562b1f7bbcdbd077a2cda 100644 --- a/server/demo.py +++ b/server/demo.py @@ -347,7 +347,7 @@ def build_demo() -> gr.Blocks: examples = get_default_engine().example_map() first_example = next(iter(examples.values())) - with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="amber"), css=CSS, title="TorchReview Copilot") as demo: + with gr.Blocks(title="TorchReview Copilot") as demo: gr.HTML( """
diff --git a/server/env.py b/server/env.py index d8bfd2ded82acd755bd3c3472219f41ec73776f3..018cbe06bc136765f6480e1613b39c163647248a 100644 --- a/server/env.py +++ b/server/env.py @@ -8,27 +8,27 @@ from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import EnvironmentMetadata -try: - from ..graders import grade_task - from ..graders.shared import component_score, final_score_pipeline, safe_ratio, safe_score - from ..models import ( - HistoryEntry, - PythonCodeReviewAction, - PythonCodeReviewObservation, - PythonCodeReviewState, - RewardDetails, +try: + from ..graders import grade_task + from ..graders.shared import component_score, final_score_pipeline, safe_ratio, safe_score + from ..models import ( + HistoryEntry, + PythonCodeReviewAction, + PythonCodeReviewObservation, + PythonCodeReviewState, + RewardDetails, TaskGrade, ) from ..tasks import ReviewTask, list_tasks, select_task -except ImportError: - from graders import grade_task - from graders.shared import component_score, final_score_pipeline, safe_ratio, safe_score - from models import ( - HistoryEntry, - PythonCodeReviewAction, - PythonCodeReviewObservation, - PythonCodeReviewState, - RewardDetails, +except ImportError: + from graders import grade_task + from graders.shared import component_score, final_score_pipeline, safe_ratio, safe_score + from models import ( + HistoryEntry, + PythonCodeReviewAction, + PythonCodeReviewObservation, + PythonCodeReviewState, + RewardDetails, TaskGrade, ) from tasks import ReviewTask, list_tasks, select_task @@ -43,10 +43,10 @@ def _empty_grade() -> TaskGrade: quality_score=component_score(0.01), runtime_score=component_score(0.01), ) - - -def _reward_value(value: float) -> float: - return final_score_pipeline(value) + + +def _reward_value(value: float) -> float: + return final_score_pipeline(value) class PythonCodeReviewEnvironment( @@ -56,17 +56,17 @@ class PythonCodeReviewEnvironment( SUPPORTS_CONCURRENT_SESSIONS: bool = True - def __init__(self, verbose: bool = False, **_: Any) -> None: - super().__init__() - self.verbose = verbose - self._task: ReviewTask = list_tasks()[0] - self._current_code: str = self._task.starter_code - self._history: list[HistoryEntry] = [] - self._last_reward = RewardDetails(value=0.1, reason="Environment initialized.") - self._last_action_error: str | None = None - self._current_grade = _empty_grade() - self._state = PythonCodeReviewState(episode_id=str(uuid4()), step_count=0) - self.reset() + def __init__(self, verbose: bool = False, **_: Any) -> None: + super().__init__() + self.verbose = verbose + self._task: ReviewTask = list_tasks()[0] + self._current_code: str = self._task.starter_code + self._history: list[HistoryEntry] = [] + self._last_reward = RewardDetails(value=0.1, reason="Environment initialized.") + self._last_action_error: str | None = None + self._current_grade = _empty_grade() + self._state = PythonCodeReviewState(episode_id=str(uuid4()), step_count=0) + self.reset() def reset( self, @@ -74,17 +74,17 @@ class PythonCodeReviewEnvironment( episode_id: Optional[str] = None, **kwargs: Any, ) -> PythonCodeReviewObservation: - task_id = kwargs.get("task_id") - self._task = select_task(seed=seed, task_id=task_id) - self._current_code = self._task.starter_code - self._history = [] - self._last_action_error = None - self._last_reward = RewardDetails(value=0.1, reason="Environment reset.") - self._current_grade, self._last_action_error = self._safe_grade_task( - self._task, - self._current_code, - include_hidden=False, - ) + task_id = kwargs.get("task_id") + self._task = select_task(seed=seed, task_id=task_id) + self._current_code = self._task.starter_code + self._history = [] + self._last_action_error = None + self._last_reward = RewardDetails(value=0.1, reason="Environment reset.") + self._current_grade, self._last_action_error = self._safe_grade_task( + self._task, + self._current_code, + include_hidden=False, + ) self._state = PythonCodeReviewState( episode_id=episode_id or str(uuid4()), @@ -143,22 +143,22 @@ class PythonCodeReviewEnvironment( ) return observation, reward.value, observation.done, {"task_id": observation.task_id, "score": observation.score} - previous_grade = self._current_grade - status = "" - invalid_action = False - code_changed = False - use_hidden_grading = False - action_error: str | None = None - - if action.action_type == "edit_code": - if not action.code or not action.code.strip(): - invalid_action = True - status = "edit_code requires a non-empty code payload." - action_error = status - else: - code_changed = action.code != self._current_code - self._current_code = action.code - status = "Updated working copy from agent patch." + previous_grade = self._current_grade + status = "" + invalid_action = False + code_changed = False + use_hidden_grading = False + action_error: str | None = None + + if action.action_type == "edit_code": + if not action.code or not action.code.strip(): + invalid_action = True + status = "edit_code requires a non-empty code payload." + action_error = status + else: + code_changed = action.code != self._current_code + self._current_code = action.code + status = "Updated working copy from agent patch." elif action.action_type == "submit_solution": if action.code is not None and action.code.strip(): code_changed = action.code != self._current_code @@ -169,30 +169,30 @@ class PythonCodeReviewEnvironment( status = "Executed public validation suite." elif action.action_type == "analyze_code": status = "Generated static review summary." - else: # pragma: no cover - invalid_action = True - status = f"Unsupported action_type: {action.action_type}" - action_error = status + else: # pragma: no cover + invalid_action = True + status = f"Unsupported action_type: {action.action_type}" + action_error = status self._state.step_count += 1 - if invalid_action: - current_grade = previous_grade - else: - current_grade, grade_error = self._safe_grade_task( - self._task, - self._current_code, - include_hidden=use_hidden_grading, - timeout_s=timeout_s or 3.0, - ) - if grade_error: - action_error = grade_error - status = f"{status} Grading fallback used." - if action.action_type == "analyze_code": - status = self._analysis_status(current_grade) - elif action.action_type == "run_tests": - status = self._run_tests_status(current_grade, use_hidden_grading) - elif action.action_type == "submit_solution": + if invalid_action: + current_grade = previous_grade + else: + current_grade, grade_error = self._safe_grade_task( + self._task, + self._current_code, + include_hidden=use_hidden_grading, + timeout_s=timeout_s or 3.0, + ) + if grade_error: + action_error = grade_error + status = f"{status} Grading fallback used." + if action.action_type == "analyze_code": + status = self._analysis_status(current_grade) + elif action.action_type == "run_tests": + status = self._run_tests_status(current_grade, use_hidden_grading) + elif action.action_type == "submit_solution": status = self._submission_status(current_grade) done = use_hidden_grading or self._state.step_count >= self._task.max_steps @@ -217,11 +217,11 @@ class PythonCodeReviewEnvironment( reward=reward_details.value, ) ) - - self._current_grade = current_grade - self._last_reward = reward_details - self._last_action_error = action_error - attempts_remaining = max(self._task.max_steps - self._state.step_count, 0) + + self._current_grade = current_grade + self._last_reward = reward_details + self._last_action_error = action_error + attempts_remaining = max(self._task.max_steps - self._state.step_count, 0) self._state.task_id = self._task.task_id self._state.difficulty = self._task.difficulty @@ -234,19 +234,19 @@ class PythonCodeReviewEnvironment( self._state.score = current_grade.score self._state.done = done - observation = self._build_observation( - grade=current_grade, - status=status, - reward_details=reward_details, - ) - return observation, reward_details.value, observation.done, { - "task_id": observation.task_id, - "score": observation.score, - "done": observation.done, - "attempts_remaining": observation.attempts_remaining, - "last_action_status": observation.last_action_status, - "last_action_error": observation.last_action_error, - } + observation = self._build_observation( + grade=current_grade, + status=status, + reward_details=reward_details, + ) + return observation, reward_details.value, observation.done, { + "task_id": observation.task_id, + "score": observation.score, + "done": observation.done, + "attempts_remaining": observation.attempts_remaining, + "last_action_status": observation.last_action_status, + "last_action_error": observation.last_action_error, + } @property def state(self) -> PythonCodeReviewState: @@ -268,102 +268,102 @@ class PythonCodeReviewEnvironment( current_code=self._current_code, errors=self._format_errors(grade), test_results=self._format_test_results(grade), - visible_tests=list(self._task.visible_tests), - history=list(self._history), - attempts_remaining=self._state.attempts_remaining, - last_action_status=status, - last_action_error=self._last_action_error, - score=grade.score, - reward=reward_details.value, - done=self._state.done, - reward_details=reward_details, - metadata={ - "benchmark": "python_code_review_env", - "goal": self._task.goal, - "repo_summary": self._task.repo_summary, - "changed_files": self._task.changed_files, - "available_files": self._task.available_files, - "grade_details": grade.details, + visible_tests=list(self._task.visible_tests), + history=list(self._history), + attempts_remaining=self._state.attempts_remaining, + last_action_status=status, + last_action_error=self._last_action_error, + score=grade.score, + reward=reward_details.value, + done=self._state.done, + reward_details=reward_details, + metadata={ + "benchmark": "python_code_review_env", + "goal": self._task.goal, + "repo_summary": self._task.repo_summary, + "changed_files": self._task.changed_files, + "available_files": self._task.available_files, + "grade_details": grade.details, }, ) - def _compute_reward( - self, - *, - previous_grade: TaskGrade, + def _compute_reward( + self, + *, + previous_grade: TaskGrade, current_grade: TaskGrade, action: PythonCodeReviewAction, invalid_action: bool, timed_out: bool, code_changed: bool, final_submission: bool, - ) -> RewardDetails: - prev_score = previous_grade.score - curr_score = current_grade.score - prev_syntax = previous_grade.syntax_score - curr_syntax = current_grade.syntax_score - prev_quality = previous_grade.quality_score - curr_quality = current_grade.quality_score - prev_rate = safe_ratio(previous_grade.tests_passed, previous_grade.tests_total) - curr_rate = safe_ratio(current_grade.tests_passed, current_grade.tests_total) - prev_runtime = previous_grade.runtime_score - curr_runtime = current_grade.runtime_score - prev_compile_health = 0.1 if str(previous_grade.details.get("compile_error", "")).strip() else 0.95 - curr_compile_health = 0.1 if str(current_grade.details.get("compile_error", "")).strip() else 0.95 - - syntax_reward = max(curr_syntax - prev_syntax, 0.0) * 0.18 - test_reward = max(curr_rate - prev_rate, 0.0) * 0.22 - progress_delta = max(curr_score - prev_score, 0.0) * 0.24 - quality_bonus = max(curr_quality - prev_quality, 0.0) * 0.12 - runtime_bonus = max(curr_runtime - prev_runtime, 0.0) * 0.10 - error_reduction_bonus = max(curr_compile_health - prev_compile_health, 0.0) * 0.14 - completion_bonus = (0.04 + 0.10 * curr_rate) * float(final_submission) - correctness_bonus = max(curr_score - 0.5, 0.0) * 0.12 * float(final_submission) - - invalid_action_penalty = (0.04 + (0.08 * (1.0 - prev_score))) if invalid_action else 0.0 - timeout_penalty = (0.05 + (0.06 * max(curr_runtime, prev_runtime))) if timed_out else 0.0 - regression_penalty = max(prev_score - curr_score, 0.0) * 0.24 - stagnation_penalty = (0.02 + (0.04 * prev_score)) if action.action_type == "edit_code" and not code_changed else 0.0 - - raw_value = ( - 2.0 * (curr_score - 0.5) - + 1.2 * (curr_rate - prev_rate) - + 0.8 * (curr_quality - prev_quality) - + 0.7 * (curr_runtime - prev_runtime) - + 0.9 * (curr_syntax - prev_syntax) - + 0.6 * (curr_compile_health - prev_compile_health) - + syntax_reward - + test_reward - + progress_delta - + quality_bonus - + runtime_bonus - + error_reduction_bonus - + completion_bonus - + correctness_bonus - - invalid_action_penalty - - timeout_penalty - - regression_penalty - - stagnation_penalty - ) - value = _reward_value(raw_value) - - reason_parts = [] - if syntax_reward: - reason_parts.append("syntax fixed") + ) -> RewardDetails: + prev_score = previous_grade.score + curr_score = current_grade.score + prev_syntax = previous_grade.syntax_score + curr_syntax = current_grade.syntax_score + prev_quality = previous_grade.quality_score + curr_quality = current_grade.quality_score + prev_rate = safe_ratio(previous_grade.tests_passed, previous_grade.tests_total) + curr_rate = safe_ratio(current_grade.tests_passed, current_grade.tests_total) + prev_runtime = previous_grade.runtime_score + curr_runtime = current_grade.runtime_score + prev_compile_health = 0.1 if str(previous_grade.details.get("compile_error", "")).strip() else 0.95 + curr_compile_health = 0.1 if str(current_grade.details.get("compile_error", "")).strip() else 0.95 + + syntax_reward = max(curr_syntax - prev_syntax, 0.0) * 0.18 + test_reward = max(curr_rate - prev_rate, 0.0) * 0.22 + progress_delta = max(curr_score - prev_score, 0.0) * 0.24 + quality_bonus = max(curr_quality - prev_quality, 0.0) * 0.12 + runtime_bonus = max(curr_runtime - prev_runtime, 0.0) * 0.10 + error_reduction_bonus = max(curr_compile_health - prev_compile_health, 0.0) * 0.14 + completion_bonus = (0.04 + 0.10 * curr_rate) * float(final_submission) + correctness_bonus = max(curr_score - 0.5, 0.0) * 0.12 * float(final_submission) + + invalid_action_penalty = (0.04 + (0.08 * (1.0 - prev_score))) if invalid_action else 0.0 + timeout_penalty = (0.05 + (0.06 * max(curr_runtime, prev_runtime))) if timed_out else 0.0 + regression_penalty = max(prev_score - curr_score, 0.0) * 0.24 + stagnation_penalty = (0.02 + (0.04 * prev_score)) if action.action_type == "edit_code" and not code_changed else 0.0 + + raw_value = ( + 2.0 * (curr_score - 0.5) + + 1.2 * (curr_rate - prev_rate) + + 0.8 * (curr_quality - prev_quality) + + 0.7 * (curr_runtime - prev_runtime) + + 0.9 * (curr_syntax - prev_syntax) + + 0.6 * (curr_compile_health - prev_compile_health) + + syntax_reward + + test_reward + + progress_delta + + quality_bonus + + runtime_bonus + + error_reduction_bonus + + completion_bonus + + correctness_bonus + - invalid_action_penalty + - timeout_penalty + - regression_penalty + - stagnation_penalty + ) + value = _reward_value(raw_value) + + reason_parts = [] + if syntax_reward: + reason_parts.append("syntax fixed") if test_reward: reason_parts.append("public test progress") if progress_delta: reason_parts.append("overall score improved") - if quality_bonus: - reason_parts.append("code quality improved") - if error_reduction_bonus: - reason_parts.append("errors removed") - if completion_bonus: - reason_parts.append("task completed") - if runtime_bonus: - reason_parts.append("runtime improved") - if correctness_bonus: - reason_parts.append("full correctness bonus") + if quality_bonus: + reason_parts.append("code quality improved") + if error_reduction_bonus: + reason_parts.append("errors removed") + if completion_bonus: + reason_parts.append("task completed") + if runtime_bonus: + reason_parts.append("runtime improved") + if correctness_bonus: + reason_parts.append("full correctness bonus") if invalid_action_penalty: reason_parts.append("invalid action penalty") if timeout_penalty: @@ -372,53 +372,53 @@ class PythonCodeReviewEnvironment( reason_parts.append("regression penalty") if stagnation_penalty: reason_parts.append("unchanged patch penalty") - if not reason_parts: - reason_parts.append("no meaningful state change") - - return RewardDetails( - value=safe_score(value), - syntax_reward=round(syntax_reward, 6), - test_reward=round(test_reward, 6), - correctness_bonus=round(correctness_bonus, 6), - quality_bonus=round(quality_bonus, 6), - error_reduction_bonus=round(error_reduction_bonus, 6), - completion_bonus=round(completion_bonus, 6), - runtime_bonus=round(runtime_bonus, 6), - progress_delta=round(progress_delta, 6), - invalid_action_penalty=round(invalid_action_penalty, 6), - timeout_penalty=round(timeout_penalty, 6), - regression_penalty=round(regression_penalty, 6), - stagnation_penalty=round(stagnation_penalty, 6), - reason=", ".join(reason_parts), - prev_score=safe_score(prev_score), - curr_score=safe_score(curr_score), - code_changed=code_changed, - ) - - def _format_errors(self, grade: TaskGrade) -> str: - compile_error = str(grade.details.get("compile_error", "")).strip() - if compile_error: - return compile_error - return "Code parses successfully." - - def _safe_grade_task( - self, - task: ReviewTask, - code: str, - *, - include_hidden: bool, - timeout_s: float = 3.0, - ) -> tuple[TaskGrade, str | None]: - try: - return ( - grade_task(task, code, include_hidden=include_hidden, timeout_s=timeout_s), - None, - ) - except Exception as exc: # pragma: no cover - return _empty_grade(), f"{type(exc).__name__}: {exc}" - - def _format_test_results(self, grade: TaskGrade) -> str: - parts = [grade.details.get("test_summary", "No test feedback available.")] + if not reason_parts: + reason_parts.append("no meaningful state change") + + return RewardDetails( + value=safe_score(value), + syntax_reward=round(syntax_reward, 6), + test_reward=round(test_reward, 6), + correctness_bonus=round(correctness_bonus, 6), + quality_bonus=round(quality_bonus, 6), + error_reduction_bonus=round(error_reduction_bonus, 6), + completion_bonus=round(completion_bonus, 6), + runtime_bonus=round(runtime_bonus, 6), + progress_delta=round(progress_delta, 6), + invalid_action_penalty=round(invalid_action_penalty, 6), + timeout_penalty=round(timeout_penalty, 6), + regression_penalty=round(regression_penalty, 6), + stagnation_penalty=round(stagnation_penalty, 6), + reason=", ".join(reason_parts), + prev_score=safe_score(prev_score), + curr_score=safe_score(curr_score), + code_changed=code_changed, + ) + + def _format_errors(self, grade: TaskGrade) -> str: + compile_error = str(grade.details.get("compile_error", "")).strip() + if compile_error: + return compile_error + return "Code parses successfully." + + def _safe_grade_task( + self, + task: ReviewTask, + code: str, + *, + include_hidden: bool, + timeout_s: float = 3.0, + ) -> tuple[TaskGrade, str | None]: + try: + return ( + grade_task(task, code, include_hidden=include_hidden, timeout_s=timeout_s), + None, + ) + except Exception as exc: # pragma: no cover + return _empty_grade(), f"{type(exc).__name__}: {exc}" + + def _format_test_results(self, grade: TaskGrade) -> str: + parts = [grade.details.get("test_summary", "No test feedback available.")] benchmark = grade.details.get("benchmark") if isinstance(benchmark, dict): parts.append( diff --git a/server/requirements.runtime.txt b/server/requirements.runtime.txt new file mode 100644 index 0000000000000000000000000000000000000000..72770419dcc6e8e81f68f1cfff99b0abc07d39e9 --- /dev/null +++ b/server/requirements.runtime.txt @@ -0,0 +1,4 @@ +openenv-core>=0.2.2 +fastapi>=0.111.0 +openai>=1.76.0 +uvicorn>=0.30.0 diff --git a/server/requirements.txt b/server/requirements.txt index fd8007025698719912d1904e87ef48ad34c543fa..f18e480e278b0d86cabfc279d871a6fcdf715203 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,6 +1,8 @@ openenv-core[core]>=0.2.2 fastapi>=0.111.0 +gradio>=5.26.0 uvicorn>=0.30.0 openai>=1.76.0 +streamlit>=1.44.0 torch>=2.2.0 transformers>=4.45.0 diff --git a/services/__init__.py b/services/__init__.py index 411796a65cc40502bd32c5069c45edd05c8c0d95..f550466fcabc6cf41f476d66419384c4a1faaa22 100644 --- a/services/__init__.py +++ b/services/__init__.py @@ -1,7 +1,7 @@ -"""Service layer for orchestrating analysis, suggestions, and rewards.""" - -from .analysis_service import AnalysisService -from .reward_service import RewardService -from .suggestion_service import SuggestionService - -__all__ = ["AnalysisService", "RewardService", "SuggestionService"] +"""Service layer for orchestrating analysis, suggestions, and rewards.""" + +from .analysis_service import AnalysisService +from .reward_service import RewardService +from .suggestion_service import SuggestionService + +__all__ = ["AnalysisService", "RewardService", "SuggestionService"] diff --git a/services/analysis_service.py b/services/analysis_service.py index bf64c2fe68c587d7930280fa3b7d1995fa572838..f7ce4fd3289d90a528128ac09aae9683242856bb 100644 --- a/services/analysis_service.py +++ b/services/analysis_service.py @@ -1,86 +1,33 @@ -"""Orchestration layer for AI-powered Python code review.""" +"""Orchestration layer for multi-domain code analysis.""" from __future__ import annotations import time -from typing import Any, Callable +from typing import Any, Callable, Dict from analyzers import analyze_data_science_code, analyze_dsa_code, analyze_ml_code, analyze_web_code from models import PyTorchCodeAnalyzerModel from schemas.request import AnalyzeCodeRequest -from schemas.response import ( - AnalysisIssue, - AnalyzeCodeResponse, - DomainAnalysis, - ModelPrediction, - StaticAnalysisSummary, -) +from schemas.response import AnalyzeCodeResponse, DomainAnalysis, StaticAnalysisSummary from services.reward_service import RewardService from services.suggestion_service import SuggestionService from utils import estimate_complexity, parse_code_structure -def _clamp_unit(value: float) -> float: - return max(0.0, min(1.0, float(value))) - - -def _lint_score(parsed: dict[str, Any]) -> float: +def _lint_score(parsed: Dict[str, Any]) -> float: """Convert structural smells into a normalized lint-style score.""" score = 1.0 if not parsed.get("syntax_valid", True): score -= 0.45 - score -= min(int(parsed.get("long_lines", 0) or 0), 5) * 0.03 + score -= min(parsed.get("long_lines", 0), 5) * 0.03 if parsed.get("tabs_used"): score -= 0.1 if parsed.get("trailing_whitespace_lines"): score -= 0.05 if parsed.get("docstring_ratio", 0.0) == 0.0 and parsed.get("function_names"): score -= 0.08 - return round(_clamp_unit(score), 4) - - -def _static_issues(parsed: dict[str, Any], complexity: dict[str, Any]) -> list[AnalysisIssue]: - """Turn parser and complexity heuristics into review issues.""" - - issues: list[AnalysisIssue] = [] - if not parsed.get("syntax_valid", True): - issues.append( - AnalysisIssue( - title="Syntax error blocks execution", - category="correctness", - severity="high", - description=str(parsed.get("syntax_error", "Python failed to parse the snippet.")), - ) - ) - if int(parsed.get("max_loop_depth", 0) or 0) >= 2: - issues.append( - AnalysisIssue( - title="Nested loops increase runtime risk", - category="performance", - severity="medium", - description="The current control flow suggests a brute-force path that may not scale on larger inputs.", - ) - ) - if int(complexity.get("cyclomatic_complexity", 1) or 1) >= 7: - issues.append( - AnalysisIssue( - title="Cyclomatic complexity is elevated", - category="maintainability", - severity="medium", - description="Branch-heavy code is harder to review, test, and optimize confidently.", - ) - ) - if parsed.get("docstring_ratio", 0.0) == 0.0 and parsed.get("function_names"): - issues.append( - AnalysisIssue( - title="Missing public-function documentation", - category="style", - severity="low", - description="Short docstrings would make the expected contract and edge cases easier to review.", - ) - ) - return issues + return round(max(0.0, min(1.0, score)), 4) class AnalysisService: @@ -90,7 +37,7 @@ class AnalysisService: self._model: PyTorchCodeAnalyzerModel | None = None self.reward_service = RewardService() self.suggestion_service = SuggestionService() - self._analyzers: dict[str, Callable[[str, dict[str, Any], dict[str, Any]], DomainAnalysis]] = { + self._analyzers: Dict[str, Callable[[str, Dict[str, Any], Dict[str, Any]], DomainAnalysis]] = { "dsa": analyze_dsa_code, "data_science": analyze_data_science_code, "ml_dl": analyze_ml_code, @@ -103,156 +50,90 @@ class AnalysisService: self._model = PyTorchCodeAnalyzerModel() return self._model - def _heuristic_domain_scores(self, parsed: dict[str, Any], code: str) -> dict[str, float]: + def _heuristic_domain_scores(self, parsed: Dict[str, Any], code: str) -> Dict[str, float]: """Derive domain priors from imports and syntax-level hints.""" scores = { - "dsa": 0.22 - + (0.18 if parsed.get("uses_recursion") else 0.0) - + (0.18 if int(parsed.get("max_loop_depth", 0) or 0) >= 1 else 0.0), - "data_science": 0.22 + (0.38 if parsed.get("uses_pandas") or parsed.get("uses_numpy") else 0.0), - "ml_dl": 0.22 + (0.38 if parsed.get("uses_torch") or parsed.get("uses_sklearn") else 0.0), - "web": 0.22 - + (0.38 if parsed.get("uses_fastapi") or parsed.get("uses_flask") else 0.0) - + (0.12 if parsed.get("route_decorators") else 0.0), - "general": 0.26, + "dsa": 0.2 + (0.15 if parsed.get("uses_recursion") else 0.0) + (0.15 if parsed.get("max_loop_depth", 0) >= 1 else 0.0), + "data_science": 0.2 + (0.35 if parsed.get("uses_pandas") or parsed.get("uses_numpy") else 0.0), + "ml_dl": 0.2 + (0.35 if parsed.get("uses_torch") or parsed.get("uses_sklearn") else 0.0), + "web": 0.2 + (0.35 if parsed.get("uses_fastapi") or parsed.get("uses_flask") else 0.0) + (0.1 if parsed.get("route_decorators") else 0.0), + "general": 0.2, } - lowered = code.lower() - if "fastapi" in lowered: - scores["web"] += 0.12 - if "pandas" in lowered or "numpy" in lowered: + if "fastapi" in code.lower(): + scores["web"] += 0.1 + if "pandas" in code.lower() or "numpy" in code.lower(): scores["data_science"] += 0.1 - if "torch" in lowered or "sklearn" in lowered: + if "torch" in code.lower(): scores["ml_dl"] += 0.1 if "while" in code or "for" in code: - scores["dsa"] += 0.06 + scores["dsa"] += 0.05 return {key: round(min(value, 0.99), 4) for key, value in scores.items()} - def _general_domain_analysis(self, parsed: dict[str, Any], complexity: dict[str, Any]) -> DomainAnalysis: - """Fallback analysis when no specialized domain is strongly selected.""" - - suggestions = [ - "Keep functions small, validate inputs explicitly, and add focused tests for edge cases.", - ] - if int(parsed.get("max_loop_depth", 0) or 0) >= 2: - suggestions.append("Consider replacing repeated scans with a precomputed dictionary or set.") - return DomainAnalysis( - domain="general", - domain_score=round(_clamp_unit(0.62 - (0.12 * float(complexity["complexity_penalty"]))), 4), - issues=_static_issues(parsed, complexity)[:2], - suggestions=suggestions, - highlights={ - "cyclomatic_complexity": float(complexity["cyclomatic_complexity"]), - "max_loop_depth": float(parsed.get("max_loop_depth", 0) or 0), - "lint_score": float(_lint_score(parsed)), - }, - ) - def analyze(self, request: AnalyzeCodeRequest) -> AnalyzeCodeResponse: - """Run the complete static-plus-ML code review pipeline.""" + """Run the complete multi-domain analysis pipeline.""" started = time.perf_counter() parsed = parse_code_structure(request.code) complexity = estimate_complexity(parsed, request.code) - lint_score = _lint_score(parsed) - model_prediction = self.model.predict( - request.code, - request.context_window, - request.traceback_text, - parsed, - ) + model_prediction = self.model.predict(request.code, request.context_window, parsed) heuristic_scores = self._heuristic_domain_scores(parsed, request.code) - combined_scores: dict[str, float] = {} + combined_scores = {} for domain, heuristic_score in heuristic_scores.items(): model_score = float(model_prediction["domain_scores"].get(domain, 0.2)) - combined_scores[domain] = round((0.65 * model_score) + (0.35 * heuristic_score), 4) + combined_scores[domain] = round((0.6 * model_score) + (0.4 * heuristic_score), 4) detected_domain = request.domain_hint if request.domain_hint != "auto" else max(combined_scores, key=combined_scores.get) analyzer = self._analyzers.get(detected_domain) domain_analysis = ( analyzer(request.code, parsed, complexity) if analyzer is not None - else self._general_domain_analysis(parsed, complexity) + else DomainAnalysis( + domain="general", + domain_score=0.6, + issues=[], + suggestions=["Add stronger domain-specific context for deeper analysis."], + highlights={}, + ) + ) + + lint_score = _lint_score(parsed) + score_breakdown = self.reward_service.compute( + ml_score=float(model_prediction["ml_quality_score"]), + domain_score=domain_analysis.domain_score, + lint_score=lint_score, + complexity_penalty=float(complexity["complexity_penalty"]), ) - static_issues = _static_issues(parsed, complexity) static_analysis = StaticAnalysisSummary( syntax_valid=bool(parsed["syntax_valid"]), syntax_error=str(parsed["syntax_error"]), cyclomatic_complexity=int(complexity["cyclomatic_complexity"]), line_count=int(parsed["line_count"]), - max_nesting_depth=int(parsed["max_nesting_depth"]), max_loop_depth=int(parsed["max_loop_depth"]), time_complexity=str(complexity["time_complexity"]), space_complexity=str(complexity["space_complexity"]), - lint_score=lint_score, - docstring_coverage=float(parsed["docstring_ratio"]), detected_imports=list(parsed["imports"]), code_smells=list(parsed["code_smells"]), - issues=static_issues, - ) - - score_breakdown = self.reward_service.compute( - ml_score=float(model_prediction["ml_quality_score"]), - domain_score=domain_analysis.domain_score, - lint_score=lint_score, - complexity_penalty=float(complexity["complexity_penalty"]), - maintainability_score=float(model_prediction["maintainability_score"]), - issue_probabilities=dict(model_prediction["issue_probabilities"]), - ) - suggestions = self.suggestion_service.build_suggestions( - domain_analysis=domain_analysis, - static_analysis=static_analysis, ) improvement_plan = self.suggestion_service.build_improvement_plan( domain_analysis=domain_analysis, static_analysis=static_analysis, ) - auto_fix_preview = self.suggestion_service.build_auto_fix_preview( - domain_analysis=domain_analysis, - static_analysis=static_analysis, - ) - summary = ( - f"Reviewed Python code as `{detected_domain}` with an ML quality score of {score_breakdown.ml_score:.0%}, " - f"lint score {score_breakdown.lint_score:.0%}, and RL-ready reward {score_breakdown.reward:.0%}." + f"Detected `{detected_domain}` code with a model score of {score_breakdown.ml_score:.0%}, " + f"domain score {score_breakdown.domain_score:.0%}, and final reward {score_breakdown.reward:.0%}." ) - model_notes = list(model_prediction["notes"]) - if static_issues: - model_notes.append(f"Static analyzer found {len(static_issues)} review issue(s).") - return AnalyzeCodeResponse( detected_domain=detected_domain, # type: ignore[arg-type] domain_confidences=combined_scores, score_breakdown=score_breakdown, static_analysis=static_analysis, - model_prediction=ModelPrediction( - quality_label=str(model_prediction["quality_label"]), # type: ignore[arg-type] - quality_score=float(model_prediction["quality_score"]), - maintainability_score=float(model_prediction["maintainability_score"]), - issue_probabilities=dict(model_prediction["issue_probabilities"]), - notes=model_notes, - ), domain_analysis=domain_analysis, - suggestions=suggestions if request.enable_suggestions else [], - improvement_plan=improvement_plan if request.enable_suggestions else [], - auto_fix_preview=auto_fix_preview if request.enable_suggestions else [], - score_visualization={ - "reward": score_breakdown.reward, - "ml_quality": score_breakdown.ml_score, - "lint_score": score_breakdown.lint_score, - "maintainability": score_breakdown.maintainability_score, - "security": score_breakdown.security_score, - "readability": score_breakdown.readability_score, - "quality_signal": score_breakdown.quality_signal, - "error_reduction_signal": score_breakdown.error_reduction_signal, - "completion_signal": score_breakdown.completion_signal, - "complexity_penalty": score_breakdown.complexity_penalty, - }, + improvement_plan=improvement_plan, model_backend=str(model_prediction["backend_name"]), model_id=str(model_prediction["model_id"]), summary=summary, context_window=request.context_window, - filename=request.filename, analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2), ) diff --git a/services/reward_service.py b/services/reward_service.py index f51954281626205d7f098c0b29cbe7747fb06cb0..b2f5607ee66241da0012f5c93d93157b60d1d85e 100644 --- a/services/reward_service.py +++ b/services/reward_service.py @@ -5,50 +5,32 @@ from __future__ import annotations from schemas.response import ScoreBreakdown -def _clamp_unit(value: float) -> float: - return max(0.0, min(1.0, float(value))) - - class RewardService: - """Compute reward scores from model, lint, complexity, and issue-risk signals.""" - - def compute( - self, - *, - ml_score: float, - domain_score: float, - lint_score: float, - complexity_penalty: float, - maintainability_score: float, - issue_probabilities: dict[str, float], - ) -> ScoreBreakdown: - """Apply RL-friendly reward shaping to the code review analysis signals.""" - - security_score = _clamp_unit(1.0 - issue_probabilities.get("security", 0.0)) - readability_score = _clamp_unit((0.6 * lint_score) + (0.4 * maintainability_score)) - quality_signal = _clamp_unit((0.55 * ml_score) + (0.25 * maintainability_score) + (0.20 * domain_score)) - error_reduction_signal = _clamp_unit((0.7 * lint_score) + (0.3 * (1.0 - complexity_penalty))) - completion_signal = _clamp_unit( - (0.4 * quality_signal) + (0.25 * readability_score) + (0.2 * security_score) + (0.15 * domain_score) + """Compute reward scores from model, domain, lint, and complexity signals.""" + + def compute(self, *, ml_score: float, domain_score: float, lint_score: float, complexity_penalty: float) -> ScoreBreakdown: + """Apply dynamic reward shaping based on quality, errors, and completion.""" + + quality_signal = max(0.0, min(1.0, (0.45 * ml_score) + (0.3 * domain_score) + (0.25 * lint_score))) + error_reduction_signal = max(0.0, min(1.0, lint_score - (0.6 * complexity_penalty))) + completion_signal = max(0.0, min(1.0, (ml_score + domain_score + lint_score) / 3.0)) + reward = max( + 0.0, + min( + 1.0, + (0.35 * quality_signal) + + (0.25 * completion_signal) + + (0.2 * error_reduction_signal) + + (0.1 * ml_score) + + (0.1 * domain_score) + - (0.15 * complexity_penalty), + ), ) - - reward = _clamp_unit( - (0.5 * ml_score) - + (0.18 * lint_score) - + (0.12 * maintainability_score) - + (0.10 * domain_score) - + (0.10 * security_score) - - (0.20 * complexity_penalty) - ) - return ScoreBreakdown( ml_score=round(ml_score, 4), domain_score=round(domain_score, 4), lint_score=round(lint_score, 4), complexity_penalty=round(complexity_penalty, 4), - maintainability_score=round(maintainability_score, 4), - security_score=round(security_score, 4), - readability_score=round(readability_score, 4), quality_signal=round(quality_signal, 4), error_reduction_signal=round(error_reduction_signal, 4), completion_signal=round(completion_signal, 4), diff --git a/services/suggestion_service.py b/services/suggestion_service.py index a57831a02a7d3b0e6172e7c38738b142fe4b89f7..2c5683754cc2ed4a9bb9f630d30e5ffea24a9b72 100644 --- a/services/suggestion_service.py +++ b/services/suggestion_service.py @@ -2,77 +2,11 @@ from __future__ import annotations -from schemas.response import DomainAnalysis, StaticAnalysisSummary, SuggestionItem +from schemas.response import DomainAnalysis, StaticAnalysisSummary class SuggestionService: - """Build high-signal improvement suggestions from analysis output.""" - - def build_suggestions( - self, - *, - domain_analysis: DomainAnalysis, - static_analysis: StaticAnalysisSummary, - ) -> list[SuggestionItem]: - """Return prioritized fixes tailored to the detected review signals.""" - - suggestions: list[SuggestionItem] = [] - - if not static_analysis.syntax_valid: - suggestions.append( - SuggestionItem( - priority="P0", - title="Fix the syntax error", - rationale="Static parsing failed, so downstream tests and model signals are less reliable.", - action=f"Resolve the parser issue first: {static_analysis.syntax_error}.", - category="correctness", - ) - ) - - if static_analysis.cyclomatic_complexity >= 6 or static_analysis.max_loop_depth >= 2: - suggestions.append( - SuggestionItem( - priority="P1", - title="Reduce branching or nested loops", - rationale="Higher structural complexity makes bugs more likely and lowers the RL reward.", - action="Extract helper functions or replace repeated scans with a dictionary, set, Counter, or vectorized operation.", - category="performance", - ) - ) - - if static_analysis.docstring_coverage == 0 and static_analysis.line_count > 0: - suggestions.append( - SuggestionItem( - priority="P2", - title="Add function-level documentation", - rationale="Docstrings improve review speed and make behavior clearer for future edits.", - action="Document the expected inputs, outputs, and edge cases in a short function docstring.", - category="style", - ) - ) - - for issue in domain_analysis.issues[:2]: - suggestions.append( - SuggestionItem( - priority="P1" if issue.severity != "high" else "P0", - title=issue.title, - rationale=issue.description, - action=domain_analysis.suggestions[0] if domain_analysis.suggestions else "Refactor the risky section and re-run analysis.", - category=issue.category, - ) - ) - - if not suggestions: - suggestions.append( - SuggestionItem( - priority="P2", - title="Strengthen review confidence", - rationale="No severe issues were detected, but explicit edge-case coverage still improves maintainability.", - action="Add targeted tests for empty input, boundary values, and malformed payloads.", - category="maintainability", - ) - ) - return suggestions[:4] + """Build high-signal improvement steps from analysis output.""" def build_improvement_plan(self, *, domain_analysis: DomainAnalysis, static_analysis: StaticAnalysisSummary) -> list[str]: """Return a compact three-step plan optimized for developer action.""" @@ -92,22 +26,3 @@ class SuggestionService: if not static_analysis.syntax_valid: step_one = f"Step 1 - Correctness and safety: fix the syntax error first ({static_analysis.syntax_error})." return [step_one, step_two, step_three] - - def build_auto_fix_preview( - self, - *, - domain_analysis: DomainAnalysis, - static_analysis: StaticAnalysisSummary, - ) -> list[str]: - """Generate compact auto-fix hints for the UI preview panel.""" - - preview: list[str] = [] - if not static_analysis.syntax_valid: - preview.append(f"Repair parser failure: {static_analysis.syntax_error}") - if static_analysis.max_loop_depth >= 2: - preview.append("Replace nested scans with a precomputed lookup table or aggregation structure.") - if static_analysis.docstring_coverage == 0: - preview.append("Add a short docstring describing the function contract and edge cases.") - if domain_analysis.suggestions: - preview.append(domain_analysis.suggestions[0]) - return preview[:3] diff --git a/tests/test_inference_runner.py b/tests/test_inference_runner.py index f415966f599f16c6ecf7d183804f71e2b1078503..980037ab2567a0072a8663d2cef17ee361ca5f85 100644 --- a/tests/test_inference_runner.py +++ b/tests/test_inference_runner.py @@ -1,108 +1,107 @@ -"""Smoke tests for the strict inference output contract.""" - +"""Smoke tests for the strict inference output contract.""" + from __future__ import annotations from dataclasses import dataclass, field from app.env.runner import InferenceRunner from app.models.inference import AgentDecision, InferenceConfig -from app.utils.runtime import format_reward - - -@dataclass -class _FakeObservation: - task_id: str - attempts_remaining: int - score: float - done: bool - history: list[object] = field(default_factory=list) - current_code: str = "print('broken')" - last_action_error: str | None = None - - -class _FakeEnv: - def __init__(self) -> None: - self._step = 0 - - def reset(self, *, task_id: str) -> _FakeObservation: - return _FakeObservation(task_id=task_id, attempts_remaining=4, score=0.2, done=False) - - def step_result(self, action: object) -> tuple[_FakeObservation, float, bool, dict[str, object]]: - self._step += 1 - if self._step == 1: - return ( - _FakeObservation("demo_task", 3, 0.45, False, current_code="candidate"), - 0.45, - False, - {"last_action_error": None}, - ) - if self._step == 2: - return ( - _FakeObservation("demo_task", 2, 0.97, True, current_code="reference"), - 0.97, - True, - {"last_action_error": None}, - ) - raise AssertionError("runner stepped too many times") - - -class _FakeAgent: - def __init__(self) -> None: - self._step = 0 - - def act(self, observation: object) -> AgentDecision: - self._step += 1 - if self._step == 1: - return AgentDecision(action_type="run_tests") - return AgentDecision(action_type="submit_solution") - - -class _LowScoreEnv(_FakeEnv): - def step_result(self, action: object) -> tuple[_FakeObservation, float, bool, dict[str, object]]: - self._step += 1 - return ( - _FakeObservation("demo_task", 2, 0.60, True, current_code="candidate"), - 0.60, - True, - {"last_action_error": None}, - ) - - -def test_inference_runner_emits_strict_lines(capsys) -> None: - runner = InferenceRunner(InferenceConfig.from_env()) - runner.agent = _FakeAgent() - runner._create_env = lambda: _FakeEnv() # type: ignore[method-assign] - runner.run_task("demo_task") - - captured = capsys.readouterr().out.strip().splitlines() - assert captured == [ - f"[START] task=demo_task env={runner.config.benchmark_name} model={runner.config.model_name}", - "[STEP] step=1 action=run_tests reward=0.45 done=false error=null", - "[STEP] step=2 action=submit_solution reward=0.97 done=true error=null", - "[END] success=true steps=2 rewards=0.45,0.97", - ] - - -def test_inference_runner_marks_low_score_submission_unsuccessful(capsys) -> None: - runner = InferenceRunner(InferenceConfig.from_env()) - runner.agent = _FakeAgent() - runner._create_env = lambda: _LowScoreEnv() # type: ignore[method-assign] - runner.run_task("demo_task") - - captured = capsys.readouterr().out.strip().splitlines() - assert captured[-1] == "[END] success=false steps=1 rewards=0.60" - - -def test_inference_config_prefers_openai_key_for_openai_base_url(monkeypatch) -> None: - monkeypatch.setenv("API_BASE_URL", "https://api.openai.com/v1") - monkeypatch.setenv("OPENAI_API_KEY", "openai-key") - monkeypatch.setenv("HF_TOKEN", "hf-key") - - config = InferenceConfig.from_env() - - assert config.api_key == "openai-key" - - + + +@dataclass +class _FakeObservation: + task_id: str + attempts_remaining: int + score: float + done: bool + history: list[object] = field(default_factory=list) + current_code: str = "print('broken')" + last_action_error: str | None = None + + +class _FakeEnv: + def __init__(self) -> None: + self._step = 0 + + def reset(self, *, task_id: str) -> _FakeObservation: + return _FakeObservation(task_id=task_id, attempts_remaining=4, score=0.2, done=False) + + def step_result(self, action: object) -> tuple[_FakeObservation, float, bool, dict[str, object]]: + self._step += 1 + if self._step == 1: + return ( + _FakeObservation("demo_task", 3, 0.45, False, current_code="candidate"), + 0.45, + False, + {"last_action_error": None}, + ) + if self._step == 2: + return ( + _FakeObservation("demo_task", 2, 0.97, True, current_code="reference"), + 0.97, + True, + {"last_action_error": None}, + ) + raise AssertionError("runner stepped too many times") + + +class _FakeAgent: + def __init__(self) -> None: + self._step = 0 + + def act(self, observation: object) -> AgentDecision: + self._step += 1 + if self._step == 1: + return AgentDecision(action_type="run_tests") + return AgentDecision(action_type="submit_solution") + + +class _LowScoreEnv(_FakeEnv): + def step_result(self, action: object) -> tuple[_FakeObservation, float, bool, dict[str, object]]: + self._step += 1 + return ( + _FakeObservation("demo_task", 2, 0.60, True, current_code="candidate"), + 0.60, + True, + {"last_action_error": None}, + ) + + +def test_inference_runner_emits_strict_lines(capsys) -> None: + runner = InferenceRunner(InferenceConfig.from_env()) + runner.agent = _FakeAgent() + runner._create_env = lambda: _FakeEnv() # type: ignore[method-assign] + runner.run_task("demo_task") + + captured = capsys.readouterr().out.strip().splitlines() + assert captured == [ + f"[START] task=demo_task env={runner.config.benchmark_name} model={runner.config.model_name}", + "[STEP] step=1 action=run_tests reward=0.45 done=false error=null", + "[STEP] step=2 action=submit_solution reward=0.97 done=true error=null", + "[END] success=true steps=2 rewards=0.45,0.97", + ] + + +def test_inference_runner_marks_low_score_submission_unsuccessful(capsys) -> None: + runner = InferenceRunner(InferenceConfig.from_env()) + runner.agent = _FakeAgent() + runner._create_env = lambda: _LowScoreEnv() # type: ignore[method-assign] + runner.run_task("demo_task") + + captured = capsys.readouterr().out.strip().splitlines() + assert captured[-1] == "[END] success=false steps=1 rewards=0.60" + + +def test_inference_config_prefers_openai_key_for_openai_base_url(monkeypatch) -> None: + monkeypatch.setenv("API_BASE_URL", "https://api.openai.com/v1") + monkeypatch.setenv("OPENAI_API_KEY", "openai-key") + monkeypatch.setenv("HF_TOKEN", "hf-key") + + config = InferenceConfig.from_env() + + assert config.api_key == "openai-key" + + def test_inference_config_prefers_hf_key_for_hf_router(monkeypatch) -> None: monkeypatch.setenv("API_BASE_URL", "https://router.huggingface.co/v1") monkeypatch.setenv("OPENAI_API_KEY", "openai-key") @@ -111,8 +110,3 @@ def test_inference_config_prefers_hf_key_for_hf_router(monkeypatch) -> None: config = InferenceConfig.from_env() assert config.api_key == "hf-key" - - -def test_reward_formatting_stays_in_strict_two_decimal_interval() -> None: - assert format_reward(0.999999) == "0.99" - assert format_reward(0.000001) == "0.01" diff --git a/tests/test_multi_domain_platform.py b/tests/test_multi_domain_platform.py index a749a493c3b1e7b23c751b570e6269e4d5f4f456..c74d07976a8526e6ba3c4e66b74caf40f12bb724 100644 --- a/tests/test_multi_domain_platform.py +++ b/tests/test_multi_domain_platform.py @@ -1,52 +1,52 @@ -from __future__ import annotations - -from fastapi.testclient import TestClient - -from api.main import app -from schemas.request import AnalyzeCodeRequest -from services.analysis_service import AnalysisService - - -def test_analysis_service_detects_web_code() -> None: - service = AnalysisService() - request = AnalyzeCodeRequest( - code="from fastapi import FastAPI\napp = FastAPI()\n\n@app.get('/health')\ndef health():\n return {'status': 'ok'}\n", - domain_hint="auto", - ) - - result = service.analyze(request) - - assert result.detected_domain == "web" - assert 0.0 <= result.score_breakdown.reward <= 1.0 - assert len(result.improvement_plan) == 3 - - -def test_analysis_service_detects_dsa_code() -> None: - service = AnalysisService() - request = AnalyzeCodeRequest( - code="def has_pair(nums, target):\n for i in range(len(nums)):\n for j in range(i + 1, len(nums)):\n if nums[i] + nums[j] == target:\n return True\n return False\n", - domain_hint="auto", - ) - - result = service.analyze(request) - - assert result.detected_domain == "dsa" - assert result.static_analysis.time_complexity in {"O(n^2)", "O(n^3)"} - - -def test_api_analyze_endpoint_returns_valid_payload() -> None: - client = TestClient(app) - response = client.post( - "/analyze", - json={ - "code": "import torch\n\ndef predict(model, x):\n return model(x)\n", - "context_window": "Inference helper for a classifier", - "traceback_text": "", - "domain_hint": "auto", - }, - ) - - assert response.status_code == 200 - payload = response.json() - assert "detected_domain" in payload - assert "score_breakdown" in payload +from __future__ import annotations + +from fastapi.testclient import TestClient + +from api.main import app +from schemas.request import AnalyzeCodeRequest +from services.analysis_service import AnalysisService + + +def test_analysis_service_detects_web_code() -> None: + service = AnalysisService() + request = AnalyzeCodeRequest( + code="from fastapi import FastAPI\napp = FastAPI()\n\n@app.get('/health')\ndef health():\n return {'status': 'ok'}\n", + domain_hint="auto", + ) + + result = service.analyze(request) + + assert result.detected_domain == "web" + assert 0.0 <= result.score_breakdown.reward <= 1.0 + assert len(result.improvement_plan) == 3 + + +def test_analysis_service_detects_dsa_code() -> None: + service = AnalysisService() + request = AnalyzeCodeRequest( + code="def has_pair(nums, target):\n for i in range(len(nums)):\n for j in range(i + 1, len(nums)):\n if nums[i] + nums[j] == target:\n return True\n return False\n", + domain_hint="auto", + ) + + result = service.analyze(request) + + assert result.detected_domain == "dsa" + assert result.static_analysis.time_complexity in {"O(n^2)", "O(n^3)"} + + +def test_api_analyze_endpoint_returns_valid_payload() -> None: + client = TestClient(app) + response = client.post( + "/analyze", + json={ + "code": "import torch\n\ndef predict(model, x):\n return model(x)\n", + "context_window": "Inference helper for a classifier", + "traceback_text": "", + "domain_hint": "auto", + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert "detected_domain" in payload + assert "score_breakdown" in payload diff --git a/tests/test_scoring.py b/tests/test_scoring.py index 7f964e58227eadc11f5e5b30bde1790a742876b2..9955b5e2a7eff3fefa8cf62ec033129e615a8fb9 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -1,30 +1,30 @@ -from __future__ import annotations - -from graders import grade_task -from graders.shared import component_score, final_score_pipeline, safe_score, shaped_score -from models import PythonCodeReviewAction -from server.env import PythonCodeReviewEnvironment -from tasks import list_tasks - - -def assert_open_unit_interval(value: float) -> None: - assert 0 < value < 1, f"Invalid score: {value}" - - -def test_score_helpers_clamp_extremes_into_open_interval() -> None: - for value in (0.0, 1.0, -999999.0, 999999.0): - assert_open_unit_interval(safe_score(value)) - assert_open_unit_interval(final_score_pipeline(value)) - - for progress in (0.0, 0.5, 1.0): - assert_open_unit_interval(shaped_score(progress)) - assert_open_unit_interval(component_score(progress)) - - -def test_task_grades_stay_strictly_between_zero_and_one() -> None: - for task in list_tasks(): - starter_grade = grade_task(task, task.starter_code, include_hidden=False) - reference_grade = grade_task(task, task.reference_code, include_hidden=True) +from __future__ import annotations + +from graders import grade_task +from graders.shared import component_score, final_score_pipeline, safe_score, shaped_score +from models import PythonCodeReviewAction +from server.env import PythonCodeReviewEnvironment +from tasks import list_tasks + + +def assert_open_unit_interval(value: float) -> None: + assert 0 < value < 1, f"Invalid score: {value}" + + +def test_score_helpers_clamp_extremes_into_open_interval() -> None: + for value in (0.0, 1.0, -999999.0, 999999.0): + assert_open_unit_interval(safe_score(value)) + assert_open_unit_interval(final_score_pipeline(value)) + + for progress in (0.0, 0.5, 1.0): + assert_open_unit_interval(shaped_score(progress)) + assert_open_unit_interval(component_score(progress)) + + +def test_task_grades_stay_strictly_between_zero_and_one() -> None: + for task in list_tasks(): + starter_grade = grade_task(task, task.starter_code, include_hidden=False) + reference_grade = grade_task(task, task.reference_code, include_hidden=True) for grade in (starter_grade, reference_grade): assert_open_unit_interval(grade.score) diff --git a/tests/test_triage_pipeline.py b/tests/test_triage_pipeline.py index 0ce5b5f07e0484d0fc7a249d22f0facc599b09b0..f39e02851541dc1455a09fd3c4987f4b792bf86d 100644 --- a/tests/test_triage_pipeline.py +++ b/tests/test_triage_pipeline.py @@ -1,46 +1,46 @@ -from __future__ import annotations - -from fastapi.testclient import TestClient - -from triage import CodeTriageEngine, HashingEmbeddingBackend -from triage_catalog import build_examples - - -def test_hashing_backend_returns_normalized_embeddings() -> None: - backend = HashingEmbeddingBackend(dimensions=32) - embeddings = backend.embed_texts(["def foo():\n return 1", "for x in items:\n pass"]) - - assert embeddings.shape == (2, 32) - for row in embeddings: - assert round(float(row.norm().item()), 5) == 1.0 - - -def test_examples_map_to_expected_labels_with_fallback_backend() -> None: - examples = build_examples() - engine = CodeTriageEngine(backend=HashingEmbeddingBackend()) - - for example in examples: - result = engine.triage(example.code, example.traceback_text, example.context_window) - assert result.issue_label == example.label - assert 0.0 <= result.reward_score <= 1.0 - - -def test_syntax_example_exposes_parser_signal() -> None: - example = next(item for item in build_examples() if item.label == "syntax") - engine = CodeTriageEngine(backend=HashingEmbeddingBackend()) - - result = engine.triage(example.code, example.traceback_text, example.context_window) - - assert any(signal.name == "syntax_parse" and signal.value == "fails" for signal in result.extracted_signals) - assert result.matched_pattern.task_id == example.task_id - assert result.repair_plan[0].startswith("Step 1 - Syntax checking and bug fixes") - - -def test_composed_app_preserves_health_route() -> None: - from server.app import build_application - - client = TestClient(build_application()) - response = client.get("/health") - - assert response.status_code == 200 - assert response.json()["status"] == "ok" +from __future__ import annotations + +from fastapi.testclient import TestClient + +from triage import CodeTriageEngine, HashingEmbeddingBackend +from triage_catalog import build_examples + + +def test_hashing_backend_returns_normalized_embeddings() -> None: + backend = HashingEmbeddingBackend(dimensions=32) + embeddings = backend.embed_texts(["def foo():\n return 1", "for x in items:\n pass"]) + + assert embeddings.shape == (2, 32) + for row in embeddings: + assert round(float(row.norm().item()), 5) == 1.0 + + +def test_examples_map_to_expected_labels_with_fallback_backend() -> None: + examples = build_examples() + engine = CodeTriageEngine(backend=HashingEmbeddingBackend()) + + for example in examples: + result = engine.triage(example.code, example.traceback_text, example.context_window) + assert result.issue_label == example.label + assert 0.0 <= result.reward_score <= 1.0 + + +def test_syntax_example_exposes_parser_signal() -> None: + example = next(item for item in build_examples() if item.label == "syntax") + engine = CodeTriageEngine(backend=HashingEmbeddingBackend()) + + result = engine.triage(example.code, example.traceback_text, example.context_window) + + assert any(signal.name == "syntax_parse" and signal.value == "fails" for signal in result.extracted_signals) + assert result.matched_pattern.task_id == example.task_id + assert result.repair_plan[0].startswith("Step 1 - Syntax checking and bug fixes") + + +def test_composed_app_preserves_health_route() -> None: + from server.app import build_application + + client = TestClient(build_application()) + response = client.get("/health") + + assert response.status_code == 200 + assert response.json()["status"] == "ok" diff --git a/triage.py b/triage.py index 632647d3ad3428eb80ca3ecc5d3fc54e24d909a4..755f4d82cb3a79b44504a8a36ed1e3307de12d98 100644 --- a/triage.py +++ b/triage.py @@ -1,473 +1,473 @@ -"""PyTorch-backed triage pipeline for TorchReview Copilot.""" - -from __future__ import annotations - -import ast -import hashlib -import os -import re -import time -from functools import lru_cache -from typing import List, Sequence - -import torch -import torch.nn.functional as F - -try: - from transformers import AutoModel, AutoTokenizer -except Exception: - AutoModel = None # type: ignore[assignment] - AutoTokenizer = None # type: ignore[assignment] - -try: - from .triage_catalog import build_examples, build_prototypes - from .triage_models import ( - IssueLabel, - PrototypeMatch, - TriageExample, - TriagePrototype, - TriageResult, - TriageSignal, - ) -except ImportError: - from triage_catalog import build_examples, build_prototypes - from triage_models import ( - IssueLabel, - PrototypeMatch, - TriageExample, - TriagePrototype, - TriageResult, - TriageSignal, - ) - - -MODEL_ID = os.getenv("TRIAGE_MODEL_ID", "huggingface/CodeBERTa-small-v1") -MODEL_MAX_LENGTH = int(os.getenv("TRIAGE_MODEL_MAX_LENGTH", "256")) -LABELS: tuple[IssueLabel, ...] = ("syntax", "logic", "performance") - - -class _LoopDepthVisitor(ast.NodeVisitor): - """Track the maximum loop nesting depth in a code snippet.""" - - def __init__(self) -> None: - self.depth = 0 - self.max_depth = 0 - - def _visit_loop(self, node: ast.AST) -> None: - self.depth += 1 - self.max_depth = max(self.max_depth, self.depth) - self.generic_visit(node) - self.depth -= 1 - - def visit_For(self, node: ast.For) -> None: # noqa: N802 - self._visit_loop(node) - - def visit_While(self, node: ast.While) -> None: # noqa: N802 - self._visit_loop(node) - - def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802 - self._visit_loop(node) - - -class HashingEmbeddingBackend: - """Deterministic torch-native fallback when pretrained weights are unavailable.""" - - def __init__(self, dimensions: int = 96) -> None: - self.dimensions = dimensions - self.model_id = "hashed-token-fallback" - self.backend_name = "hashed-token-fallback" - self.notes = ["Using hashed torch embeddings because pretrained weights are unavailable."] - - def embed_texts(self, texts: Sequence[str]) -> torch.Tensor: - rows = torch.zeros((len(texts), self.dimensions), dtype=torch.float32) - for row_index, text in enumerate(texts): - tokens = re.findall(r"[A-Za-z_]+|\d+|==|!=|<=|>=|\S", text.lower())[:512] - if not tokens: - rows[row_index, 0] = 1.0 - continue - for token in tokens: - digest = hashlib.md5(token.encode("utf-8")).hexdigest() - bucket = int(digest[:8], 16) % self.dimensions - sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0 - rows[row_index, bucket] += sign - return F.normalize(rows + 1e-6, dim=1) - - -class TransformersEmbeddingBackend: - """Mean-pool CodeBERTa embeddings via torch + transformers.""" - - def __init__(self, model_id: str = MODEL_ID, force_fallback: bool = False) -> None: - self.model_id = model_id - self.force_fallback = force_fallback - self.backend_name = model_id - self.notes: List[str] = [] - self._fallback = HashingEmbeddingBackend() - self._tokenizer = None - self._model = None - self._load_error = "" - if force_fallback: - self.backend_name = self._fallback.backend_name - self.notes = list(self._fallback.notes) - - def _ensure_loaded(self) -> None: - if self.force_fallback or self._model is not None or self._load_error: - return - if AutoTokenizer is None or AutoModel is None: - self._load_error = "transformers is not installed." - else: - try: - self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) - self._model = AutoModel.from_pretrained(self.model_id) - self._model.eval() - self.notes.append(f"Loaded pretrained encoder `{self.model_id}` for inference.") - except Exception as exc: - self._load_error = f"{type(exc).__name__}: {exc}" - - if self._load_error: - self.backend_name = self._fallback.backend_name - self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {self._load_error}"] - - def embed_texts(self, texts: Sequence[str]) -> torch.Tensor: - self._ensure_loaded() - if self._model is None or self._tokenizer is None: - return self._fallback.embed_texts(texts) - - encoded = self._tokenizer( - list(texts), - padding=True, - truncation=True, - max_length=MODEL_MAX_LENGTH, - return_tensors="pt", - ) - with torch.no_grad(): - outputs = self._model(**encoded) - hidden_state = outputs.last_hidden_state - mask = encoded["attention_mask"].unsqueeze(-1) - pooled = (hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) - return F.normalize(pooled, dim=1) - - -def _sanitize_text(value: str) -> str: - text = (value or "").strip() - return text[:4000] - - -def _safe_softmax(scores: dict[IssueLabel, float]) -> dict[str, float]: - tensor = torch.tensor([scores[label] for label in LABELS], dtype=torch.float32) - probabilities = torch.softmax(tensor * 4.0, dim=0) - return {label: round(float(probabilities[index]), 4) for index, label in enumerate(LABELS)} - - -def _loop_depth(code: str) -> int: - try: - tree = ast.parse(code) - except SyntaxError: - return 0 - visitor = _LoopDepthVisitor() - visitor.visit(tree) - return visitor.max_depth - - -def _repair_risk(label: IssueLabel, confidence: float, signal_count: int) -> str: - base = {"syntax": 0.25, "logic": 0.55, "performance": 0.7}[label] - if confidence < 0.55: - base += 0.12 - if signal_count >= 4: - base += 0.08 - if base < 0.4: - return "low" - if base < 0.72: - return "medium" - return "high" - - -def _clamp_unit(value: float) -> float: - return round(max(0.0, min(1.0, float(value))), 4) - - -def _lint_score(code: str) -> float: - stripped_lines = [line.rstrip("\n") for line in code.splitlines()] - if not stripped_lines: - return 0.2 - - score = 1.0 - if any(len(line) > 88 for line in stripped_lines): - score -= 0.15 - if any(line.rstrip() != line for line in stripped_lines): - score -= 0.1 - if any("\t" in line for line in stripped_lines): - score -= 0.1 - try: - tree = ast.parse(code) - functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)] - if functions and not ast.get_docstring(functions[0]): - score -= 0.08 - except SyntaxError: - score -= 0.45 - return _clamp_unit(score) - - -def _complexity_penalty(code: str) -> float: - try: - tree = ast.parse(code) - except SyntaxError: - return 0.95 - branch_nodes = sum(isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.Match)) for node in ast.walk(tree)) - loop_depth = _loop_depth(code) - penalty = 0.1 + min(branch_nodes, 8) * 0.07 + min(loop_depth, 4) * 0.12 - return _clamp_unit(penalty) - - -class CodeTriageEngine: - """Combine static signals with PyTorch embeddings to classify code issues.""" - - def __init__( - self, - *, - backend: TransformersEmbeddingBackend | HashingEmbeddingBackend | None = None, - prototypes: Sequence[TriagePrototype] | None = None, - examples: Sequence[TriageExample] | None = None, - ) -> None: - self.backend = backend or TransformersEmbeddingBackend() - self.prototypes = list(prototypes or build_prototypes()) - self.examples = list(examples or build_examples()) - self._prototype_matrix: torch.Tensor | None = None - self._reference_code_matrix: torch.Tensor | None = None - - def example_map(self) -> dict[str, TriageExample]: - """Return UI examples keyed by task id.""" - - return {example.key: example for example in self.examples} - - def _build_document(self, code: str, traceback_text: str) -> str: - trace = _sanitize_text(traceback_text) or "No traceback supplied." - snippet = _sanitize_text(code) or "# No code supplied." - return f"Candidate code:\n{snippet}\n\nObserved failure:\n{trace}\n" - - def _build_review_document(self, code: str, traceback_text: str, context_window: str) -> str: - context = _sanitize_text(context_window) or "No additional context window supplied." - return ( - f"{self._build_document(code, traceback_text)}\n" - f"Context window:\n{context}\n" - ) - - def _prototype_embeddings(self) -> torch.Tensor: - if self._prototype_matrix is None: - reference_texts = [prototype.reference_text for prototype in self.prototypes] - self._prototype_matrix = self.backend.embed_texts(reference_texts) - return self._prototype_matrix - - def _reference_code_embeddings(self) -> torch.Tensor: - if self._reference_code_matrix is None: - reference_codes = [prototype.reference_code for prototype in self.prototypes] - self._reference_code_matrix = self.backend.embed_texts(reference_codes) - return self._reference_code_matrix - - def _extract_signals(self, code: str, traceback_text: str) -> tuple[list[TriageSignal], dict[IssueLabel, float], list[str]]: - trace = (traceback_text or "").lower() - heuristic_scores: dict[IssueLabel, float] = {label: 0.15 for label in LABELS} - signals: list[TriageSignal] = [] - notes: list[str] = [] - - try: - ast.parse(code) - signals.append( - TriageSignal( - name="syntax_parse", - value="passes", - impact="syntax", - weight=0.1, - evidence="Python AST parsing succeeded.", - ) - ) - heuristic_scores["logic"] += 0.05 - except SyntaxError as exc: - evidence = f"{exc.msg} at line {exc.lineno}" - signals.append( - TriageSignal( - name="syntax_parse", - value="fails", - impact="syntax", - weight=0.95, - evidence=evidence, - ) - ) - heuristic_scores["syntax"] += 0.85 - notes.append(f"Parser failure detected: {evidence}") - - if any(token in trace for token in ("syntaxerror", "indentationerror", "expected ':'")): - signals.append( - TriageSignal( - name="traceback_keyword", - value="syntaxerror", - impact="syntax", - weight=0.8, - evidence="Traceback contains a parser error.", - ) - ) - heuristic_scores["syntax"] += 0.55 - - if any(token in trace for token in ("assertionerror", "expected:", "actual:", "boundary", "missing", "incorrect")): - signals.append( - TriageSignal( - name="test_failure_signal", - value="assertion-style failure", - impact="logic", - weight=0.7, - evidence="Failure text points to behavioral mismatch instead of parser issues.", - ) - ) - heuristic_scores["logic"] += 0.55 - - if any(token in trace for token in ("timeout", "benchmark", "slow", "latency", "performance", "profiler")): - signals.append( - TriageSignal( - name="performance_trace", - value="latency regression", - impact="performance", - weight=0.85, - evidence="Traceback mentions benchmark or latency pressure.", - ) - ) - heuristic_scores["performance"] += 0.7 - - loop_depth = _loop_depth(code) - if loop_depth >= 2: - signals.append( - TriageSignal( - name="loop_depth", - value=str(loop_depth), - impact="performance", - weight=0.65, - evidence="Nested iteration increases runtime risk on larger fixtures.", - ) - ) - heuristic_scores["performance"] += 0.35 - - if "Counter(" in code or "defaultdict(" in code or "set(" in code: - heuristic_scores["performance"] += 0.05 - - if "return sessions" in code and "sessions.append" not in code: - signals.append( - TriageSignal( - name="state_update_gap", - value="possible missing final append", - impact="logic", - weight=0.45, - evidence="A collection is returned without an obvious final state flush.", - ) - ) - heuristic_scores["logic"] += 0.18 - - return signals, heuristic_scores, notes - - def _nearest_match(self, embedding: torch.Tensor) -> tuple[TriagePrototype, float, dict[str, float]]: - similarities = torch.matmul(embedding, self._prototype_embeddings().T)[0] - indexed_scores = { - self.prototypes[index].task_id: round(float((similarities[index] + 1.0) / 2.0), 4) - for index in range(len(self.prototypes)) - } - best_index = int(torch.argmax(similarities).item()) - best_prototype = self.prototypes[best_index] - best_similarity = float((similarities[best_index] + 1.0) / 2.0) - return best_prototype, best_similarity, indexed_scores - - def _repair_plan(self, label: IssueLabel, matched: TriagePrototype, context_window: str) -> list[str]: - context = _sanitize_text(context_window) - step_one = { - "syntax": "Step 1 - Syntax checking and bug fixes: resolve the parser break before touching behavior, then align the function with the expected contract.", - "logic": "Step 1 - Syntax checking and bug fixes: confirm the code parses cleanly, then patch the failing branch or state update causing the incorrect result.", - "performance": "Step 1 - Syntax checking and bug fixes: keep the implementation correct first, then isolate the slow section without changing external behavior.", - }[label] - step_two = ( - "Step 2 - Edge case handling: verify empty input, boundary values, missing fields, and final-state flush behavior " - f"against the known pattern `{matched.title}`." - ) - step_three = ( - "Step 3 - Scalability of code: remove repeated full scans, prefer linear-time data structures, " - "and benchmark the path on a production-like fixture." - ) - if context: - step_two = f"{step_two} Context window to preserve: {context}" - return [step_one, step_two, step_three] - - def _reference_quality_score(self, code: str, matched: TriagePrototype) -> float: - candidate = self.backend.embed_texts([_sanitize_text(code) or "# empty"]) - match_index = next(index for index, prototype in enumerate(self.prototypes) if prototype.task_id == matched.task_id) - reference = self._reference_code_embeddings()[match_index : match_index + 1] - score = float(torch.matmul(candidate, reference.T)[0][0].item()) - return _clamp_unit((score + 1.0) / 2.0) - - def triage(self, code: str, traceback_text: str = "", context_window: str = "") -> TriageResult: - """Run the full triage pipeline on code plus optional failure context.""" - - started = time.perf_counter() - document = self._build_review_document(code, traceback_text, context_window) - signals, heuristic_scores, notes = self._extract_signals(code, traceback_text) - - candidate_embedding = self.backend.embed_texts([document]) - matched, matched_similarity, prototype_scores = self._nearest_match(candidate_embedding) - - label_similarity = {label: 0.18 for label in LABELS} - for prototype in self.prototypes: - label_similarity[prototype.label] = max( - label_similarity[prototype.label], - prototype_scores[prototype.task_id], - ) - - combined_scores = { - label: 0.72 * label_similarity[label] + 0.28 * heuristic_scores[label] - for label in LABELS - } - confidence_scores = _safe_softmax(combined_scores) - issue_label = max(LABELS, key=lambda label: confidence_scores[label]) - top_confidence = confidence_scores[issue_label] - - top_signal = signals[0].evidence if signals else "Model similarity dominated the decision." - ml_quality_score = self._reference_quality_score(code, matched) - lint_score = _lint_score(code) - complexity_penalty = _complexity_penalty(code) - reward_score = _clamp_unit((0.5 * ml_quality_score) + (0.3 * lint_score) - (0.2 * complexity_penalty)) - summary = ( - f"Detected a {issue_label} issue with {top_confidence:.0%} confidence. " - f"The closest known failure pattern is `{matched.title}`, which indicates {matched.summary.lower()}. " - f"Predicted quality score is {ml_quality_score:.0%} with an RL-ready reward of {reward_score:.0%}." - ) - suggested_next_action = { - "syntax": "Fix the parser error first, then rerun validation before changing behavior.", - "logic": "Step through the smallest failing case and confirm the final branch/update behavior.", - "performance": "Replace repeated full-list scans with a linear-time aggregation strategy, then benchmark it.", - }[issue_label] - - return TriageResult( - issue_label=issue_label, - confidence_scores=confidence_scores, - repair_risk=_repair_risk(issue_label, top_confidence, len(signals)), - ml_quality_score=ml_quality_score, - lint_score=lint_score, - complexity_penalty=complexity_penalty, - reward_score=reward_score, - summary=summary, - matched_pattern=PrototypeMatch( - task_id=matched.task_id, - title=matched.title, - label=matched.label, - similarity=round(matched_similarity, 4), - summary=matched.summary, - rationale=top_signal, - ), - repair_plan=self._repair_plan(issue_label, matched, context_window), - suggested_next_action=suggested_next_action, - extracted_signals=signals, - model_backend=self.backend.backend_name, - model_id=self.backend.model_id, - inference_notes=list(self.backend.notes) + notes, - analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2), - ) - - -@lru_cache(maxsize=1) -def get_default_engine() -> CodeTriageEngine: - """Return a cached triage engine for the running process.""" - - return CodeTriageEngine() +"""PyTorch-backed triage pipeline for TorchReview Copilot.""" + +from __future__ import annotations + +import ast +import hashlib +import os +import re +import time +from functools import lru_cache +from typing import List, Sequence + +import torch +import torch.nn.functional as F + +try: + from transformers import AutoModel, AutoTokenizer +except Exception: + AutoModel = None # type: ignore[assignment] + AutoTokenizer = None # type: ignore[assignment] + +try: + from .triage_catalog import build_examples, build_prototypes + from .triage_models import ( + IssueLabel, + PrototypeMatch, + TriageExample, + TriagePrototype, + TriageResult, + TriageSignal, + ) +except ImportError: + from triage_catalog import build_examples, build_prototypes + from triage_models import ( + IssueLabel, + PrototypeMatch, + TriageExample, + TriagePrototype, + TriageResult, + TriageSignal, + ) + + +MODEL_ID = os.getenv("TRIAGE_MODEL_ID", "huggingface/CodeBERTa-small-v1") +MODEL_MAX_LENGTH = int(os.getenv("TRIAGE_MODEL_MAX_LENGTH", "256")) +LABELS: tuple[IssueLabel, ...] = ("syntax", "logic", "performance") + + +class _LoopDepthVisitor(ast.NodeVisitor): + """Track the maximum loop nesting depth in a code snippet.""" + + def __init__(self) -> None: + self.depth = 0 + self.max_depth = 0 + + def _visit_loop(self, node: ast.AST) -> None: + self.depth += 1 + self.max_depth = max(self.max_depth, self.depth) + self.generic_visit(node) + self.depth -= 1 + + def visit_For(self, node: ast.For) -> None: # noqa: N802 + self._visit_loop(node) + + def visit_While(self, node: ast.While) -> None: # noqa: N802 + self._visit_loop(node) + + def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802 + self._visit_loop(node) + + +class HashingEmbeddingBackend: + """Deterministic torch-native fallback when pretrained weights are unavailable.""" + + def __init__(self, dimensions: int = 96) -> None: + self.dimensions = dimensions + self.model_id = "hashed-token-fallback" + self.backend_name = "hashed-token-fallback" + self.notes = ["Using hashed torch embeddings because pretrained weights are unavailable."] + + def embed_texts(self, texts: Sequence[str]) -> torch.Tensor: + rows = torch.zeros((len(texts), self.dimensions), dtype=torch.float32) + for row_index, text in enumerate(texts): + tokens = re.findall(r"[A-Za-z_]+|\d+|==|!=|<=|>=|\S", text.lower())[:512] + if not tokens: + rows[row_index, 0] = 1.0 + continue + for token in tokens: + digest = hashlib.md5(token.encode("utf-8")).hexdigest() + bucket = int(digest[:8], 16) % self.dimensions + sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0 + rows[row_index, bucket] += sign + return F.normalize(rows + 1e-6, dim=1) + + +class TransformersEmbeddingBackend: + """Mean-pool CodeBERTa embeddings via torch + transformers.""" + + def __init__(self, model_id: str = MODEL_ID, force_fallback: bool = False) -> None: + self.model_id = model_id + self.force_fallback = force_fallback + self.backend_name = model_id + self.notes: List[str] = [] + self._fallback = HashingEmbeddingBackend() + self._tokenizer = None + self._model = None + self._load_error = "" + if force_fallback: + self.backend_name = self._fallback.backend_name + self.notes = list(self._fallback.notes) + + def _ensure_loaded(self) -> None: + if self.force_fallback or self._model is not None or self._load_error: + return + if AutoTokenizer is None or AutoModel is None: + self._load_error = "transformers is not installed." + else: + try: + self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self._model = AutoModel.from_pretrained(self.model_id) + self._model.eval() + self.notes.append(f"Loaded pretrained encoder `{self.model_id}` for inference.") + except Exception as exc: + self._load_error = f"{type(exc).__name__}: {exc}" + + if self._load_error: + self.backend_name = self._fallback.backend_name + self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {self._load_error}"] + + def embed_texts(self, texts: Sequence[str]) -> torch.Tensor: + self._ensure_loaded() + if self._model is None or self._tokenizer is None: + return self._fallback.embed_texts(texts) + + encoded = self._tokenizer( + list(texts), + padding=True, + truncation=True, + max_length=MODEL_MAX_LENGTH, + return_tensors="pt", + ) + with torch.no_grad(): + outputs = self._model(**encoded) + hidden_state = outputs.last_hidden_state + mask = encoded["attention_mask"].unsqueeze(-1) + pooled = (hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) + return F.normalize(pooled, dim=1) + + +def _sanitize_text(value: str) -> str: + text = (value or "").strip() + return text[:4000] + + +def _safe_softmax(scores: dict[IssueLabel, float]) -> dict[str, float]: + tensor = torch.tensor([scores[label] for label in LABELS], dtype=torch.float32) + probabilities = torch.softmax(tensor * 4.0, dim=0) + return {label: round(float(probabilities[index]), 4) for index, label in enumerate(LABELS)} + + +def _loop_depth(code: str) -> int: + try: + tree = ast.parse(code) + except SyntaxError: + return 0 + visitor = _LoopDepthVisitor() + visitor.visit(tree) + return visitor.max_depth + + +def _repair_risk(label: IssueLabel, confidence: float, signal_count: int) -> str: + base = {"syntax": 0.25, "logic": 0.55, "performance": 0.7}[label] + if confidence < 0.55: + base += 0.12 + if signal_count >= 4: + base += 0.08 + if base < 0.4: + return "low" + if base < 0.72: + return "medium" + return "high" + + +def _clamp_unit(value: float) -> float: + return round(max(0.0, min(1.0, float(value))), 4) + + +def _lint_score(code: str) -> float: + stripped_lines = [line.rstrip("\n") for line in code.splitlines()] + if not stripped_lines: + return 0.2 + + score = 1.0 + if any(len(line) > 88 for line in stripped_lines): + score -= 0.15 + if any(line.rstrip() != line for line in stripped_lines): + score -= 0.1 + if any("\t" in line for line in stripped_lines): + score -= 0.1 + try: + tree = ast.parse(code) + functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)] + if functions and not ast.get_docstring(functions[0]): + score -= 0.08 + except SyntaxError: + score -= 0.45 + return _clamp_unit(score) + + +def _complexity_penalty(code: str) -> float: + try: + tree = ast.parse(code) + except SyntaxError: + return 0.95 + branch_nodes = sum(isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.Match)) for node in ast.walk(tree)) + loop_depth = _loop_depth(code) + penalty = 0.1 + min(branch_nodes, 8) * 0.07 + min(loop_depth, 4) * 0.12 + return _clamp_unit(penalty) + + +class CodeTriageEngine: + """Combine static signals with PyTorch embeddings to classify code issues.""" + + def __init__( + self, + *, + backend: TransformersEmbeddingBackend | HashingEmbeddingBackend | None = None, + prototypes: Sequence[TriagePrototype] | None = None, + examples: Sequence[TriageExample] | None = None, + ) -> None: + self.backend = backend or TransformersEmbeddingBackend() + self.prototypes = list(prototypes or build_prototypes()) + self.examples = list(examples or build_examples()) + self._prototype_matrix: torch.Tensor | None = None + self._reference_code_matrix: torch.Tensor | None = None + + def example_map(self) -> dict[str, TriageExample]: + """Return UI examples keyed by task id.""" + + return {example.key: example for example in self.examples} + + def _build_document(self, code: str, traceback_text: str) -> str: + trace = _sanitize_text(traceback_text) or "No traceback supplied." + snippet = _sanitize_text(code) or "# No code supplied." + return f"Candidate code:\n{snippet}\n\nObserved failure:\n{trace}\n" + + def _build_review_document(self, code: str, traceback_text: str, context_window: str) -> str: + context = _sanitize_text(context_window) or "No additional context window supplied." + return ( + f"{self._build_document(code, traceback_text)}\n" + f"Context window:\n{context}\n" + ) + + def _prototype_embeddings(self) -> torch.Tensor: + if self._prototype_matrix is None: + reference_texts = [prototype.reference_text for prototype in self.prototypes] + self._prototype_matrix = self.backend.embed_texts(reference_texts) + return self._prototype_matrix + + def _reference_code_embeddings(self) -> torch.Tensor: + if self._reference_code_matrix is None: + reference_codes = [prototype.reference_code for prototype in self.prototypes] + self._reference_code_matrix = self.backend.embed_texts(reference_codes) + return self._reference_code_matrix + + def _extract_signals(self, code: str, traceback_text: str) -> tuple[list[TriageSignal], dict[IssueLabel, float], list[str]]: + trace = (traceback_text or "").lower() + heuristic_scores: dict[IssueLabel, float] = {label: 0.15 for label in LABELS} + signals: list[TriageSignal] = [] + notes: list[str] = [] + + try: + ast.parse(code) + signals.append( + TriageSignal( + name="syntax_parse", + value="passes", + impact="syntax", + weight=0.1, + evidence="Python AST parsing succeeded.", + ) + ) + heuristic_scores["logic"] += 0.05 + except SyntaxError as exc: + evidence = f"{exc.msg} at line {exc.lineno}" + signals.append( + TriageSignal( + name="syntax_parse", + value="fails", + impact="syntax", + weight=0.95, + evidence=evidence, + ) + ) + heuristic_scores["syntax"] += 0.85 + notes.append(f"Parser failure detected: {evidence}") + + if any(token in trace for token in ("syntaxerror", "indentationerror", "expected ':'")): + signals.append( + TriageSignal( + name="traceback_keyword", + value="syntaxerror", + impact="syntax", + weight=0.8, + evidence="Traceback contains a parser error.", + ) + ) + heuristic_scores["syntax"] += 0.55 + + if any(token in trace for token in ("assertionerror", "expected:", "actual:", "boundary", "missing", "incorrect")): + signals.append( + TriageSignal( + name="test_failure_signal", + value="assertion-style failure", + impact="logic", + weight=0.7, + evidence="Failure text points to behavioral mismatch instead of parser issues.", + ) + ) + heuristic_scores["logic"] += 0.55 + + if any(token in trace for token in ("timeout", "benchmark", "slow", "latency", "performance", "profiler")): + signals.append( + TriageSignal( + name="performance_trace", + value="latency regression", + impact="performance", + weight=0.85, + evidence="Traceback mentions benchmark or latency pressure.", + ) + ) + heuristic_scores["performance"] += 0.7 + + loop_depth = _loop_depth(code) + if loop_depth >= 2: + signals.append( + TriageSignal( + name="loop_depth", + value=str(loop_depth), + impact="performance", + weight=0.65, + evidence="Nested iteration increases runtime risk on larger fixtures.", + ) + ) + heuristic_scores["performance"] += 0.35 + + if "Counter(" in code or "defaultdict(" in code or "set(" in code: + heuristic_scores["performance"] += 0.05 + + if "return sessions" in code and "sessions.append" not in code: + signals.append( + TriageSignal( + name="state_update_gap", + value="possible missing final append", + impact="logic", + weight=0.45, + evidence="A collection is returned without an obvious final state flush.", + ) + ) + heuristic_scores["logic"] += 0.18 + + return signals, heuristic_scores, notes + + def _nearest_match(self, embedding: torch.Tensor) -> tuple[TriagePrototype, float, dict[str, float]]: + similarities = torch.matmul(embedding, self._prototype_embeddings().T)[0] + indexed_scores = { + self.prototypes[index].task_id: round(float((similarities[index] + 1.0) / 2.0), 4) + for index in range(len(self.prototypes)) + } + best_index = int(torch.argmax(similarities).item()) + best_prototype = self.prototypes[best_index] + best_similarity = float((similarities[best_index] + 1.0) / 2.0) + return best_prototype, best_similarity, indexed_scores + + def _repair_plan(self, label: IssueLabel, matched: TriagePrototype, context_window: str) -> list[str]: + context = _sanitize_text(context_window) + step_one = { + "syntax": "Step 1 - Syntax checking and bug fixes: resolve the parser break before touching behavior, then align the function with the expected contract.", + "logic": "Step 1 - Syntax checking and bug fixes: confirm the code parses cleanly, then patch the failing branch or state update causing the incorrect result.", + "performance": "Step 1 - Syntax checking and bug fixes: keep the implementation correct first, then isolate the slow section without changing external behavior.", + }[label] + step_two = ( + "Step 2 - Edge case handling: verify empty input, boundary values, missing fields, and final-state flush behavior " + f"against the known pattern `{matched.title}`." + ) + step_three = ( + "Step 3 - Scalability of code: remove repeated full scans, prefer linear-time data structures, " + "and benchmark the path on a production-like fixture." + ) + if context: + step_two = f"{step_two} Context window to preserve: {context}" + return [step_one, step_two, step_three] + + def _reference_quality_score(self, code: str, matched: TriagePrototype) -> float: + candidate = self.backend.embed_texts([_sanitize_text(code) or "# empty"]) + match_index = next(index for index, prototype in enumerate(self.prototypes) if prototype.task_id == matched.task_id) + reference = self._reference_code_embeddings()[match_index : match_index + 1] + score = float(torch.matmul(candidate, reference.T)[0][0].item()) + return _clamp_unit((score + 1.0) / 2.0) + + def triage(self, code: str, traceback_text: str = "", context_window: str = "") -> TriageResult: + """Run the full triage pipeline on code plus optional failure context.""" + + started = time.perf_counter() + document = self._build_review_document(code, traceback_text, context_window) + signals, heuristic_scores, notes = self._extract_signals(code, traceback_text) + + candidate_embedding = self.backend.embed_texts([document]) + matched, matched_similarity, prototype_scores = self._nearest_match(candidate_embedding) + + label_similarity = {label: 0.18 for label in LABELS} + for prototype in self.prototypes: + label_similarity[prototype.label] = max( + label_similarity[prototype.label], + prototype_scores[prototype.task_id], + ) + + combined_scores = { + label: 0.72 * label_similarity[label] + 0.28 * heuristic_scores[label] + for label in LABELS + } + confidence_scores = _safe_softmax(combined_scores) + issue_label = max(LABELS, key=lambda label: confidence_scores[label]) + top_confidence = confidence_scores[issue_label] + + top_signal = signals[0].evidence if signals else "Model similarity dominated the decision." + ml_quality_score = self._reference_quality_score(code, matched) + lint_score = _lint_score(code) + complexity_penalty = _complexity_penalty(code) + reward_score = _clamp_unit((0.5 * ml_quality_score) + (0.3 * lint_score) - (0.2 * complexity_penalty)) + summary = ( + f"Detected a {issue_label} issue with {top_confidence:.0%} confidence. " + f"The closest known failure pattern is `{matched.title}`, which indicates {matched.summary.lower()}. " + f"Predicted quality score is {ml_quality_score:.0%} with an RL-ready reward of {reward_score:.0%}." + ) + suggested_next_action = { + "syntax": "Fix the parser error first, then rerun validation before changing behavior.", + "logic": "Step through the smallest failing case and confirm the final branch/update behavior.", + "performance": "Replace repeated full-list scans with a linear-time aggregation strategy, then benchmark it.", + }[issue_label] + + return TriageResult( + issue_label=issue_label, + confidence_scores=confidence_scores, + repair_risk=_repair_risk(issue_label, top_confidence, len(signals)), + ml_quality_score=ml_quality_score, + lint_score=lint_score, + complexity_penalty=complexity_penalty, + reward_score=reward_score, + summary=summary, + matched_pattern=PrototypeMatch( + task_id=matched.task_id, + title=matched.title, + label=matched.label, + similarity=round(matched_similarity, 4), + summary=matched.summary, + rationale=top_signal, + ), + repair_plan=self._repair_plan(issue_label, matched, context_window), + suggested_next_action=suggested_next_action, + extracted_signals=signals, + model_backend=self.backend.backend_name, + model_id=self.backend.model_id, + inference_notes=list(self.backend.notes) + notes, + analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2), + ) + + +@lru_cache(maxsize=1) +def get_default_engine() -> CodeTriageEngine: + """Return a cached triage engine for the running process.""" + + return CodeTriageEngine() diff --git a/triage_catalog.py b/triage_catalog.py index e62fea7e39d082f996f0578de965d2f90a390c3d..07d44522c002a6f71da1ed120a66c18fdd75a7b5 100644 --- a/triage_catalog.py +++ b/triage_catalog.py @@ -1,134 +1,134 @@ -"""Curated prototypes and example inputs for TorchReview Copilot.""" - -from __future__ import annotations - -from typing import Dict, List - -try: - from .triage_models import IssueLabel, TriageExample, TriagePrototype - from .tasks import list_tasks -except ImportError: - from triage_models import IssueLabel, TriageExample, TriagePrototype - from tasks import list_tasks - - -TASK_KIND_TO_LABEL: Dict[str, IssueLabel] = { - "syntax_fix": "syntax", - "bug_fix": "logic", - "optimization": "performance", -} - -TRACEBACK_BY_TASK_ID: Dict[str, str] = { - "syntax_fix_invoice_totals": ( - "Traceback (most recent call last):\n" - " File \"services/billing/reconciliation.py\", line 3\n" - " for record in records\n" - " ^\n" - "SyntaxError: expected ':'" - ), - "bug_fix_session_windows": ( - "AssertionError: collapse_sessions([{'minute': 1}, {'minute': 3}, {'minute': 8}], 4)\n" - "Expected: [(1, 3), (8, 8)]\n" - "Actual: [(1, 8)]\n" - "Boundary handling merges the final session instead of starting a new one." - ), - "optimization_rank_active_users": ( - "BenchmarkWarning: rank_active_users exceeded the 450ms budget on a nightly export fixture.\n" - "Profiler hint: repeated scans over the full event list and nested loops dominate runtime." - ), -} - -SUMMARY_BY_TASK_ID: Dict[str, str] = { - "syntax_fix_invoice_totals": "Broken parser state in a billing helper blocks reconciliation jobs.", - "bug_fix_session_windows": "Session-boundary logic fails on inclusive idle-timeout edges.", - "optimization_rank_active_users": "A nightly ranking job is correct on small fixtures but too slow at production scale.", -} - -CONTEXT_BY_TASK_ID: Dict[str, str] = { - "syntax_fix_invoice_totals": ( - "Context window: this helper runs in an end-of-day billing reconciliation job. " - "Keep the public function signature intact and restore correct totals for mixed integer/string inputs." - ), - "bug_fix_session_windows": ( - "Context window: this function groups sorted product analytics events into sessions for retention dashboards. " - "Boundary behavior must stay deterministic because downstream reports depend on it." - ), - "optimization_rank_active_users": ( - "Context window: this pipeline feeds a nightly export on a small CPU instance. " - "Maintain identical output ordering while improving scalability on larger event volumes." - ), -} - - -def _prototype_text( - task_id: str, - title: str, - description: str, - repo_summary: str, - goal: str, - visible_tests: List[str], - starter_code: str, - traceback_text: str, -) -> str: - visible = "\n".join(f"- {item}" for item in visible_tests) or "- none" - return ( - f"Title: {title}\n" - f"Problem: {description}\n" - f"Repo context: {repo_summary}\n" - f"Goal: {goal}\n" - f"Observed failure:\n{traceback_text}\n" - f"Visible checks:\n{visible}\n" - f"Candidate code:\n{starter_code}\n" - f"Task id: {task_id}\n" - ) - - -def build_examples() -> List[TriageExample]: - """Create stable UI examples from the task catalog.""" - - examples: List[TriageExample] = [] - for task in list_tasks(): - label = TASK_KIND_TO_LABEL[task.task_kind] - examples.append( - TriageExample( - key=task.task_id, - title=task.title, - label=label, - summary=SUMMARY_BY_TASK_ID[task.task_id], - code=task.starter_code, - traceback_text=TRACEBACK_BY_TASK_ID[task.task_id], - context_window=CONTEXT_BY_TASK_ID[task.task_id], - task_id=task.task_id, - ) - ) - return examples - - -def build_prototypes() -> List[TriagePrototype]: - """Build canonical triage prototypes from the OpenEnv tasks.""" - - prototypes: List[TriagePrototype] = [] - for task in list_tasks(): - traceback_text = TRACEBACK_BY_TASK_ID[task.task_id] - prototypes.append( - TriagePrototype( - task_id=task.task_id, - title=task.title, - label=TASK_KIND_TO_LABEL[task.task_kind], - summary=SUMMARY_BY_TASK_ID[task.task_id], - reference_text=_prototype_text( - task.task_id, - task.title, - task.task_description, - task.repo_summary, - task.goal, - list(task.visible_tests), - task.reference_code, - traceback_text, - ), - starter_code=task.starter_code, - reference_code=task.reference_code, - traceback_text=traceback_text, - ) - ) - return prototypes +"""Curated prototypes and example inputs for TorchReview Copilot.""" + +from __future__ import annotations + +from typing import Dict, List + +try: + from .triage_models import IssueLabel, TriageExample, TriagePrototype + from .tasks import list_tasks +except ImportError: + from triage_models import IssueLabel, TriageExample, TriagePrototype + from tasks import list_tasks + + +TASK_KIND_TO_LABEL: Dict[str, IssueLabel] = { + "syntax_fix": "syntax", + "bug_fix": "logic", + "optimization": "performance", +} + +TRACEBACK_BY_TASK_ID: Dict[str, str] = { + "syntax_fix_invoice_totals": ( + "Traceback (most recent call last):\n" + " File \"services/billing/reconciliation.py\", line 3\n" + " for record in records\n" + " ^\n" + "SyntaxError: expected ':'" + ), + "bug_fix_session_windows": ( + "AssertionError: collapse_sessions([{'minute': 1}, {'minute': 3}, {'minute': 8}], 4)\n" + "Expected: [(1, 3), (8, 8)]\n" + "Actual: [(1, 8)]\n" + "Boundary handling merges the final session instead of starting a new one." + ), + "optimization_rank_active_users": ( + "BenchmarkWarning: rank_active_users exceeded the 450ms budget on a nightly export fixture.\n" + "Profiler hint: repeated scans over the full event list and nested loops dominate runtime." + ), +} + +SUMMARY_BY_TASK_ID: Dict[str, str] = { + "syntax_fix_invoice_totals": "Broken parser state in a billing helper blocks reconciliation jobs.", + "bug_fix_session_windows": "Session-boundary logic fails on inclusive idle-timeout edges.", + "optimization_rank_active_users": "A nightly ranking job is correct on small fixtures but too slow at production scale.", +} + +CONTEXT_BY_TASK_ID: Dict[str, str] = { + "syntax_fix_invoice_totals": ( + "Context window: this helper runs in an end-of-day billing reconciliation job. " + "Keep the public function signature intact and restore correct totals for mixed integer/string inputs." + ), + "bug_fix_session_windows": ( + "Context window: this function groups sorted product analytics events into sessions for retention dashboards. " + "Boundary behavior must stay deterministic because downstream reports depend on it." + ), + "optimization_rank_active_users": ( + "Context window: this pipeline feeds a nightly export on a small CPU instance. " + "Maintain identical output ordering while improving scalability on larger event volumes." + ), +} + + +def _prototype_text( + task_id: str, + title: str, + description: str, + repo_summary: str, + goal: str, + visible_tests: List[str], + starter_code: str, + traceback_text: str, +) -> str: + visible = "\n".join(f"- {item}" for item in visible_tests) or "- none" + return ( + f"Title: {title}\n" + f"Problem: {description}\n" + f"Repo context: {repo_summary}\n" + f"Goal: {goal}\n" + f"Observed failure:\n{traceback_text}\n" + f"Visible checks:\n{visible}\n" + f"Candidate code:\n{starter_code}\n" + f"Task id: {task_id}\n" + ) + + +def build_examples() -> List[TriageExample]: + """Create stable UI examples from the task catalog.""" + + examples: List[TriageExample] = [] + for task in list_tasks(): + label = TASK_KIND_TO_LABEL[task.task_kind] + examples.append( + TriageExample( + key=task.task_id, + title=task.title, + label=label, + summary=SUMMARY_BY_TASK_ID[task.task_id], + code=task.starter_code, + traceback_text=TRACEBACK_BY_TASK_ID[task.task_id], + context_window=CONTEXT_BY_TASK_ID[task.task_id], + task_id=task.task_id, + ) + ) + return examples + + +def build_prototypes() -> List[TriagePrototype]: + """Build canonical triage prototypes from the OpenEnv tasks.""" + + prototypes: List[TriagePrototype] = [] + for task in list_tasks(): + traceback_text = TRACEBACK_BY_TASK_ID[task.task_id] + prototypes.append( + TriagePrototype( + task_id=task.task_id, + title=task.title, + label=TASK_KIND_TO_LABEL[task.task_kind], + summary=SUMMARY_BY_TASK_ID[task.task_id], + reference_text=_prototype_text( + task.task_id, + task.title, + task.task_description, + task.repo_summary, + task.goal, + list(task.visible_tests), + task.reference_code, + traceback_text, + ), + starter_code=task.starter_code, + reference_code=task.reference_code, + traceback_text=traceback_text, + ) + ) + return prototypes diff --git a/triage_models.py b/triage_models.py index 8ecc3a345adbe22292294654a77eea9f87796667..3b8e905806867dd8945a968ec24d841bd4e72db0 100644 --- a/triage_models.py +++ b/triage_models.py @@ -1,79 +1,79 @@ -"""Typed models for TorchReview Copilot outputs and examples.""" - -from __future__ import annotations - -from typing import Dict, List, Literal - -from pydantic import BaseModel, Field - - -IssueLabel = Literal["syntax", "logic", "performance"] -RiskLevel = Literal["low", "medium", "high"] - - -class TriageSignal(BaseModel): - """One extracted signal used during issue classification.""" - - name: str - value: str - impact: Literal["syntax", "logic", "performance", "mixed"] = "mixed" - weight: float = Field(..., ge=0.0, le=1.0) - evidence: str = "" - - -class PrototypeMatch(BaseModel): - """Nearest known bug pattern from the built-in task catalog.""" - - task_id: str - title: str - label: IssueLabel - similarity: float = Field(..., ge=0.0, le=1.0) - summary: str - rationale: str - - -class TriageExample(BaseModel): - """Example payload exposed in the demo UI.""" - - key: str - title: str - label: IssueLabel - summary: str - code: str - traceback_text: str - context_window: str - task_id: str - - -class TriagePrototype(BaseModel): - """Canonical issue-pattern representation embedded by the triage engine.""" - - task_id: str - title: str - label: IssueLabel - summary: str - reference_text: str - starter_code: str - reference_code: str - traceback_text: str - - -class TriageResult(BaseModel): - """Structured output produced by the triage pipeline.""" - - issue_label: IssueLabel - confidence_scores: Dict[str, float] - repair_risk: RiskLevel - ml_quality_score: float = Field(..., ge=0.0, le=1.0) - lint_score: float = Field(..., ge=0.0, le=1.0) - complexity_penalty: float = Field(..., ge=0.0, le=1.0) - reward_score: float = Field(..., ge=0.0, le=1.0) - summary: str - matched_pattern: PrototypeMatch - repair_plan: List[str] - suggested_next_action: str - extracted_signals: List[TriageSignal] = Field(default_factory=list) - model_backend: str - model_id: str - inference_notes: List[str] = Field(default_factory=list) - analysis_time_ms: float = Field(..., ge=0.0) +"""Typed models for TorchReview Copilot outputs and examples.""" + +from __future__ import annotations + +from typing import Dict, List, Literal + +from pydantic import BaseModel, Field + + +IssueLabel = Literal["syntax", "logic", "performance"] +RiskLevel = Literal["low", "medium", "high"] + + +class TriageSignal(BaseModel): + """One extracted signal used during issue classification.""" + + name: str + value: str + impact: Literal["syntax", "logic", "performance", "mixed"] = "mixed" + weight: float = Field(..., ge=0.0, le=1.0) + evidence: str = "" + + +class PrototypeMatch(BaseModel): + """Nearest known bug pattern from the built-in task catalog.""" + + task_id: str + title: str + label: IssueLabel + similarity: float = Field(..., ge=0.0, le=1.0) + summary: str + rationale: str + + +class TriageExample(BaseModel): + """Example payload exposed in the demo UI.""" + + key: str + title: str + label: IssueLabel + summary: str + code: str + traceback_text: str + context_window: str + task_id: str + + +class TriagePrototype(BaseModel): + """Canonical issue-pattern representation embedded by the triage engine.""" + + task_id: str + title: str + label: IssueLabel + summary: str + reference_text: str + starter_code: str + reference_code: str + traceback_text: str + + +class TriageResult(BaseModel): + """Structured output produced by the triage pipeline.""" + + issue_label: IssueLabel + confidence_scores: Dict[str, float] + repair_risk: RiskLevel + ml_quality_score: float = Field(..., ge=0.0, le=1.0) + lint_score: float = Field(..., ge=0.0, le=1.0) + complexity_penalty: float = Field(..., ge=0.0, le=1.0) + reward_score: float = Field(..., ge=0.0, le=1.0) + summary: str + matched_pattern: PrototypeMatch + repair_plan: List[str] + suggested_next_action: str + extracted_signals: List[TriageSignal] = Field(default_factory=list) + model_backend: str + model_id: str + inference_notes: List[str] = Field(default_factory=list) + analysis_time_ms: float = Field(..., ge=0.0) diff --git a/utils/__init__.py b/utils/__init__.py index 4bc736a197907087eadf9bfaf47d737ca460d64b..0121832ece22b25bcafa896d94b50ae4587d1b99 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,6 +1,6 @@ -"""Utility helpers for AST parsing and complexity scoring.""" - -from .ast_parser import parse_code_structure -from .complexity import estimate_complexity - -__all__ = ["parse_code_structure", "estimate_complexity"] +"""Utility helpers for AST parsing and complexity scoring.""" + +from .ast_parser import parse_code_structure +from .complexity import estimate_complexity + +__all__ = ["parse_code_structure", "estimate_complexity"] diff --git a/utils/ast_parser.py b/utils/ast_parser.py index 81dba793be86bf4a327fc91709432ee070a69552..d0eb1e80bf7adcde7e8017f6bffefecc5aa7882c 100644 --- a/utils/ast_parser.py +++ b/utils/ast_parser.py @@ -1,248 +1,144 @@ -"""AST-based parsing helpers for Python code review.""" +"""Static parsing helpers for multi-domain Python code analysis.""" from __future__ import annotations import ast -from dataclasses import dataclass, field -from typing import Any - - -@dataclass(slots=True) -class _StructureVisitor(ast.NodeVisitor): - """Collect lightweight structural signals from Python source.""" - - imports: set[str] = field(default_factory=set) - route_decorators: set[str] = field(default_factory=set) - function_names: list[str] = field(default_factory=list) - class_names: list[str] = field(default_factory=list) - code_smells: list[str] = field(default_factory=list) - branch_count: int = 0 - max_loop_depth: int = 0 - max_nesting_depth: int = 0 - current_loop_depth: int = 0 - current_nesting_depth: int = 0 - recursive_functions: set[str] = field(default_factory=set) - current_function: str | None = None - docstring_total: int = 0 - docstring_with_docs: int = 0 - backward_calls: int = 0 - optimizer_step_calls: int = 0 - container_builds: int = 0 - - def visit_Import(self, node: ast.Import) -> None: # noqa: N802 - for alias in node.names: - self.imports.add(alias.name.split(".")[0]) - self.generic_visit(node) +from typing import Any, Dict, List - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802 - if node.module: - self.imports.add(node.module.split(".")[0]) - self.generic_visit(node) - def _push_nesting(self) -> None: - self.current_nesting_depth += 1 - self.max_nesting_depth = max(self.max_nesting_depth, self.current_nesting_depth) +class _LoopDepthVisitor(ast.NodeVisitor): + """Collect loop nesting depth for a parsed Python module.""" - def _pop_nesting(self) -> None: - self.current_nesting_depth = max(0, self.current_nesting_depth - 1) + def __init__(self) -> None: + self.depth = 0 + self.max_depth = 0 def _visit_loop(self, node: ast.AST) -> None: - self.branch_count += 1 - self.current_loop_depth += 1 - self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth) - self._push_nesting() + self.depth += 1 + self.max_depth = max(self.max_depth, self.depth) self.generic_visit(node) - self._pop_nesting() - self.current_loop_depth = max(0, self.current_loop_depth - 1) + self.depth -= 1 def visit_For(self, node: ast.For) -> None: # noqa: N802 self._visit_loop(node) - def visit_AsyncFor(self, node: ast.AsyncFor) -> None: # noqa: N802 - self._visit_loop(node) - def visit_While(self, node: ast.While) -> None: # noqa: N802 self._visit_loop(node) - def visit_If(self, node: ast.If) -> None: # noqa: N802 - self.branch_count += 1 - self._push_nesting() - self.generic_visit(node) - self._pop_nesting() - - def visit_Try(self, node: ast.Try) -> None: # noqa: N802 - self.branch_count += 1 - self._push_nesting() - self.generic_visit(node) - self._pop_nesting() - - def visit_With(self, node: ast.With) -> None: # noqa: N802 - self._push_nesting() - self.generic_visit(node) - self._pop_nesting() - - def visit_AsyncWith(self, node: ast.AsyncWith) -> None: # noqa: N802 - self._push_nesting() - self.generic_visit(node) - self._pop_nesting() - def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802 self._visit_loop(node) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # noqa: N802 - self.function_names.append(node.name) - self.docstring_total += 1 - if ast.get_docstring(node): - self.docstring_with_docs += 1 - prior = self.current_function - self.current_function = node.name - for decorator in node.decorator_list: - decorator_name = self._decorator_name(decorator) - if decorator_name in {"get", "post", "put", "patch", "delete"}: - self.route_decorators.add(decorator_name) - self.generic_visit(node) - self.current_function = prior - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: N802 - self.visit_FunctionDef(node) - - def visit_ClassDef(self, node: ast.ClassDef) -> None: # noqa: N802 - self.class_names.append(node.name) - self.generic_visit(node) - def visit_Call(self, node: ast.Call) -> None: # noqa: N802 - dotted_name = self._call_name(node.func) - if dotted_name.endswith(".backward") or dotted_name == "backward": - self.backward_calls += 1 - if dotted_name.endswith(".step") or dotted_name == "step": - if "optimizer" in dotted_name: - self.optimizer_step_calls += 1 - if dotted_name in {"list", "dict", "set", "tuple"}: - self.container_builds += 1 - if self.current_function and dotted_name == self.current_function: - self.recursive_functions.add(self.current_function) - self.generic_visit(node) +def parse_code_structure(code: str) -> Dict[str, Any]: + """Parse Python code into reusable structural signals.""" - @staticmethod - def _call_name(node: ast.AST) -> str: - if isinstance(node, ast.Name): - return node.id - if isinstance(node, ast.Attribute): - left = _StructureVisitor._call_name(node.value) - return f"{left}.{node.attr}" if left else node.attr - return "" - - @staticmethod - def _decorator_name(node: ast.AST) -> str: - if isinstance(node, ast.Call): - return _StructureVisitor._decorator_name(node.func) - if isinstance(node, ast.Attribute): - return node.attr.lower() - if isinstance(node, ast.Name): - return node.id.lower() - return "" - - -def _line_smells(lines: list[str]) -> tuple[int, list[int], bool]: - long_lines = sum(1 for line in lines if len(line) > 88) - trailing_whitespace_lines = [index + 1 for index, line in enumerate(lines) if line.rstrip() != line] - tabs_used = any("\t" in line for line in lines) - return long_lines, trailing_whitespace_lines, tabs_used - - -def parse_code_structure(code: str) -> dict[str, Any]: - """Extract deterministic syntax, import, and structure signals from Python code.""" - - normalized_code = code or "" - lines = normalized_code.splitlines() - long_lines, trailing_whitespace_lines, tabs_used = _line_smells(lines) - - result: dict[str, Any] = { + summary: Dict[str, Any] = { "syntax_valid": True, "syntax_error": "", - "line_count": len(lines), "imports": [], "function_names": [], "class_names": [], - "long_lines": long_lines, - "trailing_whitespace_lines": trailing_whitespace_lines, - "tabs_used": tabs_used, - "docstring_ratio": 0.0, - "uses_recursion": False, + "loop_count": 0, + "branch_count": 0, "max_loop_depth": 0, - "max_nesting_depth": 0, - "route_decorators": [], - "code_smells": [], - "uses_pandas": False, + "line_count": len(code.splitlines()), + "long_lines": 0, + "tabs_used": "\t" in code, + "trailing_whitespace_lines": 0, "uses_numpy": False, + "uses_pandas": False, "uses_torch": False, "uses_sklearn": False, "uses_fastapi": False, "uses_flask": False, "uses_pydantic": False, + "uses_recursion": False, + "calls_eval": False, + "calls_no_grad": False, "calls_backward": False, "calls_optimizer_step": False, - "branch_count": 0, - "container_builds": 0, + "route_decorators": [], + "docstring_ratio": 0.0, + "code_smells": [], } + lines = code.splitlines() + summary["long_lines"] = sum(1 for line in lines if len(line) > 88) + summary["trailing_whitespace_lines"] = sum(1 for line in lines if line.rstrip() != line) + try: - tree = ast.parse(normalized_code or "\n") + tree = ast.parse(code) except SyntaxError as exc: - result["syntax_valid"] = False - result["syntax_error"] = f"{exc.msg} (line {exc.lineno}, column {exc.offset})" - result["code_smells"] = ["Code does not parse.", "Fix syntax before deeper review."] - return result + summary["syntax_valid"] = False + summary["syntax_error"] = f"{exc.msg} (line {exc.lineno})" + summary["code_smells"].append("Code does not parse.") + return summary - visitor = _StructureVisitor() + visitor = _LoopDepthVisitor() visitor.visit(tree) - - imports = sorted(visitor.imports) - uses_pandas = "pandas" in imports or "pd" in normalized_code - uses_numpy = "numpy" in imports or "np." in normalized_code - uses_torch = "torch" in imports or "torch." in normalized_code - uses_sklearn = "sklearn" in imports - uses_fastapi = "fastapi" in imports - uses_flask = "flask" in imports - uses_pydantic = "pydantic" in imports or "BaseModel" in normalized_code - - code_smells = list(visitor.code_smells) - if visitor.max_loop_depth >= 2: - code_smells.append("Nested loops may create avoidable performance pressure.") - if long_lines: - code_smells.append("Long lines reduce readability and reviewability.") - if trailing_whitespace_lines: - code_smells.append("Trailing whitespace suggests style drift.") - if visitor.docstring_total and visitor.docstring_with_docs == 0: - code_smells.append("Public functions are missing docstrings.") - if not visitor.function_names: - code_smells.append("Encapsulate behavior in functions for testability.") - - result.update( - { - "imports": imports, - "function_names": visitor.function_names, - "class_names": visitor.class_names, - "docstring_ratio": round( - visitor.docstring_with_docs / max(visitor.docstring_total, 1), - 4, - ), - "uses_recursion": bool(visitor.recursive_functions), - "max_loop_depth": visitor.max_loop_depth, - "max_nesting_depth": visitor.max_nesting_depth, - "route_decorators": sorted(visitor.route_decorators), - "code_smells": code_smells, - "uses_pandas": uses_pandas, - "uses_numpy": uses_numpy, - "uses_torch": uses_torch, - "uses_sklearn": uses_sklearn, - "uses_fastapi": uses_fastapi, - "uses_flask": uses_flask, - "uses_pydantic": uses_pydantic, - "calls_backward": visitor.backward_calls > 0, - "calls_optimizer_step": visitor.optimizer_step_calls > 0, - "branch_count": visitor.branch_count, - "container_builds": visitor.container_builds, - } + summary["max_loop_depth"] = visitor.max_depth + + functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)] + summary["function_names"] = [node.name for node in functions] + summary["class_names"] = [node.name for node in tree.body if isinstance(node, ast.ClassDef)] + summary["docstring_ratio"] = ( + sum(1 for node in functions if ast.get_docstring(node)) / len(functions) + if functions + else 0.0 ) - return result + + imports: List[str] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imports.extend(alias.name.split(".")[0] for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imports.append(node.module.split(".")[0]) + elif isinstance(node, (ast.For, ast.While, ast.comprehension)): + summary["loop_count"] += 1 + elif isinstance(node, (ast.If, ast.Try, ast.Match)): + summary["branch_count"] += 1 + elif isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + attr = node.func.attr + if attr == "eval": + summary["calls_eval"] = True + elif attr == "backward": + summary["calls_backward"] = True + elif attr == "step": + summary["calls_optimizer_step"] = True + elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "print": + summary["code_smells"].append("Debug print statements are present.") + elif isinstance(node, ast.With): + if any(isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute) and item.context_expr.func.attr == "no_grad" for item in node.items): + summary["calls_no_grad"] = True + + import_set = sorted(set(imports)) + summary["imports"] = import_set + summary["uses_numpy"] = "numpy" in import_set or "np" in code + summary["uses_pandas"] = "pandas" in import_set or "pd" in code + summary["uses_torch"] = "torch" in import_set + summary["uses_sklearn"] = "sklearn" in import_set + summary["uses_fastapi"] = "fastapi" in import_set + summary["uses_flask"] = "flask" in import_set + summary["uses_pydantic"] = "pydantic" in import_set or "BaseModel" in code + + for node in functions: + for child in ast.walk(node): + if isinstance(child, ast.Call) and isinstance(child.func, ast.Name) and child.func.id == node.name: + summary["uses_recursion"] = True + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute): + summary["route_decorators"].append(decorator.func.attr) + elif isinstance(decorator, ast.Attribute): + summary["route_decorators"].append(decorator.attr) + + if summary["long_lines"]: + summary["code_smells"].append("Long lines reduce readability.") + if summary["tabs_used"]: + summary["code_smells"].append("Tabs detected; prefer spaces for consistency.") + if summary["trailing_whitespace_lines"]: + summary["code_smells"].append("Trailing whitespace found.") + + return summary diff --git a/utils/complexity.py b/utils/complexity.py index 5010732e8733a764e2e0ad93c431bbcc2102b0fa..02890c2bf4f9cc791c8d6b49321a1de963eb60e5 100644 --- a/utils/complexity.py +++ b/utils/complexity.py @@ -1,70 +1,37 @@ -"""Complexity heuristics for Python code review.""" +"""Complexity heuristics for DSA-style and general Python code.""" from __future__ import annotations -import ast -from typing import Any +from typing import Any, Dict -def _clamp_unit(value: float) -> float: - return max(0.0, min(1.0, float(value))) +def estimate_complexity(parsed: Dict[str, Any], code: str) -> Dict[str, Any]: + """Estimate cyclomatic complexity and rough Big-O heuristics.""" - -def _estimate_time_complexity(loop_depth: int, uses_recursion: bool) -> str: - if uses_recursion and loop_depth >= 1: - return "O(n^2)" - if loop_depth >= 3: - return "O(n^3)" - if loop_depth == 2: - return "O(n^2)" - if loop_depth == 1: - return "O(n)" - if uses_recursion: - return "O(n)" - return "O(1)" - - -def _estimate_space_complexity(code: str, uses_recursion: bool) -> str: - if uses_recursion: - return "O(n)" - if any(token in code for token in ("[]", "{}", "set(", "dict(", "list(", "Counter(")): - return "O(n)" - return "O(1)" - - -def _cyclomatic_complexity(code: str) -> int: - try: - tree = ast.parse(code or "\n") - except SyntaxError: - return 1 - decision_points = sum( - isinstance(node, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.ExceptHandler, ast.Match, ast.BoolOp)) - for node in ast.walk(tree) - ) - return max(1, 1 + decision_points) - - -def estimate_complexity(parsed: dict[str, Any], code: str) -> dict[str, Any]: - """Estimate Python complexity signals from parsed structure plus source text.""" - - cyclomatic_complexity = _cyclomatic_complexity(code) - loop_depth = int(parsed.get("max_loop_depth", 0) or 0) - max_nesting_depth = int(parsed.get("max_nesting_depth", 0) or 0) + cyclomatic = 1 + int(parsed.get("branch_count", 0)) + loop_depth = int(parsed.get("max_loop_depth", 0)) uses_recursion = bool(parsed.get("uses_recursion", False)) - line_count = int(parsed.get("line_count", 0) or 0) - - complexity_penalty = _clamp_unit( - 0.08 - + min(cyclomatic_complexity, 12) * 0.045 - + min(loop_depth, 4) * 0.11 - + min(max_nesting_depth, 4) * 0.06 - + (0.06 if uses_recursion else 0.0) - + min(line_count, 200) * 0.0009 - ) + if loop_depth >= 3: + time_complexity = "O(n^3)" + elif loop_depth == 2: + time_complexity = "O(n^2)" + elif "sorted(" in code or ".sort(" in code: + time_complexity = "O(n log n)" + elif loop_depth == 1 or uses_recursion: + time_complexity = "O(n)" + else: + time_complexity = "O(1)" + + if "append(" in code or "list(" in code or "dict(" in code or "set(" in code: + space_complexity = "O(n)" + else: + space_complexity = "O(1)" + + complexity_penalty = min(0.99, 0.08 + (cyclomatic * 0.04) + (loop_depth * 0.12)) return { - "cyclomatic_complexity": cyclomatic_complexity, - "time_complexity": _estimate_time_complexity(loop_depth, uses_recursion), - "space_complexity": _estimate_space_complexity(code, uses_recursion), + "cyclomatic_complexity": cyclomatic, + "time_complexity": time_complexity, + "space_complexity": space_complexity, "complexity_penalty": round(complexity_penalty, 4), } diff --git a/uv.lock b/uv.lock index 584df0eeec58c8faaf2e53f6ff23cc8cd1d8339e..496be518ae45526396abb39241b45932977692cd 100644 --- a/uv.lock +++ b/uv.lock @@ -1926,6 +1926,7 @@ source = { editable = "." } dependencies = [ { name = "fastapi" }, { name = "gradio" }, + { name = "hf-xet" }, { name = "openai" }, { name = "openenv-core", extra = ["core"] }, { name = "streamlit" }, @@ -1944,6 +1945,7 @@ dev = [ requires-dist = [ { name = "fastapi", specifier = ">=0.111.0" }, { name = "gradio", specifier = ">=5.26.0" }, + { name = "hf-xet", specifier = ">=1.4.3" }, { name = "openai", specifier = ">=1.76.0" }, { name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },