pull-request-validator / main_ai_version.py
Sgridda
modified
733f0e1
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
import json
# ----------------------------
# 1. Configuration
# ----------------------------
MODEL_NAME = "Salesforce/codegen-350M-mono"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------------------
# 2. FastAPI App Initialization
# ----------------------------
app = FastAPI(
title="AI Code Review Service",
description="An API to get AI-powered code reviews for pull request diffs.",
version="1.0.0",
)
# ----------------------------
# 3. AI Model Loading
# ----------------------------
model = None
tokenizer = None
def load_model():
"""Loads the model and tokenizer into memory."""
global model, tokenizer
if model is None:
print(f"Loading model: {MODEL_NAME} on device: {DEVICE}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="cpu",
)
print("Model loaded successfully.")
@app.on_event("startup")
async def startup_event():
"""
On server startup, we trigger the model loading.
"""
print("Server starting up...")
load_model()
# ----------------------------
# 4. API Request/Response Models
# ----------------------------
class ReviewRequest(BaseModel):
diff: str
class ReviewComment(BaseModel):
file_path: str
line_number: int
comment_text: str
class ReviewResponse(BaseModel):
comments: list[ReviewComment]
# ----------------------------
# 5. The AI Review Logic
# ----------------------------
def run_ai_inference(diff: str) -> str:
"""
Runs the AI model to get the review.
"""
if not model or not tokenizer:
raise RuntimeError("Model is not loaded.")
# Improved prompt for codegen-350M-mono
prompt = (
"Below is a Python function. Please provide a code review comment with suggestions for improvement, in natural language. "
"Do not repeat the code.\n"
f"{diff[:800]}\n"
"Review comment:"
)
encoded = tokenizer(
prompt,
return_tensors="pt",
max_length=1024,
truncation=True,
padding="max_length"
)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_p=0.95,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
use_cache=True
)
response_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
# Post-process: filter out code-like lines and fallback if needed
review_lines = [line.strip() for line in response_text.strip().split('\n') if line.strip()]
# Filter out lines that look like code
comment_lines = [l for l in review_lines if not l.startswith("def ") and not l.startswith("class ") and not l.endswith(":") and not l.startswith("#")]
review = comment_lines[0] if comment_lines else "Consider adding a docstring and input validation."
return review
def parse_ai_response(response_text: str) -> list[ReviewComment]:
"""
Parses the raw text from the AI to extract the JSON array.
"""
# For codegen-350M-mono, just wrap the review in a single comment
return [ReviewComment(
file_path="code_reviewed.py",
line_number=1,
comment_text=response_text.strip()
)]
# ----------------------------
# 6. The API Endpoint
# ----------------------------
@app.post("/review", response_model=ReviewResponse)
async def get_code_review(request: ReviewRequest):
if not request.diff:
raise HTTPException(status_code=400, detail="Diff content cannot be empty.")
import time
start_time = time.time()
print(f"Starting review request at {start_time}")
try:
print("Running AI inference...")
ai_response_text = run_ai_inference(request.diff)
print(f"AI inference completed in {time.time() - start_time:.2f} seconds")
print("Parsing AI response...")
parsed_comments = parse_ai_response(ai_response_text)
print(f"Total processing time: {time.time() - start_time:.2f} seconds")
return ReviewResponse(comments=parsed_comments)
except Exception as e:
print(f"An unexpected error occurred after {time.time() - start_time:.2f} seconds: {e}")
raise HTTPException(status_code=500, detail="An internal error occurred while processing the review.")
# ----------------------------
# 7. Health Check Endpoint
# ----------------------------
@app.get("/health")
async def health_check():
return {"status": "ok", "model_loaded": model is not None}